diff --git a/controller/ai-api/iml.go b/controller/ai-api/iml.go index d66a77a3..05e9baca 100644 --- a/controller/ai-api/iml.go +++ b/controller/ai-api/iml.go @@ -48,10 +48,14 @@ func (i *imlAPIController) Create(ctx *gin.Context, serviceId string, input *ai_ } } if input.AiModel != nil { + provider := "ollama" + if input.AiModel.Type != "local" { + provider = input.AiModel.Provider + } plugins["ai_formatter"] = api.PluginSetting{ Config: plugin_model.ConfigType{ "model": input.AiModel.Id, - "provider": input.AiModel.Provider, + "provider": provider, "config": input.AiModel.Config, }, } @@ -102,10 +106,14 @@ func (i *imlAPIController) Edit(ctx *gin.Context, serviceId string, apiId string } //var upstream *string if input.AiModel != nil { + provider := "ollama" + if input.AiModel.Type != "local" { + provider = input.AiModel.Provider + } proxy.Plugins["ai_formatter"] = api.PluginSetting{ Config: plugin_model.ConfigType{ "model": input.AiModel.Id, - "provider": input.AiModel.Provider, + "provider": provider, "config": input.AiModel.Config, }, } diff --git a/controller/service/iml.go b/controller/service/iml.go index 803d0318..eb1e297e 100644 --- a/controller/service/iml.go +++ b/controller/service/iml.go @@ -8,6 +8,10 @@ import ( "strings" "time" + subscribe_dto "github.com/APIParkLab/APIPark/module/subscribe/dto" + + "github.com/APIParkLab/APIPark/module/subscribe" + ai_local "github.com/APIParkLab/APIPark/module/ai-local" ai_dto "github.com/APIParkLab/APIPark/module/ai/dto" @@ -64,11 +68,13 @@ var ( type imlServiceController struct { module service.IServiceModule `autowired:""` docModule service.IServiceDocModule `autowired:""` + subscribeModule subscribe.ISubscribeModule `autowired:""` aiAPIModule ai_api.IAPIModule `autowired:""` routerModule router.IRouterModule `autowired:""` apiDocModule api_doc.IAPIDocModule `autowired:""` providerModule ai.IProviderModule `autowired:""` aiLocalModel ai_local.ILocalModelModule `autowired:""` + appModule service.IAppModule `autowired:""` upstreamModule upstream.IUpstreamModule `autowired:""` settingModule system.ISettingModule `autowired:""` teamModule team.ITeamModule `autowired:""` @@ -86,27 +92,26 @@ func (i *imlServiceController) QuickCreateAIService(ctx *gin.Context, input *ser if err != nil { return err } - pv, err := i.providerModule.Provider(ctx, input.Provider) - if err != nil { - return err - } - p, has := model_runtime.GetProvider(input.Provider) - if !has { - return fmt.Errorf("provider not found") - } - m, has := p.GetModel(pv.DefaultLLM) - if !has { - return fmt.Errorf("model %s not found", pv.DefaultLLM) - } + //pv, err := i.providerModule.Provider(ctx, input.Provider) + //if err != nil { + // return err + //} + //p, has := model_runtime.GetProvider(input.Provider) + //if !has { + // return fmt.Errorf("provider not found") + //} + //m, has := p.GetModel(pv.DefaultLLM) + //if !has { + // return fmt.Errorf("model %s not found", pv.DefaultLLM) + //} - var info *service_dto.Service id := uuid.NewString() prefix := fmt.Sprintf("/%s", id[:8]) catalogueInfo, err := i.catalogueModule.DefaultCatalogue(ctx) if err != nil { return err } - info, err = i.module.Create(ctx, input.Team, &service_dto.CreateService{ + _, err = i.createAIService(ctx, input.Team, &service_dto.CreateService{ Id: uuid.NewString(), Name: input.Provider + " AI Service", Prefix: prefix, @@ -118,82 +123,92 @@ func (i *imlServiceController) QuickCreateAIService(ctx *gin.Context, input *ser Provider: &input.Provider, Kind: "ai", }) - if err != nil { - return err - } + return err + //info, err = i.module.Create(ctx, input.Team, &service_dto.CreateService{ + // Id: uuid.NewString(), + // Name: input.Provider + " AI Service", + // Prefix: prefix, + // Description: "Quick create by AI provider", + // ServiceType: "public", + // State: "normal", + // Catalogue: catalogueInfo.Id, + // ApprovalType: "auto", + // Provider: &input.Provider, + // Kind: "ai", + //}) + //if err != nil { + // return err + //} + // + //path := fmt.Sprintf("%s/chat", prefix) + //timeout := 300000 + //retry := 0 + //aiPrompt := &ai_api_dto.AiPrompt{ + // Variables: []*ai_api_dto.AiPromptVariable{}, + // Prompt: "", + //} + //aiModel := &ai_api_dto.AiModel{ + // Id: m.ID(), + // Config: m.DefaultConfig(), + // Provider: input.Provider, + //} + //name := "Demo AI API" + //description := "A demo that shows you how to use a e a Chat" + //apiId := uuid.New().String() + //err = i.aiAPIModule.Create( + // ctx, + // info.Id, + // &ai_api_dto.CreateAPI{ + // Id: apiId, + // Name: name, + // Path: path, + // Description: description, + // Disable: false, + // AiPrompt: aiPrompt, + // AiModel: aiModel, + // Timeout: timeout, + // Retry: retry, + // }, + //) + //if err != nil { + // return err + //} + //plugins := make(map[string]api.PluginSetting) + //plugins["ai_prompt"] = api.PluginSetting{ + // Config: plugin_model.ConfigType{ + // "prompt": aiPrompt.Prompt, + // "variables": aiPrompt.Variables, + // }, + //} + //plugins["ai_formatter"] = api.PluginSetting{ + // Config: plugin_model.ConfigType{ + // "model": aiModel.Id, + // "provider": info.Provider.Id, + // "config": aiModel.Config, + // }, + //} + //_, err = i.routerModule.Create(ctx, info.Id, &router_dto.Create{ + // Id: apiId, + // Name: name, + // Path: path, + // Methods: []string{ + // http.MethodPost, + // }, + // Description: description, + // Protocols: []string{"http", "https"}, + // MatchRules: nil, + // Proxy: &router_dto.InputProxy{ + // Path: path, + // Timeout: timeout, + // Retry: retry, + // Plugins: plugins, + // }, + // Disable: false, + //}) + //if err != nil { + // return err + //} - path := fmt.Sprintf("%s/chat", prefix) - timeout := 300000 - retry := 0 - aiPrompt := &ai_api_dto.AiPrompt{ - Variables: []*ai_api_dto.AiPromptVariable{}, - Prompt: "", - } - aiModel := &ai_api_dto.AiModel{ - Id: m.ID(), - Config: m.DefaultConfig(), - Provider: input.Provider, - } - name := "Demo AI API" - description := "A demo that shows you how to use a e a Chat" - apiId := uuid.New().String() - err = i.aiAPIModule.Create( - ctx, - info.Id, - &ai_api_dto.CreateAPI{ - Id: apiId, - Name: name, - Path: path, - Description: description, - Disable: false, - AiPrompt: aiPrompt, - AiModel: aiModel, - Timeout: timeout, - Retry: retry, - }, - ) - if err != nil { - return err - } - plugins := make(map[string]api.PluginSetting) - plugins["ai_prompt"] = api.PluginSetting{ - Config: plugin_model.ConfigType{ - "prompt": aiPrompt.Prompt, - "variables": aiPrompt.Variables, - }, - } - plugins["ai_formatter"] = api.PluginSetting{ - Config: plugin_model.ConfigType{ - "model": aiModel.Id, - "provider": info.Provider.Id, - "config": aiModel.Config, - }, - } - _, err = i.routerModule.Create(ctx, info.Id, &router_dto.Create{ - Id: apiId, - Name: name, - Path: path, - Methods: []string{ - http.MethodPost, - }, - Description: description, - Protocols: []string{"http", "https"}, - MatchRules: nil, - Proxy: &router_dto.InputProxy{ - Path: path, - Timeout: timeout, - Retry: retry, - Plugins: plugins, - }, - Disable: false, - }) - if err != nil { - return err - } - - return i.docModule.SaveServiceDoc(ctx, info.Id, &service_dto.SaveServiceDoc{ - Doc: "", - }) }) } @@ -264,6 +279,16 @@ func (i *imlServiceController) QuickCreateRestfulService(ctx *gin.Context) error if err != nil { return err } + apps, err := i.appModule.Search(ctx, teamId, "") + if err != nil { + return err + } + for _, app := range apps { + i.subscribeModule.AddSubscriber(ctx, id, &subscribe_dto.AddSubscriber{ + Application: app.Id, + }) + } + return nil }) } @@ -512,6 +537,15 @@ func (i *imlServiceController) createAIService(ctx *gin.Context, teamID string, if err != nil { return err } + apps, err := i.appModule.Search(ctx, info.Team.Id, "") + if err != nil { + return err + } + for _, app := range apps { + i.subscribeModule.AddSubscriber(ctx, info.Id, &subscribe_dto.AddSubscriber{ + Application: app.Id, + }) + } return i.docModule.SaveServiceDoc(ctx, info.Id, &service_dto.SaveServiceDoc{ Doc: "", diff --git a/module/ai-balance/iml.go b/module/ai-balance/iml.go index 18414162..d2ddd4f5 100644 --- a/module/ai-balance/iml.go +++ b/module/ai-balance/iml.go @@ -39,6 +39,13 @@ type imlBalanceModule struct { } func (i *imlBalanceModule) Create(ctx context.Context, input *ai_balance_dto.Create) error { + has, err := i.balanceService.Exist(ctx, input.Provider, input.Model) + if err != nil { + return err + } + if has { + return fmt.Errorf("model already exists") + } priority, err := i.balanceService.MaxPriority(ctx) if err != nil { if !errors.Is(err, gorm.ErrRecordNotFound) { @@ -64,19 +71,52 @@ func (i *imlBalanceModule) Create(ctx context.Context, input *ai_balance_dto.Cre providerName = "Ollama" modelName = input.Model } - return i.balanceService.Create(ctx, &ai_balance.Create{ - Id: input.Id, - Priority: priority + 1, - Provider: input.Provider, - ProviderName: providerName, - Model: input.Model, - ModelName: modelName, - Type: ai_balance_dto.ModelType(input.Type).Int(), + return i.transaction.Transaction(ctx, func(ctx context.Context) error { + err = i.balanceService.Create(ctx, &ai_balance.Create{ + Id: input.Id, + Priority: priority + 1, + Provider: input.Provider, + ProviderName: providerName, + Model: input.Model, + ModelName: modelName, + Type: ai_balance_dto.ModelType(input.Type).Int(), + }) + if err != nil { + return err + } + item, err := i.balanceService.Get(ctx, input.Id) + if err != nil { + return err + } + return i.syncGateway(ctx, cluster.DefaultClusterID, []*gateway.DynamicRelease{newRelease(item)}, true) }) + } +var ( + ollamaConfig = "{\n \"mirostat\": 0,\n \"mirostat_eta\": 0.1,\n \"mirostat_tau\": 5.0,\n \"num_ctx\": 4096,\n \"repeat_last_n\":64,\n \"repeat_penalty\": 1.1,\n \"temperature\": 0.7,\n \"seed\": 42,\n \"num_predict\": 42,\n \"top_k\": 40,\n \"top_p\": 0.9,\n \"min_p\": 0.5\n}\n" + ollamaBase = "http://apipark-ollama:11434" +) + func newRelease(item *ai_balance.Balance) *gateway.DynamicRelease { - return &gateway.DynamicRelease{} + + cfg := make(map[string]interface{}) + cfg["provider"] = item.Id + cfg["model"] = item.Model + cfg["model_config"] = ollamaConfig + cfg["base"] = ollamaBase + return &gateway.DynamicRelease{ + BasicItem: &gateway.BasicItem{ + ID: item.Id, + Description: item.ModelName, + Resource: "ai-provider", + Version: item.UpdateAt.Format("20060102150405"), + MatchLabels: map[string]string{ + "module": "ai-provider", + }, + }, + Attr: cfg, + } } func (i *imlBalanceModule) Sort(ctx context.Context, input *ai_balance_dto.Sort) error { @@ -91,11 +131,13 @@ func (i *imlBalanceModule) Sort(ctx context.Context, input *ai_balance_dto.Sort) if err != nil { return err } + releases := make([]*gateway.DynamicRelease, 0, len(list)) for _, item := range list { - err = i.syncGateway(ctx, cluster.DefaultClusterID, []*gateway.DynamicRelease{newRelease(item)}, true) - if err != nil { - return err - } + releases = append(releases, newRelease(item)) + } + err = i.syncGateway(ctx, cluster.DefaultClusterID, releases, true) + if err != nil { + return err } return nil } @@ -140,11 +182,23 @@ func (i *imlBalanceModule) List(ctx context.Context, keyword string) ([]*ai_bala } func (i *imlBalanceModule) Delete(ctx context.Context, id string) error { - return i.balanceService.Delete(ctx, id) + return i.transaction.Transaction(ctx, func(ctx context.Context) error { + err := i.balanceService.Delete(ctx, id) + if err != nil { + return err + } + return i.syncGateway(ctx, cluster.DefaultClusterID, []*gateway.DynamicRelease{ + { + BasicItem: &gateway.BasicItem{ + ID: id, + }, + }, + }, false) + }) + } func (i *imlBalanceModule) syncGateway(ctx context.Context, clusterId string, releases []*gateway.DynamicRelease, online bool) error { - return nil client, err := i.clusterService.GatewayClient(ctx, clusterId) if err != nil { log.Errorf("get apinto client error: %v", err) diff --git a/module/ai-local/iml.go b/module/ai-local/iml.go index 8380a2f3..c02a356f 100644 --- a/module/ai-local/iml.go +++ b/module/ai-local/iml.go @@ -4,8 +4,16 @@ import ( "context" "errors" "fmt" + "net/url" "strings" + "github.com/eolinker/eosc/env" + + "github.com/APIParkLab/APIPark/gateway" + "github.com/eolinker/eosc/log" + + "github.com/APIParkLab/APIPark/service/cluster" + "github.com/APIParkLab/APIPark/service/service" "github.com/eolinker/go-common/auto" @@ -35,6 +43,7 @@ type imlLocalModel struct { localModelService ai_local.ILocalModelService `autowired:""` localModelPackageService ai_local.ILocalModelPackageService `autowired:""` localModelStateService ai_local.ILocalModelInstallStateService `autowired:""` + clusterService cluster.IClusterService `autowired:""` aiAPIService ai_api.IAPIService `autowired:""` serviceService service.IServiceService `autowired:""` transaction store.ITransaction `autowired:""` @@ -42,8 +51,21 @@ type imlLocalModel struct { var ( ollamaConfig = "{\n \"mirostat\": 0,\n \"mirostat_eta\": 0.1,\n \"mirostat_tau\": 5.0,\n \"num_ctx\": 4096,\n \"repeat_last_n\":64,\n \"repeat_penalty\": 1.1,\n \"temperature\": 0.7,\n \"seed\": 42,\n \"num_predict\": 42,\n \"top_k\": 40,\n \"top_p\": 0.9,\n \"min_p\": 0.5\n}\n" + ollamaBase = "http://apipark-ollama:11434" ) +func init() { + base, has := env.GetEnv("OLLAMA_BASE") + if !has { + return + } + _, err := url.Parse(base) + if err == nil { + ollamaBase = base + } + +} + func (i *imlLocalModel) SimpleList(ctx context.Context) ([]*ai_local_dto.SimpleItem, error) { list, err := i.localModelService.List(ctx) if err != nil { @@ -135,6 +157,7 @@ func (i *imlLocalModel) ListCanInstall(ctx context.Context, keyword string) ([]* func (i *imlLocalModel) pullHook() func(msg ai_provider_local.PullMessage) error { return func(msg ai_provider_local.PullMessage) error { return i.transaction.Transaction(context.Background(), func(ctx context.Context) error { + state := ai_local_dto.DeployStateFinish.Int() modelState := ai_local_dto.LocalModelStateNormal.Int() if msg.Status == "error" { @@ -151,7 +174,7 @@ func (i *imlLocalModel) pullHook() func(msg ai_provider_local.PullMessage) error return err } - return i.localModelStateService.Create(ctx, &ai_local.CreateLocalModelInstallState{ + err = i.localModelStateService.Create(ctx, &ai_local.CreateLocalModelInstallState{ Id: msg.Model, Complete: msg.Completed, Total: msg.Total, @@ -159,30 +182,83 @@ func (i *imlLocalModel) pullHook() func(msg ai_provider_local.PullMessage) error Msg: msg.Msg, }) - } - if info.Complete < msg.Completed { - info.Complete = msg.Completed + } else { + if info.Complete < msg.Completed { + info.Complete = msg.Completed + } + if info.Total < msg.Total { + info.Total = msg.Total + } + if msg.Msg != "" { + info.Msg = msg.Msg + } + err = i.localModelStateService.Save(ctx, msg.Model, &ai_local.EditLocalModelInstallState{State: &state, Complete: &info.Complete, Total: &info.Total, Msg: &info.Msg}) + if err != nil { + return err + } + serviceState := 0 + if msg.Status == "error" { + state = 2 + } + err = i.serviceService.Save(ctx, msg.Model, &service.Edit{State: &serviceState}) } - if info.Total < msg.Total { - info.Total = msg.Total - } - if msg.Msg != "" { - info.Msg = msg.Msg - } - err = i.localModelStateService.Save(ctx, msg.Model, &ai_local.EditLocalModelInstallState{State: &state, Complete: &info.Complete, Total: &info.Total, Msg: &info.Msg}) if err != nil { return err } - serviceState := 0 - if msg.Status == "error" { - state = 2 - } - return i.serviceService.Save(ctx, msg.Model, &service.Edit{State: &serviceState}) + cfg := make(map[string]interface{}) + cfg["provider"] = "ollama" + cfg["model"] = msg.Model + cfg["model_config"] = ollamaConfig + cfg["priority"] = 0 + cfg["base"] = ollamaBase + return i.syncGateway(ctx, cluster.DefaultClusterID, []*gateway.DynamicRelease{ + { + BasicItem: &gateway.BasicItem{ + ID: msg.Model, + Description: msg.Model, + Resource: "ai-provider", + Version: info.UpdateAt.Format("20060102150405"), + MatchLabels: map[string]string{ + "module": "ai-provider", + }, + }, + Attr: cfg, + }}, true) }) } } +func (i *imlLocalModel) syncGateway(ctx context.Context, clusterId string, releases []*gateway.DynamicRelease, online bool) error { + client, err := i.clusterService.GatewayClient(ctx, clusterId) + if err != nil { + log.Errorf("get apinto client error: %v", err) + return nil + } + defer func() { + err := client.Close(ctx) + if err != nil { + log.Warn("close apinto client:", err) + } + }() + for _, releaseInfo := range releases { + dynamicClient, err := client.Dynamic(releaseInfo.Resource) + if err != nil { + return err + } + if online { + err = dynamicClient.Online(ctx, releaseInfo) + } else { + err = dynamicClient.Offline(ctx, releaseInfo) + } + if err != nil { + return err + } + } + + return nil +} + func (i *imlLocalModel) Deploy(ctx context.Context, model string, session string) (*ai_provider_local.Pipeline, error) { var p *ai_provider_local.Pipeline err := i.transaction.Transaction(ctx, func(txCtx context.Context) error { @@ -341,3 +417,51 @@ func (i *imlLocalModel) OnInit() { } }) } + +func (i *imlLocalModel) getLocalModels(ctx context.Context) ([]*gateway.DynamicRelease, error) { + list, err := i.localModelService.List(ctx) + if err != nil { + return nil, err + } + releases := make([]*gateway.DynamicRelease, 0, len(list)) + for _, l := range list { + cfg := make(map[string]interface{}) + cfg["provider"] = "ollama" + cfg["model"] = l.Id + cfg["model_config"] = ollamaConfig + cfg["base"] = ollamaBase + releases = append(releases, &gateway.DynamicRelease{ + BasicItem: &gateway.BasicItem{ + ID: l.Id, + Description: l.Name, + Resource: "ai-provider", + Version: l.UpdateAt.Format("20060102150405"), + MatchLabels: map[string]string{ + "module": "ai-provider", + }, + }, + Attr: cfg, + }) + } + return releases, nil +} + +func (i *imlLocalModel) initGateway(ctx context.Context, clusterId string, clientDriver gateway.IClientDriver) error { + releases, err := i.getLocalModels(ctx) + if err != nil { + return err + } + + for _, p := range releases { + client, err := clientDriver.Dynamic(p.Resource) + if err != nil { + return err + } + err = client.Online(ctx, p) + if err != nil { + return err + } + } + + return nil +} diff --git a/module/ai-local/module.go b/module/ai-local/module.go index bf40ff4a..a96172a0 100644 --- a/module/ai-local/module.go +++ b/module/ai-local/module.go @@ -4,6 +4,8 @@ import ( "context" "reflect" + "github.com/APIParkLab/APIPark/gateway" + "github.com/eolinker/go-common/autowire" ai_provider_local "github.com/APIParkLab/APIPark/ai-provider/local" @@ -24,7 +26,9 @@ type ILocalModelModule interface { } func init() { + localModel := new(imlLocalModel) autowire.Auto[ILocalModelModule](func() reflect.Value { - return reflect.ValueOf(&imlLocalModel{}) + gateway.RegisterInitHandleFunc(localModel.initGateway) + return reflect.ValueOf(localModel) }) } diff --git a/module/ai/iml.go b/module/ai/iml.go index 24821923..ed31d2ee 100644 --- a/module/ai/iml.go +++ b/module/ai/iml.go @@ -129,74 +129,6 @@ func (i *imlProviderModule) SimpleProvider(ctx context.Context, id string) (*ai_ }, nil } -//func (i *imlProviderModule) Sort(ctx context.Context, input *ai_dto.Sort) error { -// return i.transaction.Transaction(ctx, func(txCtx context.Context) error { -// list, err := i.providerService.List(ctx) -// if err != nil { -// return err -// } -// providerMap := utils.SliceToMap(list, func(e *ai.Provider) string { -// return e.Id -// }) -// releases := make([]*gateway.DynamicRelease, 0, len(list)) -// offlineReleases := make([]*gateway.DynamicRelease, 0, len(list)) -// for index, id := range input.Providers { -// p, has := model_runtime.GetProvider(id) -// if !has { -// continue -// } -// -// l, has := providerMap[id] -// if !has { -// continue -// } -// model, has := p.GetModel(l.DefaultLLM) -// if !has { -// continue -// } -// priority := index + 1 -// err = i.providerService.Save(txCtx, id, &ai.SetProvider{ -// Priority: &priority, -// }) -// if err != nil { -// return err -// } -// if ai_dto.ToProviderStatus(l.Status) == ai_dto.ProviderDisabled { -// offlineReleases = append(offlineReleases, &gateway.DynamicRelease{ -// BasicItem: &gateway.BasicItem{ -// ID: l.Id, -// Resource: "ai-provider", -// }}) -// } else { -// cfg := make(map[string]interface{}) -// cfg["provider"] = l.Id -// cfg["model"] = l.DefaultLLM -// cfg["model_config"] = model.DefaultConfig() -// cfg["priority"] = l.Priority -// cfg["base"] = fmt.Sprintf("%s://%s", p.URI().Scheme(), p.URI().Host()) -// releases = append(releases, &gateway.DynamicRelease{ -// BasicItem: &gateway.BasicItem{ -// ID: l.Id, -// Description: l.Name, -// Resource: "ai-provider", -// Version: l.UpdateAt.Format("20060102150405"), -// MatchLabels: map[string]string{ -// "module": "ai-provider", -// }, -// }, -// Attr: cfg, -// }) -// } -// } -// err = i.syncGateway(ctx, cluster.DefaultClusterID, releases, true) -// if err != nil { -// return err -// } -// return i.syncGateway(ctx, cluster.DefaultClusterID, offlineReleases, false) -// -// }) -//} - func (i *imlProviderModule) ConfiguredProviders(ctx context.Context, keyword string) ([]*ai_dto.ConfiguredProviderItem, error) { // 获取已配置的AI服务商 list, err := i.providerService.Search(ctx, keyword, nil, "update_at") @@ -611,7 +543,6 @@ func (i *imlProviderModule) UpdateProviderConfig(ctx context.Context, id string, cfg["provider"] = info.Id cfg["model"] = info.DefaultLLM cfg["model_config"] = model.DefaultConfig() - cfg["priority"] = info.Priority cfg["base"] = fmt.Sprintf("%s://%s", p.URI().Scheme(), p.URI().Host()) return i.syncGateway(txCtx, cluster.DefaultClusterID, []*gateway.DynamicRelease{ { @@ -652,7 +583,6 @@ func (i *imlProviderModule) getAiProviders(ctx context.Context) ([]*gateway.Dyna cfg["provider"] = l.Id cfg["model"] = l.DefaultLLM cfg["model_config"] = model.DefaultConfig() - cfg["priority"] = l.Priority providers = append(providers, &gateway.DynamicRelease{ BasicItem: &gateway.BasicItem{ ID: l.Id, diff --git a/module/service/dto/input.go b/module/service/dto/input.go index 4a4229ed..6cfab3da 100644 --- a/module/service/dto/input.go +++ b/module/service/dto/input.go @@ -31,7 +31,7 @@ type EditService struct { Catalogue *string `json:"catalogue"` Logo *string `json:"logo"` Tags *[]string `json:"tags"` - Provider *string `json:"provider" aocheck:"ai_provider"` + Provider *string `json:"provider"` ApprovalType *string `json:"approval_type"` State *string `json:"state"` } diff --git a/service/ai-balance/iml.go b/service/ai-balance/iml.go index 34f2dd3c..6adb1839 100644 --- a/service/ai-balance/iml.go +++ b/service/ai-balance/iml.go @@ -2,10 +2,13 @@ package ai_balance import ( "context" + "errors" "fmt" "sort" "time" + "gorm.io/gorm" + "github.com/eolinker/go-common/store" "github.com/APIParkLab/APIPark/service/universally" @@ -23,6 +26,17 @@ type imlBalanceService struct { universally.IServiceDelete } +func (i *imlBalanceService) Exist(ctx context.Context, provider string, model string) (bool, error) { + _, err := i.store.First(ctx, map[string]interface{}{"provider": provider, "model": model}) + if err != nil { + if !errors.Is(err, gorm.ErrRecordNotFound) { + return false, err + } + return false, nil + } + return true, nil +} + func (i *imlBalanceService) OnComplete() { i.IServiceGet = universally.NewGet[Balance, ai.Balance](i.store, FromEntity) i.IServiceCreate = universally.NewCreator[Create, ai.Balance](i.store, "ai_balance", createEntityHandler, uniquestHandler, labelHandler) diff --git a/service/ai-balance/service.go b/service/ai-balance/service.go index 13a52400..b4ea9926 100644 --- a/service/ai-balance/service.go +++ b/service/ai-balance/service.go @@ -17,6 +17,7 @@ type IBalanceService interface { MaxPriority(ctx context.Context) (int, error) SortBefore(ctx context.Context, originID string, targetID string) ([]*Balance, error) SortAfter(ctx context.Context, originID string, targetID string) ([]*Balance, error) + Exist(ctx context.Context, provider string, model string) (bool, error) } func init() {