Merge branch 'main' into feature/dashen/model_mapping

This commit is contained in:
2944321442@qq.com
2025-03-05 10:15:00 +08:00
289 changed files with 79473 additions and 1783 deletions
+1
View File
@@ -27,6 +27,7 @@ type AiModel struct {
Id string `json:"id"`
Config string `json:"config"`
Provider string `json:"provider"`
Type string `json:"type"`
}
type EditAPI struct {
+34
View File
@@ -4,6 +4,39 @@ import (
"github.com/eolinker/go-common/auto"
)
type ModelType string
const (
ModelTypeOnline ModelType = "online"
ModelTypeLocal ModelType = "local"
)
func (m ModelType) String() string {
return string(m)
}
func (m ModelType) Int() int {
switch m {
case ModelTypeOnline:
return 0
case ModelTypeLocal:
return 1
default:
return -1
}
}
func FromModelType(m int) ModelType {
switch m {
case 0:
return ModelTypeOnline
case 1:
return ModelTypeLocal
default:
return ""
}
}
type API struct {
Id string `json:"id"`
Name string `json:"name"`
@@ -19,6 +52,7 @@ type API struct {
type APIItem struct {
Id string `json:"id"`
Name string `json:"name"`
ModelType ModelType `json:"model_type"`
RequestPath string `json:"request_path"`
Description string `json:"description"`
Disable bool `json:"disabled"`
+12 -3
View File
@@ -54,7 +54,7 @@ func (i *imlAPIModule) getAPIDoc(ctx context.Context, serviceId string) (*openap
return openapi3Loader.LoadFromData([]byte(doc.Content))
}
func (i *imlAPIModule) updateAPIDoc(ctx context.Context, serviceId string, serviceName string, path string, summary string, description string, aiPrompt *ai_api_dto.AiPrompt) error {
func (i *imlAPIModule) updateAPIDoc(ctx context.Context, serviceId, serviceName, orgPath, path, summary, description string, aiPrompt *ai_api_dto.AiPrompt) error {
doc, err := i.getAPIDoc(ctx, serviceId)
if err != nil {
return err
@@ -64,6 +64,10 @@ func (i *imlAPIModule) updateAPIDoc(ctx context.Context, serviceId string, servi
if aiPrompt != nil {
variables = aiPrompt.Variables
}
if doc.Paths != nil {
doc.Paths.Delete(orgPath)
}
doc.AddOperation(path, http.MethodPost, genOperation(summary, description, variables))
result, err := doc.MarshalJSON()
if err != nil {
@@ -103,10 +107,11 @@ func (i *imlAPIModule) Create(ctx context.Context, serviceId string, input *ai_a
input.Id = uuid.New().String()
}
return i.transaction.Transaction(ctx, func(txCtx context.Context) error {
err = i.updateAPIDoc(ctx, serviceId, info.Name, input.Path, input.Name, input.Description, input.AiPrompt)
err = i.updateAPIDoc(ctx, serviceId, info.Name, "", input.Path, input.Name, input.Description, input.AiPrompt)
if err != nil {
return err
}
return i.aiAPIService.Create(ctx, &ai_api.Create{
ID: input.Id,
Name: input.Name,
@@ -118,6 +123,7 @@ func (i *imlAPIModule) Create(ctx context.Context, serviceId string, input *ai_a
Retry: input.Retry,
Model: input.AiModel.Id,
Provider: input.AiModel.Provider,
Type: ai_api_dto.ModelType(input.AiModel.Type).Int(),
AdditionalConfig: map[string]interface{}{
"ai_prompt": input.AiPrompt,
"ai_model": input.AiModel,
@@ -141,13 +147,14 @@ func (i *imlAPIModule) Edit(ctx context.Context, serviceId string, apiId string,
if err != nil {
return err
}
orgPath := apiInfo.Path
if input.Path != nil {
apiInfo.Path = *input.Path
}
if input.Description != nil {
apiInfo.Description = *input.Description
}
err = i.updateAPIDoc(ctx, serviceId, info.Name, apiInfo.Path, apiInfo.Name, apiInfo.Description, input.AiPrompt)
err = i.updateAPIDoc(ctx, serviceId, info.Name, orgPath, apiInfo.Path, apiInfo.Name, apiInfo.Description, input.AiPrompt)
if err != nil {
return err
}
@@ -163,6 +170,7 @@ func (i *imlAPIModule) Edit(ctx context.Context, serviceId string, apiId string,
if input.AiModel != nil {
apiInfo.AdditionalConfig["ai_model"] = input.AiModel
}
typ := ai_api_dto.ModelType(input.AiModel.Type).Int()
return i.aiAPIService.Save(ctx, apiId, &ai_api.Edit{
Name: input.Name,
Path: input.Path,
@@ -171,6 +179,7 @@ func (i *imlAPIModule) Edit(ctx context.Context, serviceId string, apiId string,
Retry: input.Retry,
Model: modelId,
Provider: providerId,
Type: &typ,
AdditionalConfig: &apiInfo.AdditionalConfig,
Disable: input.Disable,
})
+9 -11
View File
@@ -13,12 +13,7 @@ func genOpenAPI3Template(title string, description string) *openapi3.T {
Description: description,
Version: "beta",
}
//result.Tags = openapi3.Tags{
// {
// Name: title,
// Description: description,
// },
//}
result.Components = components
result.Paths = new(openapi3.Paths)
return result
@@ -26,7 +21,6 @@ func genOpenAPI3Template(title string, description string) *openapi3.T {
func genOperation(summary string, description string, variables []*ai_api_dto.AiPromptVariable) *openapi3.Operation {
operation := openapi3.NewOperation()
//operation.Parameters = genRequestParameters(variables)
operation.Summary = summary
operation.Description = description
operation.RequestBody = genRequestBody(variables)
@@ -78,9 +72,13 @@ func genResponse() *openapi3.ResponseRef {
func genRequestBodySchema(variables []*ai_api_dto.AiPromptVariable) *openapi3.Schema {
result := openapi3.NewObjectSchema()
result.WithProperty("variables", genVariableSchema(variables))
if len(variables) > 0 {
result.WithProperty("variables", genVariableSchema(variables))
result.WithRequired([]string{"variables", "messages"})
}
result.WithPropertyRef("messages", messagesSchemaRef)
result.WithRequired([]string{"variables", "messages"})
return result
}
@@ -136,10 +134,10 @@ func genMessageSchema() *openapi3.Schema {
result.Description = "Chat Message"
roleSchema := openapi3.NewStringSchema()
roleSchema.Description = "Role of the message sender"
roleSchema.Example = "assistant"
roleSchema.Example = "user"
contentSchema := openapi3.NewStringSchema()
contentSchema.Description = "The message content"
contentSchema.Example = "Hello, how can I help you?"
contentSchema.Example = "Hello, who are you?"
result.WithProperties(map[string]*openapi3.Schema{
"role": roleSchema,
"content": contentSchema,
+14
View File
@@ -0,0 +1,14 @@
package ai_balance_dto
type Create struct {
Id string `json:"id"`
Type string `json:"type"`
Provider string `json:"provider"`
Model string `json:"model"`
}
type Sort struct {
Origin string `json:"origin"`
Target string `json:"target"`
Sort string `json:"sort"`
}
+81
View File
@@ -0,0 +1,81 @@
package ai_balance_dto
const (
ModelTypeOnline = "online"
ModelTypeLocal = "local"
StateNormal = "normal"
StateAbnormal = "abnormal"
)
type ModelType string
func (m ModelType) String() string {
return string(m)
}
func (m ModelType) Int() int {
switch m {
case ModelTypeOnline:
return 0
case ModelTypeLocal:
return 1
default:
return -1
}
}
func ModelTypeFromInt(i int) ModelType {
switch i {
case 0:
return ModelTypeOnline
case 1:
return ModelTypeLocal
default:
return "unknown"
}
}
type ModelState string
func (m ModelState) String() string {
return string(m)
}
func (m ModelState) Int() int {
switch m {
case StateNormal:
return 1
case StateAbnormal:
return 0
default:
return -1
}
}
func ModelStateFromInt(i int) ModelState {
switch i {
case 1:
return StateNormal
case 0:
return StateAbnormal
default:
return "unknown"
}
}
type Item struct {
Id string `json:"id"`
Priority int `json:"priority"`
Provider *BasicItem `json:"provider"`
Model *BasicItem `json:"model"`
Type ModelType `json:"type"`
State ModelState `json:"state"`
APICount int64 `json:"api_count"`
KeyCount int64 `json:"key_count"`
}
type BasicItem struct {
Id string `json:"id"`
Name string `json:"name"`
}
+329
View File
@@ -0,0 +1,329 @@
package ai_balance
import (
"context"
"errors"
"fmt"
"sort"
"github.com/APIParkLab/APIPark/service/setting"
ai_provider_local "github.com/APIParkLab/APIPark/ai-provider/local"
model_runtime "github.com/APIParkLab/APIPark/ai-provider/model-runtime"
"gorm.io/gorm"
ai_key "github.com/APIParkLab/APIPark/service/ai-key"
ai_api "github.com/APIParkLab/APIPark/service/ai-api"
"github.com/google/uuid"
ai_balance "github.com/APIParkLab/APIPark/service/ai-balance"
"github.com/eolinker/go-common/store"
"github.com/APIParkLab/APIPark/gateway"
"github.com/APIParkLab/APIPark/service/cluster"
ai_balance_dto "github.com/APIParkLab/APIPark/module/ai-balance/dto"
"github.com/eolinker/eosc/log"
)
var _ IBalanceModule = (*imlBalanceModule)(nil)
type imlBalanceModule struct {
clusterService cluster.IClusterService `autowired:""`
aiAPIService ai_api.IAPIService `autowired:""`
aiKeyService ai_key.IKeyService `autowired:""`
balanceService ai_balance.IBalanceService `autowired:""`
settingService setting.ISettingService `autowired:""`
transaction store.ITransaction `autowired:""`
}
func (i *imlBalanceModule) SyncLocalBalances(ctx context.Context, address string) error {
releases, err := i.getLocalBalances(ctx, address)
if err != nil {
return err
}
return i.syncGateway(ctx, cluster.DefaultClusterID, releases, true)
}
func (i *imlBalanceModule) Create(ctx context.Context, input *ai_balance_dto.Create) error {
has, err := i.balanceService.Exist(ctx, input.Provider, input.Model)
if err != nil {
return err
}
if has {
return fmt.Errorf("model already exists")
}
priority, err := i.balanceService.MaxPriority(ctx)
if err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) {
return err
}
priority = 0
}
if input.Id == "" {
input.Id = uuid.NewString()
}
providerName := ""
modelName := ""
base := ""
switch input.Type {
case ai_balance_dto.ModelTypeOnline:
p, has := model_runtime.GetProvider(input.Provider)
if !has {
return fmt.Errorf("provider not found")
}
providerName = p.Name()
modelName = input.Model
base = fmt.Sprintf("%s://%s", p.URI().Scheme(), p.URI().Host())
case ai_balance_dto.ModelTypeLocal:
input.Provider = "ollama"
providerName = "Ollama"
modelName = input.Model
v, has := i.settingService.Get(ctx, "system.ai_model.ollama_address")
if !has {
return fmt.Errorf("ollama address not found")
}
base = v
}
return i.transaction.Transaction(ctx, func(ctx context.Context) error {
err = i.balanceService.Create(ctx, &ai_balance.Create{
Id: input.Id,
Priority: priority + 1,
Provider: input.Provider,
ProviderName: providerName,
Model: input.Model,
ModelName: modelName,
Type: ai_balance_dto.ModelType(input.Type).Int(),
})
if err != nil {
return err
}
item, err := i.balanceService.Get(ctx, input.Id)
if err != nil {
return err
}
return i.syncGateway(ctx, cluster.DefaultClusterID, []*gateway.DynamicRelease{newRelease(item, base)}, true)
})
}
func newRelease(item *ai_balance.Balance, base string) *gateway.DynamicRelease {
cfg := make(map[string]interface{})
cfg["provider"] = item.Provider
cfg["model"] = item.Model
cfg["model_config"] = ai_provider_local.OllamaConfig
cfg["base"] = base
cfg["priority"] = item.Priority
return &gateway.DynamicRelease{
BasicItem: &gateway.BasicItem{
ID: item.Id,
Description: item.ModelName,
Resource: "ai-provider",
Version: item.UpdateAt.Format("20060102150405"),
MatchLabels: map[string]string{
"module": "ai-provider",
},
},
Attr: cfg,
}
}
func (i *imlBalanceModule) Sort(ctx context.Context, input *ai_balance_dto.Sort) error {
var list []*ai_balance.Balance
var err error
switch input.Sort {
case "after":
list, err = i.balanceService.SortAfter(ctx, input.Origin, input.Target)
default:
list, err = i.balanceService.SortBefore(ctx, input.Origin, input.Target)
}
if err != nil {
return err
}
v, has := i.settingService.Get(ctx, "system.ai_model.ollama_address")
if !has {
return fmt.Errorf("ollama address not found")
}
releases := make([]*gateway.DynamicRelease, 0, len(list))
for _, item := range list {
base := v
if item.Provider != "ollama" {
p, has := model_runtime.GetProvider(item.Provider)
if !has {
continue
}
base = fmt.Sprintf("%s://%s", p.URI().Scheme(), p.URI().Host())
}
releases = append(releases, newRelease(item, base))
}
err = i.syncGateway(ctx, cluster.DefaultClusterID, releases, true)
if err != nil {
return err
}
return nil
}
func (i *imlBalanceModule) List(ctx context.Context, keyword string) ([]*ai_balance_dto.Item, error) {
list, err := i.balanceService.Search(ctx, keyword, nil, "priority asc")
if err != nil {
return nil, err
}
sort.Slice(list, func(i, j int) bool {
return list[i].Priority < list[j].Priority
})
aiAPIMap, err := i.aiAPIService.CountMapByModel(ctx, "", nil)
if err != nil {
return nil, fmt.Errorf("get ai api count error:%v", err)
}
keyMap, err := i.aiKeyService.CountMapByProvider(ctx, "", nil)
if err != nil {
return nil, fmt.Errorf("get ai key count error:%v", err)
}
result := make([]*ai_balance_dto.Item, 0, len(list))
for i, item := range list {
priority := i + 1
result = append(result, &ai_balance_dto.Item{
Id: item.Id,
Provider: &ai_balance_dto.BasicItem{
Id: item.Provider,
Name: item.ProviderName,
},
Model: &ai_balance_dto.BasicItem{
Id: item.Model,
Name: item.ModelName,
},
Priority: priority,
Type: ai_balance_dto.ModelTypeFromInt(item.Type),
State: ai_balance_dto.ModelStateFromInt(item.State),
APICount: aiAPIMap[item.Model],
KeyCount: keyMap[item.Provider],
})
}
return result, nil
}
func (i *imlBalanceModule) Delete(ctx context.Context, id string) error {
return i.transaction.Transaction(ctx, func(ctx context.Context) error {
err := i.balanceService.Delete(ctx, id)
if err != nil {
return err
}
return i.syncGateway(ctx, cluster.DefaultClusterID, []*gateway.DynamicRelease{
{
BasicItem: &gateway.BasicItem{
ID: id,
},
},
}, false)
})
}
func (i *imlBalanceModule) syncGateway(ctx context.Context, clusterId string, releases []*gateway.DynamicRelease, online bool) error {
client, err := i.clusterService.GatewayClient(ctx, clusterId)
if err != nil {
log.Errorf("get apinto client error: %v", err)
return nil
}
defer func() {
err := client.Close(ctx)
if err != nil {
log.Warn("close apinto client:", err)
}
}()
for _, releaseInfo := range releases {
dynamicClient, err := client.Dynamic(releaseInfo.Resource)
if err != nil {
return err
}
if online {
err = dynamicClient.Online(ctx, releaseInfo)
} else {
dynamicClient.Offline(ctx, releaseInfo)
}
if err != nil {
return err
}
}
return nil
}
func (i *imlBalanceModule) getLocalBalances(ctx context.Context, v string) ([]*gateway.DynamicRelease, error) {
balances, err := i.balanceService.Search(ctx, "", map[string]interface{}{"provider": "ollama"}, "priority asc")
if err != nil {
return nil, err
}
if v == "" {
var has bool
v, has = i.settingService.Get(ctx, "system.ai_model.ollama_address")
if !has {
return nil, fmt.Errorf("ollama address not found")
}
}
releases := make([]*gateway.DynamicRelease, 0, len(balances))
for _, item := range balances {
base := v
if item.Provider != "ollama" {
p, has := model_runtime.GetProvider(item.Provider)
if !has {
continue
}
base = fmt.Sprintf("%s://%s", p.URI().Scheme(), p.URI().Host())
}
releases = append(releases, newRelease(item, base))
}
return releases, nil
}
func (i *imlBalanceModule) getBalances(ctx context.Context) ([]*gateway.DynamicRelease, error) {
balances, err := i.balanceService.Search(ctx, "", nil, "priority asc")
if err != nil {
return nil, err
}
v, has := i.settingService.Get(ctx, "system.ai_model.ollama_address")
if !has {
return nil, fmt.Errorf("ollama address not found")
}
releases := make([]*gateway.DynamicRelease, 0, len(balances))
for _, item := range balances {
base := v
if item.Provider != "ollama" {
p, has := model_runtime.GetProvider(item.Provider)
if !has {
continue
}
base = fmt.Sprintf("%s://%s", p.URI().Scheme(), p.URI().Host())
}
releases = append(releases, newRelease(item, base))
}
return releases, nil
}
func (i *imlBalanceModule) initGateway(ctx context.Context, clusterId string, clientDriver gateway.IClientDriver) error {
releases, err := i.getBalances(ctx)
if err != nil {
return err
}
for _, p := range releases {
client, err := clientDriver.Dynamic(p.Resource)
if err != nil {
return err
}
err = client.Online(ctx, p)
if err != nil {
return err
}
}
return nil
}
+28
View File
@@ -0,0 +1,28 @@
package ai_balance
import (
"context"
"reflect"
"github.com/APIParkLab/APIPark/gateway"
"github.com/eolinker/go-common/autowire"
ai_balance_dto "github.com/APIParkLab/APIPark/module/ai-balance/dto"
)
type IBalanceModule interface {
Create(ctx context.Context, input *ai_balance_dto.Create) error
Sort(ctx context.Context, input *ai_balance_dto.Sort) error
List(ctx context.Context, keyword string) ([]*ai_balance_dto.Item, error)
Delete(ctx context.Context, id string) error
SyncLocalBalances(ctx context.Context, address string) error
}
func init() {
balanceModule := new(imlBalanceModule)
autowire.Auto[IBalanceModule](func() reflect.Value {
gateway.RegisterInitHandleFunc(balanceModule.initGateway)
return reflect.ValueOf(balanceModule)
})
}
+15
View File
@@ -0,0 +1,15 @@
package ai_local_dto
type Update struct {
Disable bool `json:"disable"`
}
type CancelDeploy struct {
Model string `json:"model"`
}
type DeployInput struct {
Model string `json:"model"`
Service string `json:"service"`
Team string `json:"team"`
}
+141
View File
@@ -0,0 +1,141 @@
package ai_local_dto
import "github.com/eolinker/go-common/auto"
type LocalModelState string
const (
LocalModelStateNormal LocalModelState = "normal"
LocalModelStateDisable LocalModelState = "disabled"
LocalModelStateDeployingError LocalModelState = "deploying_error"
LocalModelStateError LocalModelState = "error"
LocalModelStateDeploying LocalModelState = "deploying"
DeployStateDownload DeployState = "download"
DeployStateDeploy DeployState = "deploy"
DeployStateInitializing DeployState = "initializing"
DeployStateFinish DeployState = "finish"
DeployStateDownloadError DeployState = "download error"
DeployStateDeployError DeployState = "deploy error"
DeployStateInitializingError DeployState = "initializing error"
)
func (l LocalModelState) String() string {
return string(l)
}
func (l LocalModelState) Int() int {
switch l {
case LocalModelStateDisable:
return 0
case LocalModelStateNormal:
return 1
case LocalModelStateError:
return 2
case LocalModelStateDeploying:
return 3
case LocalModelStateDeployingError:
return 4
default:
return 0
}
}
func FromLocalModelState(state int) LocalModelState {
switch state {
case 0:
return LocalModelStateDisable
case 1:
return LocalModelStateNormal
case 2:
return LocalModelStateError
case 3:
return LocalModelStateDeploying
case 4:
return LocalModelStateDeployingError
default:
return LocalModelStateDisable
}
}
type OllamaConfig struct {
Address string `json:"address"`
}
type SimpleItem struct {
Id string `json:"id"`
Name string `json:"name"`
DefaultConfig string `json:"default_config"`
Logo string `json:"logo"`
}
type LocalModelItem struct {
Id string `json:"id"`
Name string `json:"name"`
State LocalModelState `json:"state"`
APICount int64 `json:"api_count"`
Provider string `json:"provider"`
UpdateTime auto.TimeLabel `json:"update_time"`
CanDelete bool `json:"can_delete"`
}
type LocalModelPackageItem struct {
Id string `json:"id"`
Name string `json:"name"`
Size string `json:"size"`
IsPopular bool `json:"is_popular"`
}
type DeployState string
func (d DeployState) String() string {
return string(d)
}
func (d DeployState) Int() int {
switch d {
case DeployStateDownload:
return 1
case DeployStateDeploy:
return 2
case DeployStateInitializing:
return 3
case DeployStateFinish:
return 4
case DeployStateDownloadError:
return 5
case DeployStateDeployError:
return 6
case DeployStateInitializingError:
return 7
default:
return 1
}
}
func FromDeployState(state int) DeployState {
switch state {
case 1:
return DeployStateDownload
case 2:
return DeployStateDeploy
case 3:
return DeployStateInitializing
case 4:
return DeployStateFinish
case 5:
return DeployStateDownloadError
case 6:
return DeployStateDeployError
case 7:
return DeployStateInitializingError
default:
return DeployStateDownload
}
}
type ModelInfo struct {
Current int64 `json:"current"`
Total int64 `json:"total"`
LastMessage string `json:"last_message"`
}
+635
View File
@@ -0,0 +1,635 @@
package ai_local
import (
"context"
"errors"
"fmt"
"strings"
ai_balance "github.com/APIParkLab/APIPark/service/ai-balance"
"github.com/APIParkLab/APIPark/service/setting"
"github.com/APIParkLab/APIPark/service/api"
"github.com/APIParkLab/APIPark/gateway"
"github.com/eolinker/eosc/log"
"github.com/APIParkLab/APIPark/service/cluster"
"github.com/APIParkLab/APIPark/service/service"
"github.com/eolinker/go-common/auto"
ai_api "github.com/APIParkLab/APIPark/service/ai-api"
"github.com/eolinker/go-common/register"
"github.com/eolinker/go-common/server"
"github.com/eolinker/go-common/utils"
"gorm.io/gorm"
ai_local "github.com/APIParkLab/APIPark/service/ai-local"
"github.com/eolinker/go-common/store"
ai_provider_local "github.com/APIParkLab/APIPark/ai-provider/local"
ai_local_dto "github.com/APIParkLab/APIPark/module/ai-local/dto"
)
var (
_ ILocalModelModule = (*imlLocalModel)(nil)
)
type imlLocalModel struct {
localModelService ai_local.ILocalModelService `autowired:""`
localModelPackageService ai_local.ILocalModelPackageService `autowired:""`
localModelStateService ai_local.ILocalModelInstallStateService `autowired:""`
localModelCacheService ai_local.ILocalModelCacheService `autowired:""`
balanceService ai_balance.IBalanceService `autowired:""`
clusterService cluster.IClusterService `autowired:""`
aiAPIService ai_api.IAPIService `autowired:""`
routerService api.IAPIService `autowired:""`
serviceService service.IServiceService `autowired:""`
settingService setting.ISettingService `autowired:""`
transaction store.ITransaction `autowired:""`
}
func (i *imlLocalModel) SyncLocalModels(ctx context.Context, address string) error {
releases, err := i.getLocalModels(ctx, address)
if err != nil {
return err
}
return i.syncGateway(ctx, cluster.DefaultClusterID, releases, true)
}
func (i *imlLocalModel) SimpleList(ctx context.Context) ([]*ai_local_dto.SimpleItem, error) {
list, err := i.localModelService.List(ctx)
if err != nil {
return nil, err
}
return utils.SliceToSlice(list, func(s *ai_local.LocalModel) *ai_local_dto.SimpleItem {
return &ai_local_dto.SimpleItem{
Id: s.Id,
Name: s.Name,
DefaultConfig: ai_provider_local.OllamaConfig,
Logo: ai_provider_local.OllamaSvg,
}
}, func(l *ai_local.LocalModel) bool {
if l.State != ai_local_dto.LocalModelStateNormal.Int() && l.State != ai_local_dto.LocalModelStateDisable.Int() {
return false
}
return true
}), nil
}
func (i *imlLocalModel) ModelState(ctx context.Context, model string) (*ai_local_dto.DeployState, *ai_local_dto.ModelInfo, error) {
info, err := i.localModelStateService.Get(ctx, model)
if err != nil {
return nil, nil, err
}
state := ai_local_dto.FromDeployState(info.State)
return &state, &ai_local_dto.ModelInfo{
Current: info.Complete,
Total: info.Total,
LastMessage: info.Msg,
}, nil
}
func (i *imlLocalModel) Search(ctx context.Context, keyword string) ([]*ai_local_dto.LocalModelItem, error) {
list, err := i.localModelService.Search(ctx, keyword, nil, "update_at desc")
if err != nil {
return nil, err
}
apiCountMap, err := i.aiAPIService.CountMapByModel(ctx, "", map[string]interface{}{
"type": 1,
})
if err != nil {
return nil, err
}
return utils.SliceToSlice(list, func(s *ai_local.LocalModel) *ai_local_dto.LocalModelItem {
count := apiCountMap[s.Id]
return &ai_local_dto.LocalModelItem{
Id: s.Id,
Name: s.Name,
State: ai_local_dto.FromLocalModelState(s.State),
APICount: count,
CanDelete: count < 1 && s.State != ai_local_dto.LocalModelStateDeploying.Int(),
UpdateTime: auto.TimeLabel(s.UpdateAt),
Provider: "ollama",
}
}), nil
}
func (i *imlLocalModel) ListCanInstall(ctx context.Context, keyword string) ([]*ai_local_dto.LocalModelPackageItem, error) {
if keyword == "" {
list, err := i.localModelPackageService.Search(ctx, keyword, nil)
if err != nil {
return nil, err
}
return utils.SliceToSlice(list, func(s *ai_local.LocalModelPackage) *ai_local_dto.LocalModelPackageItem {
return &ai_local_dto.LocalModelPackageItem{
Id: s.Id,
Name: s.Name,
Size: s.Size,
IsPopular: s.IsPopular,
}
}), nil
} else {
info, err := i.localModelPackageService.Get(ctx, keyword)
if err != nil {
return nil, err
}
result := make([]*ai_local_dto.LocalModelPackageItem, 0)
//for _, v := range list {
models := ai_provider_local.ModelsCanInstallById(info.Id)
for _, model := range models {
result = append(result, &ai_local_dto.LocalModelPackageItem{
Id: model.Id,
Name: model.Name,
Size: model.Size,
IsPopular: model.IsPopular,
})
}
//}
return result, nil
}
}
func (i *imlLocalModel) pullHook(fn ...func() error) func(msg ai_provider_local.PullMessage) error {
return func(msg ai_provider_local.PullMessage) error {
return i.transaction.Transaction(context.Background(), func(ctx context.Context) error {
state := ai_local_dto.DeployStateFinish.Int()
modelState := ai_local_dto.LocalModelStateNormal.Int()
if msg.Status == "error" {
state = ai_local_dto.DeployStateDownloadError.Int()
modelState = ai_local_dto.LocalModelStateDeployingError.Int()
}
err := i.localModelService.Save(ctx, msg.Model, &ai_local.EditLocalModel{State: &modelState})
if err != nil {
return err
}
info, err := i.localModelStateService.Get(ctx, msg.Model)
if err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) {
return err
}
err = i.localModelStateService.Create(ctx, &ai_local.CreateLocalModelInstallState{
Id: msg.Model,
Complete: msg.Completed,
Total: msg.Total,
State: state,
Msg: msg.Msg,
})
if err != nil {
return err
}
info, err = i.localModelStateService.Get(ctx, msg.Model)
if err != nil {
return err
}
} else {
if info.Complete < msg.Completed {
info.Complete = msg.Completed
}
if info.Total < msg.Total {
info.Total = msg.Total
}
if msg.Msg != "" {
info.Msg = msg.Msg
}
err = i.localModelStateService.Save(ctx, msg.Model, &ai_local.EditLocalModelInstallState{State: &state, Complete: &info.Complete, Total: &info.Total, Msg: &info.Msg})
if err != nil {
return err
}
}
serviceState := 0
if msg.Status == "error" {
state = 2
}
list, err := i.localModelCacheService.List(ctx, msg.Model, ai_local.CacheTypeService)
if err != nil {
return err
}
for _, l := range list {
serviceInfo, err := i.serviceService.Get(ctx, l.Target)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
continue
}
return err
}
if serviceInfo.State == serviceState {
continue
}
err = i.serviceService.Save(ctx, l.Target, &service.Edit{State: &serviceState})
if err != nil {
return err
}
}
if state == ai_local_dto.DeployStateFinish.Int() {
for _, f := range fn {
err = f()
if err != nil {
return err
}
}
v, _ := i.settingService.Get(ctx, "system.ai_model.ollama_address")
cfg := make(map[string]interface{})
cfg["provider"] = "ollama"
cfg["model"] = msg.Model
cfg["model_config"] = ai_provider_local.OllamaConfig
cfg["priority"] = 0
cfg["base"] = v
return i.syncGateway(ctx, cluster.DefaultClusterID, []*gateway.DynamicRelease{
{
BasicItem: &gateway.BasicItem{
ID: msg.Model,
Description: msg.Model,
Resource: "ai-provider",
Version: info.UpdateAt.Format("20060102150405"),
MatchLabels: map[string]string{
"module": "ai-provider",
},
},
Attr: cfg,
}}, true)
}
return nil
})
}
}
func (i *imlLocalModel) syncGateway(ctx context.Context, clusterId string, releases []*gateway.DynamicRelease, online bool) error {
client, err := i.clusterService.GatewayClient(ctx, clusterId)
if err != nil {
log.Errorf("get apinto client error: %v", err)
return nil
}
defer func() {
err := client.Close(ctx)
if err != nil {
log.Warn("close apinto client:", err)
}
}()
for _, releaseInfo := range releases {
dynamicClient, err := client.Dynamic(releaseInfo.Resource)
if err != nil {
return err
}
if online {
err = dynamicClient.Online(ctx, releaseInfo)
} else {
err = dynamicClient.Offline(ctx, releaseInfo)
}
if err != nil {
return err
}
}
return nil
}
func (i *imlLocalModel) Deploy(ctx context.Context, model string, session string, fn ...func() error) (*ai_provider_local.Pipeline, error) {
var p *ai_provider_local.Pipeline
err := i.transaction.Transaction(ctx, func(txCtx context.Context) error {
item, err := i.localModelCacheService.GetByTarget(ctx, ai_local.CacheTypeService, model)
if err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) {
return err
}
} else {
model = item.Model
}
info, err := i.localModelService.Get(ctx, model)
if err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) {
return err
}
err = i.localModelService.Create(ctx, &ai_local.CreateLocalModel{
Id: model,
Name: model,
Provider: "ollama",
State: ai_local_dto.LocalModelStateDeploying.Int(),
})
} else {
if info.State == ai_local_dto.LocalModelStateDeployingError.Int() {
state := ai_local_dto.LocalModelStateDeploying.Int()
err = i.localModelService.Save(ctx, model, &ai_local.EditLocalModel{State: &state})
}
}
if err != nil {
return err
}
p, err = ai_provider_local.PullModel(model, session, i.pullHook(fn...))
if err != nil {
return err
}
return nil
})
if err != nil {
return nil, err
}
return p, nil
}
func (i *imlLocalModel) SaveCache(ctx context.Context, model string, target string) error {
return i.localModelCacheService.Save(ctx, model, ai_local.CacheTypeService, target)
}
func (i *imlLocalModel) CancelDeploy(ctx context.Context, model string) error {
return i.transaction.Transaction(ctx, func(txCtx context.Context) error {
item, err := i.localModelCacheService.GetByTarget(ctx, ai_local.CacheTypeService, model)
if err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) {
return err
}
} else {
model = item.Model
}
list, err := i.localModelCacheService.List(ctx, model, ai_local.CacheTypeService)
if err != nil {
return err
}
for _, l := range list {
info, err := i.serviceService.Get(ctx, l.Target)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
continue
}
return err
}
if info.State == 0 {
continue
}
err = i.serviceService.Delete(ctx, info.Id)
if err != nil {
return err
}
err = i.aiAPIService.DeleteByService(ctx, info.Id)
if err != nil {
return err
}
err = i.routerService.DeleteByService(ctx, info.Id)
if err != nil {
return err
}
}
err = i.localModelCacheService.Delete(ctx, model)
if err != nil {
return err
}
// 删除模型
err = i.localModelService.Delete(ctx, model)
if err != nil {
return err
}
ai_provider_local.StopPull(model)
return nil
})
}
func (i *imlLocalModel) RemoveModel(ctx context.Context, model string) error {
// 判断是否有api
count, err := i.aiAPIService.CountByModel(ctx, model)
if err != nil {
return err
}
if count > 0 {
return fmt.Errorf("model %s has api, can not remove", model)
}
info, err := i.localModelService.Get(ctx, model)
if err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) {
return err
}
return ai_provider_local.RemoveModel(model)
}
if info.State == ai_local_dto.LocalModelStateDeploying.Int() {
return fmt.Errorf("model %s is deploying, can not remove", model)
}
return i.transaction.Transaction(ctx, func(txCtx context.Context) error {
err = i.localModelService.Delete(ctx, model)
if err != nil {
return err
}
return ai_provider_local.RemoveModel(model)
})
}
func (i *imlLocalModel) Enable(ctx context.Context, model string) error {
info, err := i.localModelService.Get(ctx, model)
if err != nil {
return err
}
if info.State == ai_local_dto.LocalModelStateDisable.Int() || info.State == ai_local_dto.LocalModelStateError.Int() {
return i.transaction.Transaction(ctx, func(ctx context.Context) error {
status := ai_local_dto.LocalModelStateNormal.Int()
err = i.localModelService.Save(ctx, model, &ai_local.EditLocalModel{State: &status})
if err != nil {
return err
}
v, _ := i.settingService.Get(ctx, "system.ai_model.ollama_address")
cfg := make(map[string]interface{})
cfg["provider"] = "ollama"
cfg["model"] = info.Id
cfg["model_config"] = ai_provider_local.OllamaConfig
cfg["priority"] = 0
cfg["base"] = v
return i.syncGateway(ctx, cluster.DefaultClusterID, []*gateway.DynamicRelease{
{
BasicItem: &gateway.BasicItem{
ID: info.Id,
Description: info.Id,
Resource: "ai-provider",
Version: info.UpdateAt.Format("20060102150405"),
MatchLabels: map[string]string{
"module": "ai-provider",
},
},
Attr: cfg,
}}, true)
})
}
return fmt.Errorf("model %s is not disabled state,can not enable", model)
}
func (i *imlLocalModel) Disable(ctx context.Context, model string) error {
info, err := i.localModelService.Get(ctx, model)
if err != nil {
return err
}
if info.State == ai_local_dto.LocalModelStateNormal.Int() {
disable := ai_local_dto.LocalModelStateDisable.Int()
return i.transaction.Transaction(ctx, func(ctx context.Context) error {
err = i.localModelService.Save(ctx, model, &ai_local.EditLocalModel{State: &disable})
if err != nil {
return err
}
return i.syncGateway(ctx, cluster.DefaultClusterID, []*gateway.DynamicRelease{
{
BasicItem: &gateway.BasicItem{
ID: info.Id,
Resource: "ai-provider",
},
}}, false)
})
}
return fmt.Errorf("model %s is not enabled state,can not disable", model)
}
func (i *imlLocalModel) OnInit() {
register.Handle(func(v server.Server) {
ctx := context.Background()
list, err := i.localModelPackageService.List(ctx)
if err != nil {
return
}
oldModels := utils.SliceToMapO(list, func(s *ai_local.LocalModelPackage) (string, *ai_local.LocalModelPackage) {
return s.Id, s
})
models, version := ai_provider_local.ModelsCanInstall()
for _, model := range models {
delete(oldModels, model.Id)
if v, ok := oldModels[model.Id]; ok {
if v.Version == version {
continue
}
err = i.localModelPackageService.Save(ctx, model.Id, &ai_local.EditLocalModelPackage{
Size: &model.Size,
Hash: &model.Digest,
Description: &model.Description,
Version: &version,
Popular: &model.IsPopular,
})
if err != nil {
return
}
} else {
err = i.localModelPackageService.Create(ctx, &ai_local.CreateLocalModelPackage{
Id: model.Id,
Name: model.Name,
Size: model.Size,
Hash: model.Digest,
Description: model.Description,
Version: version,
Popular: model.IsPopular,
})
if err != nil {
return
}
}
}
for id := range oldModels {
err = i.localModelPackageService.Delete(ctx, id)
if err != nil {
return
}
}
installModels, err := ai_provider_local.ModelsInstalled()
if err != nil {
return
}
for _, model := range installModels {
id := strings.TrimSuffix(model.Name, ":latest")
name := strings.TrimSuffix(model.Name, ":latest")
_, err = i.localModelService.Get(ctx, id)
if err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) {
return
}
err = i.localModelService.Create(ctx, &ai_local.CreateLocalModel{
Id: id,
Name: name,
State: 1,
})
if err != nil {
return
}
}
}
})
}
func (i *imlLocalModel) getLocalModels(ctx context.Context, v string) ([]*gateway.DynamicRelease, error) {
list, err := i.localModelService.List(ctx)
if err != nil {
return nil, err
}
if v == "" {
var has bool
v, has = i.settingService.Get(ctx, "system.ai_model.ollama_address")
if !has {
return nil, errors.New("ollama_address not set")
}
}
releases := make([]*gateway.DynamicRelease, 0, len(list))
for _, l := range list {
if l.State != ai_local_dto.LocalModelStateNormal.Int() {
continue
}
cfg := make(map[string]interface{})
cfg["provider"] = "ollama"
cfg["model"] = l.Id
cfg["model_config"] = ai_provider_local.OllamaConfig
cfg["base"] = v
releases = append(releases, &gateway.DynamicRelease{
BasicItem: &gateway.BasicItem{
ID: l.Id,
Description: l.Name,
Resource: "ai-provider",
Version: l.UpdateAt.Format("20060102150405"),
MatchLabels: map[string]string{
"module": "ai-provider",
},
},
Attr: cfg,
})
}
return releases, nil
}
func (i *imlLocalModel) initGateway(ctx context.Context, clusterId string, clientDriver gateway.IClientDriver) error {
releases, err := i.getLocalModels(ctx, "")
if err != nil {
return err
}
for _, p := range releases {
client, err := clientDriver.Dynamic(p.Resource)
if err != nil {
return err
}
err = client.Online(ctx, p)
if err != nil {
return err
}
}
return nil
}
+37
View File
@@ -0,0 +1,37 @@
package ai_local
import (
"context"
"reflect"
"github.com/APIParkLab/APIPark/gateway"
"github.com/eolinker/go-common/autowire"
ai_provider_local "github.com/APIParkLab/APIPark/ai-provider/local"
ai_local_dto "github.com/APIParkLab/APIPark/module/ai-local/dto"
)
type ILocalModelModule interface {
Search(ctx context.Context, keyword string) ([]*ai_local_dto.LocalModelItem, error)
ListCanInstall(ctx context.Context, keyword string) ([]*ai_local_dto.LocalModelPackageItem, error)
Deploy(ctx context.Context, model string, session string, fn ...func() error) (*ai_provider_local.Pipeline, error)
CancelDeploy(ctx context.Context, model string) error
RemoveModel(ctx context.Context, model string) error
Enable(ctx context.Context, model string) error
Disable(ctx context.Context, model string) error
ModelState(ctx context.Context, model string) (*ai_local_dto.DeployState, *ai_local_dto.ModelInfo, error)
SimpleList(ctx context.Context) ([]*ai_local_dto.SimpleItem, error)
SaveCache(ctx context.Context, model string, target string) error
SyncLocalModels(ctx context.Context, address string) error
}
func init() {
localModel := new(imlLocalModel)
autowire.Auto[ILocalModelModule](func() reflect.Value {
gateway.RegisterInitHandleFunc(localModel.initGateway)
return reflect.ValueOf(localModel)
})
}
+13 -13
View File
@@ -13,15 +13,15 @@ type SimpleProvider struct {
}
type Provider struct {
Id string `json:"id"`
Name string `json:"name"`
Config string `json:"config"`
GetAPIKeyUrl string `json:"get_apikey_url"`
DefaultLLM string `json:"default_llm"`
DefaultLLMConfig string `json:"-"`
Priority int `json:"priority"`
Status ProviderStatus `json:"status"`
Configured bool `json:"configured"`
Id string `json:"id"`
Name string `json:"name"`
Config string `json:"config"`
GetAPIKeyUrl string `json:"get_apikey_url"`
DefaultLLM string `json:"default_llm"`
DefaultLLMConfig string `json:"-"`
//Priority int `json:"priority"`
Status ProviderStatus `json:"status"`
Configured bool `json:"configured"`
}
type ConfiguredProviderItem struct {
@@ -31,9 +31,8 @@ type ConfiguredProviderItem struct {
DefaultLLM string `json:"default_llm"`
Status ProviderStatus `json:"status"`
APICount int64 `json:"api_count"`
KeyCount int `json:"key_count"`
KeyStatus []*KeyStatus `json:"keys"`
Priority int `json:"priority"`
KeyCount int64 `json:"key_count"`
CanDelete bool `json:"can_delete"`
}
type KeyStatus struct {
@@ -59,13 +58,14 @@ type SimpleProviderItem struct {
DefaultConfig string `json:"default_config"`
Status ProviderStatus `json:"status"`
Model *BasicInfo `json:"model,omitempty"`
Priority int `json:"-"`
Type string `json:"type"`
}
type BackupProvider struct {
Id string `json:"id"`
Name string `json:"name"`
Model *BasicInfo `json:"model,omitempty"`
Type string `json:"type"`
}
type LLMItem struct {
+241 -229
View File
@@ -8,9 +8,18 @@ import (
"sort"
"time"
"github.com/APIParkLab/APIPark/service/service"
"github.com/google/uuid"
ai_key_dto "github.com/APIParkLab/APIPark/module/ai-key/dto"
ai_provider_local "github.com/APIParkLab/APIPark/ai-provider/local"
"github.com/eolinker/go-common/register"
"github.com/eolinker/go-common/server"
ai_local "github.com/APIParkLab/APIPark/service/ai-local"
ai_balance "github.com/APIParkLab/APIPark/service/ai-balance"
"github.com/APIParkLab/APIPark/service/service"
ai_key "github.com/APIParkLab/APIPark/service/ai-key"
@@ -54,11 +63,105 @@ func newKey(key *ai_key.Key) *gateway.DynamicRelease {
var _ IProviderModule = (*imlProviderModule)(nil)
type imlProviderModule struct {
providerService ai.IProviderService `autowired:""`
clusterService cluster.IClusterService `autowired:""`
aiAPIService ai_api.IAPIService `autowired:""`
aiKeyService ai_key.IKeyService `autowired:""`
transaction store.ITransaction `autowired:""`
providerService ai.IProviderService `autowired:""`
clusterService cluster.IClusterService `autowired:""`
aiAPIService ai_api.IAPIService `autowired:""`
aiKeyService ai_key.IKeyService `autowired:""`
aiBalanceService ai_balance.IBalanceService `autowired:""`
transaction store.ITransaction `autowired:""`
}
func (i *imlProviderModule) OnInit() {
register.Handle(func(v server.Server) {
ctx := context.Background()
list, err := i.providerService.List(ctx)
if err != nil {
return
}
i.transaction.Transaction(ctx, func(ctx context.Context) error {
for _, l := range list {
if l.Priority < 1 {
continue
}
has, err := i.aiBalanceService.Exist(ctx, l.Id, l.DefaultLLM)
if err != nil {
return err
}
if has {
continue
}
p, has := model_runtime.GetProvider(l.Id)
if !has {
continue
}
err = i.aiBalanceService.Create(ctx, &ai_balance.Create{
Id: uuid.NewString(),
Priority: l.Priority,
Provider: l.Id,
ProviderName: p.Name(),
Model: l.DefaultLLM,
ModelName: l.DefaultLLM,
Type: 0,
})
if err != nil {
return err
}
priority := 0
err = i.providerService.Save(ctx, l.Id, &ai.SetProvider{
Priority: &priority,
})
if err != nil {
return err
}
}
return nil
})
})
}
func (i *imlProviderModule) Delete(ctx context.Context, id string) error {
return i.transaction.Transaction(ctx, func(ctx context.Context) error {
// 判断是否有api
count, err := i.aiAPIService.CountByProvider(ctx, id)
if err != nil {
return err
}
if count > 0 {
return fmt.Errorf("provider has api")
}
keys, err := i.aiKeyService.KeysByProvider(ctx, id)
if err != nil {
return err
}
err = i.aiKeyService.DeleteByProvider(ctx, id)
if err != nil {
return err
}
err = i.providerService.Delete(ctx, id)
if err != nil {
return err
}
releases := make([]*gateway.DynamicRelease, 0, len(keys))
for _, key := range keys {
releases = append(releases, newKey(key))
}
err = i.syncGateway(ctx, cluster.DefaultClusterID, releases, false)
if err != nil {
return err
}
return i.syncGateway(ctx, cluster.DefaultClusterID, []*gateway.DynamicRelease{
{
BasicItem: &gateway.BasicItem{
ID: id,
Resource: "ai-provider",
},
},
}, false)
})
}
func (i *imlProviderModule) SimpleProvider(ctx context.Context, id string) (*ai_dto.SimpleProvider, error) {
@@ -75,83 +178,19 @@ func (i *imlProviderModule) SimpleProvider(ctx context.Context, id string) (*ai_
}, nil
}
func (i *imlProviderModule) Sort(ctx context.Context, input *ai_dto.Sort) error {
return i.transaction.Transaction(ctx, func(txCtx context.Context) error {
list, err := i.providerService.List(ctx)
if err != nil {
return err
}
providerMap := utils.SliceToMap(list, func(e *ai.Provider) string {
return e.Id
})
releases := make([]*gateway.DynamicRelease, 0, len(list))
offlineReleases := make([]*gateway.DynamicRelease, 0, len(list))
for index, id := range input.Providers {
p, has := model_runtime.GetProvider(id)
if !has {
continue
}
l, has := providerMap[id]
if !has {
continue
}
model, has := p.GetModel(l.DefaultLLM)
if !has {
continue
}
priority := index + 1
err = i.providerService.Save(txCtx, id, &ai.SetProvider{
Priority: &priority,
})
if err != nil {
return err
}
if ai_dto.ToProviderStatus(l.Status) == ai_dto.ProviderDisabled {
offlineReleases = append(offlineReleases, &gateway.DynamicRelease{
BasicItem: &gateway.BasicItem{
ID: l.Id,
Resource: "ai-provider",
}})
} else {
cfg := make(map[string]interface{})
cfg["provider"] = l.Id
cfg["model"] = l.DefaultLLM
cfg["model_config"] = model.DefaultConfig()
cfg["priority"] = l.Priority
cfg["base"] = fmt.Sprintf("%s://%s", p.URI().Scheme(), p.URI().Host())
releases = append(releases, &gateway.DynamicRelease{
BasicItem: &gateway.BasicItem{
ID: l.Id,
Description: l.Name,
Resource: "ai-provider",
Version: l.UpdateAt.Format("20060102150405"),
MatchLabels: map[string]string{
"module": "ai-provider",
},
},
Attr: cfg,
})
}
}
err = i.syncGateway(ctx, cluster.DefaultClusterID, releases, true)
if err != nil {
return err
}
return i.syncGateway(ctx, cluster.DefaultClusterID, offlineReleases, false)
})
}
func (i *imlProviderModule) ConfiguredProviders(ctx context.Context) ([]*ai_dto.ConfiguredProviderItem, *ai_dto.BackupProvider, error) {
func (i *imlProviderModule) ConfiguredProviders(ctx context.Context, keyword string) ([]*ai_dto.ConfiguredProviderItem, error) {
// 获取已配置的AI服务商
list, err := i.providerService.List(ctx)
list, err := i.providerService.Search(ctx, keyword, nil, "update_at")
if err != nil {
return nil, nil, fmt.Errorf("get provider list error:%v", err)
return nil, fmt.Errorf("get provider list error:%v", err)
}
aiAPIMap, err := i.aiAPIService.CountMapByProvider(ctx, "", nil)
if err != nil {
return nil, nil, fmt.Errorf("get ai api count error:%v", err)
return nil, fmt.Errorf("get ai api count error:%v", err)
}
keyMap, err := i.aiKeyService.CountMapByProvider(ctx, "", nil)
if err != nil {
return nil, fmt.Errorf("get ai key count error:%v", err)
}
providers := make([]*ai_dto.ConfiguredProviderItem, 0, len(list))
for _, l := range list {
@@ -159,7 +198,7 @@ func (i *imlProviderModule) ConfiguredProviders(ctx context.Context) ([]*ai_dto.
_, err = i.aiKeyService.DefaultKey(ctx, l.Id)
if err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil, err
return nil, err
}
err = i.aiKeyService.Create(ctx, &ai_key.Create{
ID: l.Id,
@@ -173,7 +212,7 @@ func (i *imlProviderModule) ConfiguredProviders(ctx context.Context) ([]*ai_dto.
Default: true,
})
if err != nil {
return nil, nil, fmt.Errorf("create default key error:%v", err)
return nil, fmt.Errorf("create default key error:%v", err)
}
}
@@ -181,29 +220,7 @@ func (i *imlProviderModule) ConfiguredProviders(ctx context.Context) ([]*ai_dto.
if !has {
continue
}
keys, err := i.aiKeyService.KeysByProvider(ctx, l.Id)
if err != nil {
return nil, nil, fmt.Errorf("get provider keys error:%v", err)
}
keysStatus := make([]*ai_dto.KeyStatus, 0, len(keys))
for _, k := range keys {
status := ai_key_dto.ToKeyStatus(k.Status)
switch status {
case ai_key_dto.KeyNormal, ai_key_dto.KeyDisable, ai_key_dto.KeyError:
default:
status = ai_key_dto.KeyError
}
keysStatus = append(keysStatus, &ai_dto.KeyStatus{
Id: k.ID,
Name: k.Name,
Status: status.String(),
Priority: k.Priority,
})
}
sort.Slice(keysStatus, func(i, j int) bool {
return keysStatus[i].Priority < keysStatus[j].Priority
})
apiCount := aiAPIMap[l.Id]
providers = append(providers, &ai_dto.ConfiguredProviderItem{
Id: l.Id,
@@ -211,35 +228,13 @@ func (i *imlProviderModule) ConfiguredProviders(ctx context.Context) ([]*ai_dto.
Logo: p.Logo(),
DefaultLLM: l.DefaultLLM,
Status: ai_dto.ToProviderStatus(l.Status),
APICount: aiAPIMap[l.Id],
KeyCount: len(keysStatus),
KeyStatus: keysStatus,
Priority: l.Priority,
APICount: apiCount,
KeyCount: keyMap[l.Id],
CanDelete: apiCount < 1,
})
}
sort.Slice(providers, func(i, j int) bool {
if providers[i].Priority != providers[j].Priority {
if providers[i].Priority == 0 {
return false
}
if providers[j].Priority == 0 {
return true
}
return providers[i].Priority < providers[j].Priority
}
return providers[i].Name < providers[j].Name
})
var backup *ai_dto.BackupProvider
for _, p := range providers {
if p.Status == ai_dto.ProviderEnabled {
backup = &ai_dto.BackupProvider{
Id: p.Id,
Name: p.Name,
}
break
}
}
return providers, backup, nil
return providers, nil
}
func (i *imlProviderModule) SimpleProviders(ctx context.Context) ([]*ai_dto.SimpleProviderItem, error) {
@@ -252,6 +247,7 @@ func (i *imlProviderModule) SimpleProviders(ctx context.Context) ([]*ai_dto.Simp
providerMap := utils.SliceToMap(list, func(e *ai.Provider) string {
return e.Id
})
items := make([]*ai_dto.SimpleProviderItem, 0, len(providers))
for _, v := range providers {
item := &ai_dto.SimpleProviderItem{
@@ -264,31 +260,35 @@ func (i *imlProviderModule) SimpleProviders(ctx context.Context) ([]*ai_dto.Simp
if info, has := providerMap[v.ID()]; has {
item.Configured = true
item.Status = ai_dto.ToProviderStatus(info.Status)
item.Priority = info.Priority
}
items = append(items, item)
}
sort.Slice(items, func(i, j int) bool {
if items[i].Priority != items[j].Priority {
if items[i].Priority == 0 {
return false
}
if items[j].Priority == 0 {
return true
}
return items[i].Priority < items[j].Priority
}
return items[i].Name < items[j].Name
})
return items, nil
}
func (i *imlProviderModule) SimpleConfiguredProviders(ctx context.Context) ([]*ai_dto.SimpleProviderItem, *ai_dto.BackupProvider, error) {
func (i *imlProviderModule) SimpleConfiguredProviders(ctx context.Context, all bool) ([]*ai_dto.SimpleProviderItem, *ai_dto.BackupProvider, error) {
list, err := i.providerService.List(ctx)
if err != nil {
return nil, nil, err
}
items := make([]*ai_dto.SimpleProviderItem, 0, len(list))
healthProvider := make(map[string]struct{})
if all {
healthProvider["ollama"] = struct{}{}
items = append(items, &ai_dto.SimpleProviderItem{
Id: "ollama",
Name: "Ollama",
Logo: ai_provider_local.OllamaSvg,
Configured: true,
DefaultConfig: "",
Status: ai_dto.ProviderEnabled,
Type: "local",
})
}
var backup *ai_dto.BackupProvider
for _, l := range list {
p, has := model_runtime.GetProvider(l.Id)
@@ -308,34 +308,32 @@ func (i *imlProviderModule) SimpleConfiguredProviders(ctx context.Context) ([]*a
Logo: p.Logo(),
DefaultConfig: p.DefaultConfig(),
Status: ai_dto.ToProviderStatus(l.Status),
Priority: l.Priority,
Configured: true,
Model: &ai_dto.BasicInfo{
Id: model.ID(),
Name: model.ID(),
},
}
if item.Status == ai_dto.ProviderEnabled {
healthProvider[l.Id] = struct{}{}
}
items = append(items, item)
}
sort.Slice(items, func(i, j int) bool {
if items[i].Priority != items[j].Priority {
if items[i].Priority == 0 {
return false
}
if items[j].Priority == 0 {
return true
}
return items[i].Priority < items[j].Priority
}
return items[i].Name < items[j].Name
})
for _, item := range items {
if item.Status == ai_dto.ProviderEnabled {
aiBalanceItems, err := i.aiBalanceService.Search(ctx, "", nil, "priority asc")
if err != nil {
return nil, nil, err
}
for _, item := range aiBalanceItems {
if _, has := healthProvider[item.Provider]; has {
backup = &ai_dto.BackupProvider{
Id: item.Id,
Name: item.Name,
Model: item.Model,
Id: item.Provider,
Name: item.Provider,
Model: &ai_dto.BasicInfo{
Id: item.Model,
Name: item.Model,
},
Type: "local",
}
break
}
@@ -388,13 +386,7 @@ func (i *imlProviderModule) Provider(ctx context.Context, id string) (*ai_dto.Pr
if !has {
return nil, fmt.Errorf("ai provider not found")
}
maxPriority, err := i.providerService.MaxPriority(ctx)
if err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, err
}
}
maxPriority = maxPriority + 1
info, err := i.providerService.Get(ctx, id)
if err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) {
@@ -412,7 +404,7 @@ func (i *imlProviderModule) Provider(ctx context.Context, id string) (*ai_dto.Pr
DefaultLLM: defaultLLM.ID(),
DefaultLLMConfig: defaultLLM.Logo(),
Status: ai_dto.ProviderDisabled,
Priority: maxPriority,
//Priority: maxPriority,
}, nil
}
defaultLLM, has := p.GetModel(info.DefaultLLM)
@@ -423,9 +415,6 @@ func (i *imlProviderModule) Provider(ctx context.Context, id string) (*ai_dto.Pr
}
defaultLLM = model
}
if info.Priority == 0 {
info.Priority = maxPriority
}
return &ai_dto.Provider{
Id: info.Id,
@@ -434,9 +423,9 @@ func (i *imlProviderModule) Provider(ctx context.Context, id string) (*ai_dto.Pr
GetAPIKeyUrl: p.HelpUrl(),
DefaultLLM: defaultLLM.ID(),
DefaultLLMConfig: defaultLLM.DefaultConfig(),
Priority: info.Priority,
Status: ai_dto.ToProviderStatus(info.Status),
Configured: true,
//Priority: info.Priority,
Status: ai_dto.ToProviderStatus(info.Status),
Configured: true,
}, nil
}
@@ -492,38 +481,48 @@ func (i *imlProviderModule) UpdateProviderConfig(ctx context.Context, id string,
if !has {
return fmt.Errorf("ai provider not found")
}
info, err := i.providerService.Get(ctx, id)
if err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) {
return i.transaction.Transaction(ctx, func(ctx context.Context) error {
info, err := i.providerService.Get(ctx, id)
if err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) {
return err
}
if input.DefaultLLM == "" {
defaultLLM, has := p.DefaultModel(model_runtime.ModelTypeLLM)
if !has {
return fmt.Errorf("ai provider default llm not found")
}
input.DefaultLLM = defaultLLM.ID()
}
info = &ai.Provider{
Id: id,
Name: p.Name(),
DefaultLLM: input.DefaultLLM,
Config: input.Config,
}
err = i.providerService.Create(ctx, &ai.CreateProvider{
Id: info.Id,
Name: info.Name,
DefaultLLM: input.DefaultLLM,
Config: input.Config,
})
if err != nil {
return err
}
}
model, has := p.GetModel(input.DefaultLLM)
if !has {
return fmt.Errorf("ai provider model not found")
}
err = p.Check(input.Config)
if err != nil {
return err
}
if input.DefaultLLM == "" {
defaultLLM, has := p.DefaultModel(model_runtime.ModelTypeLLM)
if !has {
return fmt.Errorf("ai provider default llm not found")
}
input.DefaultLLM = defaultLLM.ID()
input.Config, err = p.GenConfig(input.Config, info.Config)
if err != nil {
return err
}
info = &ai.Provider{
Id: id,
Name: p.Name(),
DefaultLLM: input.DefaultLLM,
Config: input.Config,
}
}
model, has := p.GetModel(input.DefaultLLM)
if !has {
return fmt.Errorf("ai provider model not found")
}
err = p.Check(input.Config)
if err != nil {
return err
}
input.Config, err = p.GenConfig(input.Config, info.Config)
if err != nil {
return err
}
return i.transaction.Transaction(ctx, func(txCtx context.Context) error {
status := 0
if input.Enable != nil && *input.Enable {
status = 1
@@ -532,15 +531,14 @@ func (i *imlProviderModule) UpdateProviderConfig(ctx context.Context, id string,
Name: &info.Name,
DefaultLLM: &input.DefaultLLM,
Config: &input.Config,
Priority: input.Priority,
Status: &status,
}
_, err = i.aiKeyService.DefaultKey(txCtx, id)
_, err = i.aiKeyService.DefaultKey(ctx, id)
if err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) {
return err
}
err = i.aiKeyService.Create(txCtx, &ai_key.Create{
err = i.aiKeyService.Create(ctx, &ai_key.Create{
ID: id,
Name: info.Name,
Config: input.Config,
@@ -551,7 +549,7 @@ func (i *imlProviderModule) UpdateProviderConfig(ctx context.Context, id string,
Priority: 1,
})
} else {
err = i.aiKeyService.Save(txCtx, id, &ai_key.Edit{
err = i.aiKeyService.Save(ctx, id, &ai_key.Edit{
Config: &input.Config,
Status: &status,
})
@@ -559,13 +557,13 @@ func (i *imlProviderModule) UpdateProviderConfig(ctx context.Context, id string,
if err != nil {
return err
}
err = i.providerService.Save(txCtx, id, pInfo)
err = i.providerService.Save(ctx, id, pInfo)
if err != nil {
return err
}
if *pInfo.Status == 0 {
return i.syncGateway(txCtx, cluster.DefaultClusterID, []*gateway.DynamicRelease{
return i.syncGateway(ctx, cluster.DefaultClusterID, []*gateway.DynamicRelease{
{
BasicItem: &gateway.BasicItem{
ID: id,
@@ -575,7 +573,7 @@ func (i *imlProviderModule) UpdateProviderConfig(ctx context.Context, id string,
}, false)
}
// 获取当前供应商默认Key信息
defaultKey, err := i.aiKeyService.DefaultKey(txCtx, id)
defaultKey, err := i.aiKeyService.DefaultKey(ctx, id)
if err != nil {
return err
}
@@ -583,9 +581,8 @@ func (i *imlProviderModule) UpdateProviderConfig(ctx context.Context, id string,
cfg["provider"] = info.Id
cfg["model"] = info.DefaultLLM
cfg["model_config"] = model.DefaultConfig()
cfg["priority"] = info.Priority
cfg["base"] = fmt.Sprintf("%s://%s", p.URI().Scheme(), p.URI().Host())
return i.syncGateway(txCtx, cluster.DefaultClusterID, []*gateway.DynamicRelease{
return i.syncGateway(ctx, cluster.DefaultClusterID, []*gateway.DynamicRelease{
{
BasicItem: &gateway.BasicItem{
ID: id,
@@ -624,7 +621,6 @@ func (i *imlProviderModule) getAiProviders(ctx context.Context) ([]*gateway.Dyna
cfg["provider"] = l.Id
cfg["model"] = l.DefaultLLM
cfg["model_config"] = model.DefaultConfig()
cfg["priority"] = l.Priority
providers = append(providers, &gateway.DynamicRelease{
BasicItem: &gateway.BasicItem{
ID: l.Id,
@@ -694,16 +690,38 @@ func (i *imlProviderModule) syncGateway(ctx context.Context, clusterId string, r
var _ IAIAPIModule = (*imlAIApiModule)(nil)
type imlAIApiModule struct {
aiAPIService ai_api.IAPIService `autowired:""`
aiAPIUseService ai_api.IAPIUseService `autowired:""`
serviceService service.IServiceService `autowired:""`
aiAPIService ai_api.IAPIService `autowired:""`
aiAPIUseService ai_api.IAPIUseService `autowired:""`
serviceService service.IServiceService `autowired:""`
aiLocalModelService ai_local.ILocalModelService `autowired:""`
}
func (i *imlAIApiModule) APIs(ctx context.Context, keyword string, providerId string, start int64, end int64, page int, pageSize int, sortCondition string, asc bool, models []string, serviceIds []string) ([]*ai_dto.APIItem, *ai_dto.Condition, int64, error) {
p, has := model_runtime.GetProvider(providerId)
if !has {
return nil, nil, 0, fmt.Errorf("ai provider not found")
modelItems := make([]*ai_dto.BasicInfo, 0)
if providerId == "ollama" {
items, err := i.aiLocalModelService.Search(ctx, "", nil, "update_at desc")
if err != nil {
return nil, nil, 0, err
}
modelItems = utils.SliceToSlice(items, func(e *ai_local.LocalModel) *ai_dto.BasicInfo {
return &ai_dto.BasicInfo{
Id: e.Id,
Name: e.Name,
}
})
} else {
p, has := model_runtime.GetProvider(providerId)
if !has {
return nil, nil, 0, fmt.Errorf("ai provider not found")
}
modelItems = utils.SliceToSlice(p.Models(), func(e model_runtime.IModel) *ai_dto.BasicInfo {
return &ai_dto.BasicInfo{
Id: e.ID(),
Name: e.ID(),
}
})
}
sortRule := "desc"
if asc {
sortRule = "asc"
@@ -723,12 +741,6 @@ func (i *imlAIApiModule) APIs(ctx context.Context, keyword string, providerId st
}
modelItems := utils.SliceToSlice(p.Models(), func(e model_runtime.IModel) *ai_dto.BasicInfo {
return &ai_dto.BasicInfo{
Id: e.ID(),
Name: e.ID(),
}
})
condition := &ai_dto.Condition{Services: serviceItems, Models: modelItems}
switch sortCondition {
default:
+6 -5
View File
@@ -10,23 +10,24 @@ import (
)
type IProviderModule interface {
ConfiguredProviders(ctx context.Context) ([]*ai_dto.ConfiguredProviderItem, *ai_dto.BackupProvider, error)
ConfiguredProviders(ctx context.Context, keyword string) ([]*ai_dto.ConfiguredProviderItem, error)
UnConfiguredProviders(ctx context.Context) ([]*ai_dto.ProviderItem, error)
SimpleProviders(ctx context.Context) ([]*ai_dto.SimpleProviderItem, error)
SimpleConfiguredProviders(ctx context.Context) ([]*ai_dto.SimpleProviderItem, *ai_dto.BackupProvider, error)
SimpleConfiguredProviders(ctx context.Context, all bool) ([]*ai_dto.SimpleProviderItem, *ai_dto.BackupProvider, error)
Provider(ctx context.Context, id string) (*ai_dto.Provider, error)
SimpleProvider(ctx context.Context, id string) (*ai_dto.SimpleProvider, error)
LLMs(ctx context.Context, driver string) ([]*ai_dto.LLMItem, *ai_dto.ProviderItem, error)
//UpdateProviderStatus(ctx context.Context, id string, enable bool) error
UpdateProviderConfig(ctx context.Context, id string, input *ai_dto.UpdateConfig) error
//UpdateProviderDefaultLLM(ctx context.Context, id string, input *ai_dto.UpdateLLM) error
Sort(ctx context.Context, input *ai_dto.Sort) error
Delete(ctx context.Context, id string) error
}
type IAIAPIModule interface {
APIs(ctx context.Context, keyword string, providerId string, start int64, end int64, page int, pageSize int, sortCondition string, asc bool, models []string, services []string) ([]*ai_dto.APIItem, *ai_dto.Condition, int64, error)
}
type ILocalModelModule interface {
}
func init() {
autowire.Auto[IProviderModule](func() reflect.Value {
module := new(imlProviderModule)
+3 -1
View File
@@ -2,9 +2,10 @@ package catalogue
import (
"context"
"github.com/APIParkLab/APIPark/module/system"
"reflect"
"github.com/APIParkLab/APIPark/module/system"
"github.com/eolinker/go-common/autowire"
catalogue_dto "github.com/APIParkLab/APIPark/module/catalogue/dto"
@@ -28,6 +29,7 @@ type ICatalogueModule interface {
// Subscribe 订阅服务
Subscribe(ctx context.Context, subscribeInfo *catalogue_dto.SubscribeService) error
Sort(ctx context.Context, sorts []*catalogue_dto.SortItem) error
DefaultCatalogue(ctx context.Context) (*catalogue_dto.Catalogue, error)
//ExportAll(ctx context.Context) ([]*catalogue_dto.ExportCatalogue, error)
}
+18
View File
@@ -66,6 +66,24 @@ type imlCatalogueModule struct {
root *Root
}
func (i *imlCatalogueModule) DefaultCatalogue(ctx context.Context) (*catalogue_dto.Catalogue, error) {
catalogues, err := i.catalogueService.List(ctx)
if err != nil {
return nil, err
}
for _, v := range catalogues {
if v.Parent == "" {
return &catalogue_dto.Catalogue{
Id: v.Id,
Name: v.Name,
Parent: v.Parent,
Sort: v.Sort,
}, nil
}
}
return nil, errors.New("no default catalogue")
}
func (i *imlCatalogueModule) onlineSubscriber(ctx context.Context, clusterId string, sub *gateway.SubscribeRelease) error {
client, err := i.clusterService.GatewayClient(ctx, clusterId)
if err != nil {
+13 -2
View File
@@ -1,5 +1,12 @@
package service_dto
type QuickCreateAIService struct {
Provider string `json:"provider"`
Model string `json:"model"`
Config string `json:"config"`
Team string `json:"team"`
}
type CreateService struct {
Id string `json:"id"`
Name string `json:"name"`
@@ -11,7 +18,9 @@ type CreateService struct {
Catalogue string `json:"catalogue"`
ApprovalType string `json:"approval_type"`
Kind string `json:"service_kind"`
Provider *string `json:"provider" aocheck:"ai_provider"`
State string `json:"state"`
Provider *string `json:"provider"`
Model *string `json:"model"`
AsApp *bool `json:"as_app"`
AsServer *bool `json:"as_server"`
ModelMapping string `json:"model_mapping"`
@@ -24,8 +33,10 @@ type EditService struct {
Catalogue *string `json:"catalogue"`
Logo *string `json:"logo"`
Tags *[]string `json:"tags"`
Provider *string `json:"provider" aocheck:"ai_provider"`
Provider *string `json:"provider"`
Model *string `json:"model"`
ApprovalType *string `json:"approval_type"`
State *string `json:"state"`
ModelMapping string `json:"model_mapping"`
}
+58
View File
@@ -5,6 +5,44 @@ import (
"github.com/eolinker/go-common/auto"
)
type ServiceState string
const (
ServiceStateNormal ServiceState = "normal"
ServiceStateDeploying ServiceState = "deploying"
ServiceStateDeployError ServiceState = "error"
)
func (s ServiceState) String() string {
return string(s)
}
func (s ServiceState) Int() int {
switch s {
case ServiceStateNormal:
return 0
case ServiceStateDeploying:
return 1
case ServiceStateDeployError:
return 2
default:
return 0
}
}
func FromServiceState(s int) ServiceState {
switch s {
case 0:
return ServiceStateNormal
case 1:
return ServiceStateDeploying
case 2:
return ServiceStateDeployError
default:
return ""
}
}
type ServiceItem struct {
Id string `json:"id"`
Name string `json:"name"`
@@ -15,6 +53,7 @@ type ServiceItem struct {
CreateTime auto.TimeLabel `json:"create_time"`
UpdateTime auto.TimeLabel `json:"update_time"`
Provider *auto.Label `json:"provider,omitempty" aolabel:"ai_provider"`
State string `json:"state"`
CanDelete bool `json:"can_delete"`
}
@@ -58,10 +97,13 @@ type Service struct {
Tags []auto.Label `json:"tags" aolabel:"tag"`
Logo string `json:"logo"`
Provider *auto.Label `json:"provider,omitempty" aolabel:"ai_provider"`
ProviderType string `json:"provider_type,omitempty"`
Model string `json:"model,omitempty"`
ApprovalType string `json:"approval_type"`
AsServer bool `json:"as_server"`
AsApp bool `json:"as_app"`
ServiceKind string `json:"service_kind"`
State string `json:"state"`
ModelMapping string `json:"model_mapping"`
}
@@ -80,6 +122,7 @@ func ToService(model *service.Service) *Service {
if model.Prefix != "" {
prefix = model.Prefix
}
s := &Service{
Id: model.Id,
Name: model.Name,
@@ -96,10 +139,25 @@ func ToService(model *service.Service) *Service {
AsApp: model.AsApp,
ServiceKind: model.Kind.String(),
}
state := FromServiceState(model.State)
if state == ServiceStateNormal {
s.State = model.ServiceType.String()
} else {
s.State = state.String()
}
switch model.Kind {
case service.AIService:
provider := auto.UUID(model.AdditionalConfig["provider"])
s.Provider = &provider
s.ProviderType = "local"
if provider.Id != "ollama" {
s.ProviderType = "online"
}
modelId := model.AdditionalConfig["model"]
if modelId != "" {
s.Model = modelId
}
}
return s
}
+60 -7
View File
@@ -8,6 +8,10 @@ import (
"strings"
"time"
ai_local "github.com/APIParkLab/APIPark/service/ai-local"
model_runtime "github.com/APIParkLab/APIPark/ai-provider/model-runtime"
"github.com/eolinker/eosc/log"
"github.com/APIParkLab/APIPark/resources/access"
@@ -59,6 +63,7 @@ type imlServiceModule struct {
teamService team.ITeamService `autowired:""`
teamMemberService team_member.ITeamMemberService `autowired:""`
tagService tag.ITagService `autowired:""`
localModelService ai_local.ILocalModelService `autowired:""`
serviceTagService service_tag.ITagService `autowired:""`
apiService api.IAPIService `autowired:""`
@@ -126,7 +131,7 @@ func (i *imlServiceModule) searchMyServices(ctx context.Context, teamId string,
return nil, err
}
condition["team"] = teamId
return i.serviceService.Search(ctx, keyword, condition, "update_at desc")
return i.serviceService.Search(ctx, keyword, condition, "create_at desc")
} else {
membersForUser, err := i.teamMemberService.FilterMembersForUser(ctx, userID)
if err != nil {
@@ -134,7 +139,7 @@ func (i *imlServiceModule) searchMyServices(ctx context.Context, teamId string,
}
teamIds := membersForUser[userID]
condition["team"] = teamIds
return i.serviceService.Search(ctx, keyword, condition, "update_at desc")
return i.serviceService.Search(ctx, keyword, condition, "create_at desc")
}
}
@@ -220,6 +225,25 @@ func (i *imlServiceModule) Get(ctx context.Context, id string) (*service_dto.Ser
s.Tags = auto.List(utils.SliceToSlice(tags, func(p *service_tag.Tag) string {
return p.Tid
}))
if s.Model == "" {
switch s.ProviderType {
case "online":
p, has := model_runtime.GetProvider(s.Provider.Id)
if has {
m, has := p.DefaultModel(model_runtime.ModelTypeLLM)
if has {
s.Model = m.ID()
}
}
case "local":
info, err := i.localModelService.DefaultModel(ctx)
if err != nil {
return nil, err
}
s.Model = info.Id
}
}
serviceModelMapping, err := i.serviceModelMappingService.GetByService(ctx, id)
if err != nil {
@@ -239,9 +263,9 @@ func (i *imlServiceModule) Search(ctx context.Context, teamID string, keyword st
if err != nil {
return nil, err
}
list, err = i.serviceService.Search(ctx, keyword, map[string]interface{}{"team": teamID, "as_server": true}, "update_at desc")
list, err = i.serviceService.Search(ctx, keyword, map[string]interface{}{"team": teamID, "as_server": true}, "create_at desc")
} else {
list, err = i.serviceService.Search(ctx, keyword, map[string]interface{}{"as_server": true}, "update_at desc")
list, err = i.serviceService.Search(ctx, keyword, map[string]interface{}{"as_server": true}, "create_at desc")
}
if err != nil {
return nil, err
@@ -277,8 +301,16 @@ func toServiceItem(model *service.Service) *service_dto.ServiceItem {
Team: auto.UUID(model.Team),
ServiceKind: model.Kind.String(),
}
state := service_dto.FromServiceState(model.State)
if state == service_dto.ServiceStateNormal {
item.State = model.ServiceType.String()
} else {
item.State = state.String()
}
switch model.Kind {
case service.RestService:
item.State = model.ServiceType.String()
return item
case service.AIService:
provider := auto.UUID(model.AdditionalConfig["provider"])
@@ -293,6 +325,13 @@ func (i *imlServiceModule) Create(ctx context.Context, teamID string, input *ser
if input.Id == "" {
input.Id = uuid.New().String()
}
if teamID == "" {
item, err := i.teamService.DefaultTeam(ctx)
if err != nil {
return nil, err
}
teamID = item.Id
}
mo := &service.Create{
Id: input.Id,
Name: input.Name,
@@ -302,6 +341,7 @@ func (i *imlServiceModule) Create(ctx context.Context, teamID string, input *ser
Catalogue: input.Catalogue,
Prefix: input.Prefix,
Logo: input.Logo,
State: service_dto.ServiceState(input.State).Int(),
ApprovalType: service.ApprovalType(input.ApprovalType),
AdditionalConfig: make(map[string]string),
Kind: service.Kind(input.Kind),
@@ -315,6 +355,11 @@ func (i *imlServiceModule) Create(ctx context.Context, teamID string, input *ser
return nil, fmt.Errorf("ai service: provider can not be empty")
}
mo.AdditionalConfig["provider"] = *input.Provider
if input.Model == nil {
return nil, fmt.Errorf("ai service: model can not be empty")
}
mo.AdditionalConfig["model"] = *input.Model
}
if input.AsApp == nil {
// 默认值为false
@@ -374,6 +419,9 @@ func (i *imlServiceModule) Edit(ctx context.Context, id string, input *service_d
if input.Provider != nil {
info.AdditionalConfig["provider"] = *input.Provider
}
if input.Model != nil {
info.AdditionalConfig["model"] = *input.Model
}
}
err = i.transaction.Transaction(ctx, func(ctx context.Context) error {
serviceType := (*service.ServiceType)(input.ServiceType)
@@ -386,8 +434,7 @@ func (i *imlServiceModule) Edit(ctx context.Context, id string, input *service_d
if input.ApprovalType != nil {
approvalType = service.ApprovalType(*input.ApprovalType)
}
err = i.serviceService.Save(ctx, id, &service.Edit{
editCfg := &service.Edit{
Name: input.Name,
Description: input.Description,
Logo: input.Logo,
@@ -395,7 +442,13 @@ func (i *imlServiceModule) Edit(ctx context.Context, id string, input *service_d
Catalogue: input.Catalogue,
AdditionalConfig: &info.AdditionalConfig,
ApprovalType: &approvalType,
})
}
if input.State != nil {
state := service_dto.ServiceState(*input.State).Int()
editCfg.State = &state
}
err = i.serviceService.Save(ctx, id, editCfg)
if err != nil {
return err
}
+26 -7
View File
@@ -6,14 +6,24 @@ import (
)
type InputSetting struct {
InvokeAddress string `json:"invoke_address" key:"system.node.invoke_address"`
SitePrefix string `json:"site_prefix" key:"system.setting.site_prefix"`
InvokeAddress *string `json:"invoke_address" key:"system.node.invoke_address"`
SitePrefix *string `json:"site_prefix" key:"system.setting.site_prefix"`
OllamaAddress *string `json:"ollama_address" key:"system.ai_model.ollama_address"`
}
func (i *InputSetting) Validate() error {
_, err := url.Parse(i.InvokeAddress)
if err != nil {
return err
if i.InvokeAddress != nil {
_, err := url.Parse(*i.InvokeAddress)
if err != nil {
return err
}
}
if i.OllamaAddress != nil {
_, err := url.Parse(*i.OllamaAddress)
if err != nil {
return err
}
}
return nil
}
@@ -31,9 +41,18 @@ func ToKeyMap(i interface{}) map[string]string {
{
for i := 0; i < typ.NumField(); i++ {
f := typ.Field(i)
if f.Tag.Get("key") != "" {
result[f.Tag.Get("key")] = val.Field(i).String()
v := val.Field(i)
if f.Type.Kind() == reflect.Ptr {
if v.IsNil() {
continue
}
v = v.Elem()
}
if f.Tag.Get("key") != "" {
result[f.Tag.Get("key")] = v.String()
}
}
}
}
+4 -2
View File
@@ -6,9 +6,11 @@ import (
)
func TestMap(t *testing.T) {
invokeAddress := "http://127.0.0.1:8080"
ollamaAddress := "http://127.0.0.1:8081"
input := &InputSetting{
InvokeAddress: "http://127.0.0.1:8080",
InvokeAddress: &invokeAddress,
OllamaAddress: &ollamaAddress,
}
err := input.Validate()
if err != nil {
+1
View File
@@ -8,6 +8,7 @@ import (
type Setting struct {
InvokeAddress string `json:"invoke_address" key:"system.node.invoke_address"`
SitePrefix string `json:"site_prefix" key:"system.setting.site_prefix"`
OllamaAddress string `json:"ollama_address" key:"system.ai_model.ollama_address"`
}
func MapStringToStruct[T any](m map[string]string) *T {
+20
View File
@@ -3,6 +3,11 @@ package system
import (
"context"
"github.com/eolinker/go-common/server"
ai_provider_local "github.com/APIParkLab/APIPark/ai-provider/local"
"github.com/eolinker/go-common/register"
"github.com/eolinker/go-common/store"
"github.com/eolinker/go-common/utils"
@@ -43,6 +48,21 @@ func (i *imlSettingModule) Set(ctx context.Context, input *system_dto.InputSetti
return err
}
}
if input.OllamaAddress != nil {
ai_provider_local.ResetOllamaAddress(*input.OllamaAddress)
}
return nil
})
}
func (i *imlSettingModule) OnInit() {
register.Handle(func(v server.Server) {
ctx := context.Background()
address, has := i.settingService.Get(ctx, "system.ai_model.ollama_address")
if has {
ai_provider_local.ResetOllamaAddress(address)
}
})
}