diff --git a/controller/ai-local/iml.go b/controller/ai-local/iml.go index aa722518..e436609a 100644 --- a/controller/ai-local/iml.go +++ b/controller/ai-local/iml.go @@ -7,11 +7,15 @@ import ( "io" "math" "net/http" + "net/url" "strings" + "time" - system_dto "github.com/APIParkLab/APIPark/module/system/dto" + ai_balance "github.com/APIParkLab/APIPark/module/ai-balance" "github.com/APIParkLab/APIPark/module/system" + system_dto "github.com/APIParkLab/APIPark/module/system/dto" + ollama_api "github.com/ollama/ollama/api" "github.com/APIParkLab/APIPark/module/subscribe" subscribe_dto "github.com/APIParkLab/APIPark/module/subscribe/dto" @@ -51,6 +55,7 @@ type imlLocalModelController struct { serviceModule service.IServiceModule `autowired:""` catalogueModule catalogue.ICatalogueModule `autowired:""` aiAPIModule ai_api.IAPIModule `autowired:""` + aiBalanceModule ai_balance.IBalanceModule `autowired:""` appModule service.IAppModule `autowired:""` routerModule router.IRouterModule `autowired:""` subscribeModule subscribe.ISubscribeModule `autowired:""` @@ -66,9 +71,35 @@ func (i *imlLocalModelController) OllamaConfig(ctx *gin.Context) (*ai_local_dto. }, nil } +var ( + client = &http.Client{ + Timeout: 2 * time.Second, + } +) + func (i *imlLocalModelController) OllamaConfigUpdate(ctx *gin.Context, input *ai_local_dto.OllamaConfig) error { - return i.settingModule.Set(ctx, &system_dto.InputSetting{ - OllamaAddress: &input.Address, + u, err := url.Parse(input.Address) + if err != nil { + return nil + } + ollamaClient := ollama_api.NewClient(u, client) + _, err = ollamaClient.Version(ctx) + if err != nil { + return err + } + return i.transaction.Transaction(ctx, func(ctx context.Context) error { + err = i.module.SyncLocalModels(ctx, input.Address) + if err != nil { + return err + } + err = i.aiBalanceModule.SyncLocalBalances(ctx, input.Address) + if err != nil { + return err + } + + return i.settingModule.Set(ctx, &system_dto.InputSetting{ + OllamaAddress: &input.Address, + }) }) } diff --git a/controller/service/iml.go b/controller/service/iml.go index 22862da5..e2f0bf32 100644 --- a/controller/service/iml.go +++ b/controller/service/iml.go @@ -477,13 +477,13 @@ func (i *imlServiceController) Create(ctx *gin.Context, teamID string, input *se } var err error var info *service_dto.Service - err = i.transaction.Transaction(ctx, func(txCtx context.Context) error { - info, err = i.module.Create(txCtx, teamID, input) + err = i.transaction.Transaction(ctx, func(ctx context.Context) error { + info, err = i.module.Create(ctx, teamID, input) if err != nil { return err } path := fmt.Sprintf("/%s/", strings.Trim(input.Prefix, "/")) - _, err = i.routerModule.Create(txCtx, info.Id, &router_dto.Create{ + _, err = i.routerModule.Create(ctx, info.Id, &router_dto.Create{ Id: uuid.New().String(), Name: "", Path: path + "*", @@ -499,6 +499,15 @@ func (i *imlServiceController) Create(ctx *gin.Context, teamID string, input *se }, Disable: false, }) + apps, err := i.appModule.Search(ctx, teamID, "") + if err != nil { + return err + } + for _, app := range apps { + i.subscribeModule.AddSubscriber(ctx, info.Id, &subscribe_dto.AddSubscriber{ + Application: app.Id, + }) + } return err }) return info, err diff --git a/controller/system/iml.go b/controller/system/iml.go index 7c020379..44f9e37e 100644 --- a/controller/system/iml.go +++ b/controller/system/iml.go @@ -10,6 +10,8 @@ import ( "strings" "time" + subscribe_dto "github.com/APIParkLab/APIPark/module/subscribe/dto" + "github.com/eolinker/eosc/log" ai_dto "github.com/APIParkLab/APIPark/module/ai/dto" @@ -222,6 +224,7 @@ type imlInitController struct { applicationAuthorizationModule application_authorization.IAuthorizationModule `autowired:""` catalogueModule catalogue.ICatalogueModule `autowired:""` providerModule ai.IProviderModule `autowired:""` + subscribeModule subscribe.ISubscribeModule `autowired:""` transaction store.ITransaction `autowired:""` aiAPIModule ai_api.IAPIModule `autowired:""` docModule service.IServiceDocModule `autowired:""` @@ -264,6 +267,13 @@ func (i *imlInitController) OnInit() { if err != nil { return fmt.Errorf("create default team error: %v", err) } + app, err := i.appModule.CreateApp(ctx, info.Id, &service_dto.CreateApp{ + Name: "Demo Application", + Description: "Auto created By APIPark", + }) + if err != nil { + return fmt.Errorf("create default app error: %v", err) + } // 创建Rest服务 restPath := "/rest-demo" serviceInfo, err := i.serviceModule.Create(ctx, info.Id, &service_dto.CreateService{ @@ -298,6 +308,13 @@ func (i *imlInitController) OnInit() { if err != nil { return fmt.Errorf("create default router error: %v", err) } + err = i.subscribeModule.AddSubscriber(ctx, serviceInfo.Id, &subscribe_dto.AddSubscriber{ + Application: app.Id, + }) + if err != nil { + return err + } + // 创建AI服务 err = i.createAIService(ctx, info.Id, &service_dto.CreateService{ Name: "AI Demo Service", @@ -307,17 +324,11 @@ func (i *imlInitController) OnInit() { Catalogue: catalogueId, ApprovalType: "auto", Kind: "ai", - }) + }, app.Id) if err != nil { return err } - app, err := i.appModule.CreateApp(ctx, info.Id, &service_dto.CreateApp{ - Name: "Demo Application", - Description: "Auto created By APIPark", - }) - if err != nil { - return fmt.Errorf("create default app error: %v", err) - } + _, err = i.applicationAuthorizationModule.AddAuthorization(ctx, app.Id, &application_authorization_dto.CreateAuthorization{ Name: "Default API Key", Driver: "apikey", @@ -338,7 +349,7 @@ func (i *imlInitController) OnInit() { } }) } -func (i *imlInitController) createAIService(ctx context.Context, teamID string, input *service_dto.CreateService) error { +func (i *imlInitController) createAIService(ctx context.Context, teamID string, input *service_dto.CreateService, appId string) error { providerId := "fakegpt" err := i.providerModule.UpdateProviderConfig(ctx, providerId, &ai_dto.UpdateConfig{ @@ -469,6 +480,12 @@ func (i *imlInitController) createAIService(ctx context.Context, teamID string, if err != nil { return err } + err = i.subscribeModule.AddSubscriber(ctx, info.Id, &subscribe_dto.AddSubscriber{ + Application: appId, + }) + if err != nil { + return err + } return i.docModule.SaveServiceDoc(ctx, info.Id, &service_dto.SaveServiceDoc{ Doc: "The Translation API allows developers to translate text from one language to another. It supports multiple languages and enables easy integration of high-quality translation features into applications. With simple API requests, you can quickly translate content into different target languages.", diff --git a/module/ai-balance/iml.go b/module/ai-balance/iml.go index d5dc0340..5ab66bc8 100644 --- a/module/ai-balance/iml.go +++ b/module/ai-balance/iml.go @@ -43,6 +43,14 @@ type imlBalanceModule struct { transaction store.ITransaction `autowired:""` } +func (i *imlBalanceModule) SyncLocalBalances(ctx context.Context, address string) error { + releases, err := i.getLocalBalances(ctx, address) + if err != nil { + return err + } + return i.syncGateway(ctx, cluster.DefaultClusterID, releases, true) +} + func (i *imlBalanceModule) Create(ctx context.Context, input *ai_balance_dto.Create) error { has, err := i.balanceService.Exist(ctx, input.Provider, input.Model) if err != nil { @@ -63,6 +71,7 @@ func (i *imlBalanceModule) Create(ctx context.Context, input *ai_balance_dto.Cre } providerName := "" modelName := "" + base := "" switch input.Type { case ai_balance_dto.ModelTypeOnline: p, has := model_runtime.GetProvider(input.Provider) @@ -71,14 +80,16 @@ func (i *imlBalanceModule) Create(ctx context.Context, input *ai_balance_dto.Cre } providerName = p.Name() modelName = input.Model + base = fmt.Sprintf("%s://%s", p.URI().Scheme(), p.URI().Host()) case ai_balance_dto.ModelTypeLocal: input.Provider = "ollama" providerName = "Ollama" modelName = input.Model - } - v, has := i.settingService.Get(ctx, "system.ai_model.ollama_address") - if !has { - return fmt.Errorf("ollama address not found") + v, has := i.settingService.Get(ctx, "system.ai_model.ollama_address") + if !has { + return fmt.Errorf("ollama address not found") + } + base = v } return i.transaction.Transaction(ctx, func(ctx context.Context) error { @@ -98,7 +109,7 @@ func (i *imlBalanceModule) Create(ctx context.Context, input *ai_balance_dto.Cre if err != nil { return err } - return i.syncGateway(ctx, cluster.DefaultClusterID, []*gateway.DynamicRelease{newRelease(item, v)}, true) + return i.syncGateway(ctx, cluster.DefaultClusterID, []*gateway.DynamicRelease{newRelease(item, base)}, true) }) } @@ -106,10 +117,11 @@ func (i *imlBalanceModule) Create(ctx context.Context, input *ai_balance_dto.Cre func newRelease(item *ai_balance.Balance, base string) *gateway.DynamicRelease { cfg := make(map[string]interface{}) - cfg["provider"] = item.Id + cfg["provider"] = item.Provider cfg["model"] = item.Model cfg["model_config"] = ai_provider_local.OllamaConfig cfg["base"] = base + cfg["priority"] = item.Priority return &gateway.DynamicRelease{ BasicItem: &gateway.BasicItem{ ID: item.Id, @@ -142,7 +154,16 @@ func (i *imlBalanceModule) Sort(ctx context.Context, input *ai_balance_dto.Sort) } releases := make([]*gateway.DynamicRelease, 0, len(list)) for _, item := range list { - releases = append(releases, newRelease(item, v)) + base := v + if item.Provider != "ollama" { + p, has := model_runtime.GetProvider(item.Provider) + if !has { + continue + } + base = fmt.Sprintf("%s://%s", p.URI().Scheme(), p.URI().Host()) + } + + releases = append(releases, newRelease(item, base)) } err = i.syncGateway(ctx, cluster.DefaultClusterID, releases, true) if err != nil { @@ -236,3 +257,73 @@ func (i *imlBalanceModule) syncGateway(ctx context.Context, clusterId string, re return nil } + +func (i *imlBalanceModule) getLocalBalances(ctx context.Context, v string) ([]*gateway.DynamicRelease, error) { + balances, err := i.balanceService.Search(ctx, "", map[string]interface{}{"provider": "ollama"}, "priority asc") + if err != nil { + return nil, err + } + if v == "" { + var has bool + v, has = i.settingService.Get(ctx, "system.ai_model.ollama_address") + if !has { + return nil, fmt.Errorf("ollama address not found") + } + } + + releases := make([]*gateway.DynamicRelease, 0, len(balances)) + for _, item := range balances { + base := v + if item.Provider != "ollama" { + p, has := model_runtime.GetProvider(item.Provider) + if !has { + continue + } + base = fmt.Sprintf("%s://%s", p.URI().Scheme(), p.URI().Host()) + } + releases = append(releases, newRelease(item, base)) + } + return releases, nil +} + +func (i *imlBalanceModule) getBalances(ctx context.Context) ([]*gateway.DynamicRelease, error) { + balances, err := i.balanceService.Search(ctx, "", nil, "priority asc") + if err != nil { + return nil, err + } + v, has := i.settingService.Get(ctx, "system.ai_model.ollama_address") + if !has { + return nil, fmt.Errorf("ollama address not found") + } + releases := make([]*gateway.DynamicRelease, 0, len(balances)) + for _, item := range balances { + base := v + if item.Provider != "ollama" { + p, has := model_runtime.GetProvider(item.Provider) + if !has { + continue + } + base = fmt.Sprintf("%s://%s", p.URI().Scheme(), p.URI().Host()) + } + releases = append(releases, newRelease(item, base)) + } + return releases, nil +} + +func (i *imlBalanceModule) initGateway(ctx context.Context, clusterId string, clientDriver gateway.IClientDriver) error { + releases, err := i.getBalances(ctx) + if err != nil { + return err + } + for _, p := range releases { + client, err := clientDriver.Dynamic(p.Resource) + if err != nil { + return err + } + err = client.Online(ctx, p) + if err != nil { + return err + } + } + return nil +} diff --git a/module/ai-balance/module.go b/module/ai-balance/module.go index 51f99a32..a72e6b7f 100644 --- a/module/ai-balance/module.go +++ b/module/ai-balance/module.go @@ -4,6 +4,8 @@ import ( "context" "reflect" + "github.com/APIParkLab/APIPark/gateway" + "github.com/eolinker/go-common/autowire" ai_balance_dto "github.com/APIParkLab/APIPark/module/ai-balance/dto" @@ -14,10 +16,13 @@ type IBalanceModule interface { Sort(ctx context.Context, input *ai_balance_dto.Sort) error List(ctx context.Context, keyword string) ([]*ai_balance_dto.Item, error) Delete(ctx context.Context, id string) error + SyncLocalBalances(ctx context.Context, address string) error } func init() { + balanceModule := new(imlBalanceModule) autowire.Auto[IBalanceModule](func() reflect.Value { - return reflect.ValueOf(new(imlBalanceModule)) + gateway.RegisterInitHandleFunc(balanceModule.initGateway) + return reflect.ValueOf(balanceModule) }) } diff --git a/module/ai-local/iml.go b/module/ai-local/iml.go index 145b62d9..28577413 100644 --- a/module/ai-local/iml.go +++ b/module/ai-local/iml.go @@ -6,6 +6,8 @@ import ( "fmt" "strings" + ai_balance "github.com/APIParkLab/APIPark/service/ai-balance" + "github.com/APIParkLab/APIPark/service/setting" "github.com/APIParkLab/APIPark/service/api" @@ -45,6 +47,7 @@ type imlLocalModel struct { localModelPackageService ai_local.ILocalModelPackageService `autowired:""` localModelStateService ai_local.ILocalModelInstallStateService `autowired:""` localModelCacheService ai_local.ILocalModelCacheService `autowired:""` + balanceService ai_balance.IBalanceService `autowired:""` clusterService cluster.IClusterService `autowired:""` aiAPIService ai_api.IAPIService `autowired:""` routerService api.IAPIService `autowired:""` @@ -53,6 +56,14 @@ type imlLocalModel struct { transaction store.ITransaction `autowired:""` } +func (i *imlLocalModel) SyncLocalModels(ctx context.Context, address string) error { + releases, err := i.getLocalModels(ctx, address) + if err != nil { + return err + } + return i.syncGateway(ctx, cluster.DefaultClusterID, releases, true) +} + func (i *imlLocalModel) SimpleList(ctx context.Context) ([]*ai_local_dto.SimpleItem, error) { list, err := i.localModelService.List(ctx) if err != nil { @@ -177,6 +188,13 @@ func (i *imlLocalModel) pullHook(fn ...func() error) func(msg ai_provider_local. State: state, Msg: msg.Msg, }) + if err != nil { + return err + } + info, err = i.localModelStateService.Get(ctx, msg.Model) + if err != nil { + return err + } } else { if info.Complete < msg.Completed { @@ -422,8 +440,36 @@ func (i *imlLocalModel) Enable(ctx context.Context, model string) error { return err } if info.State == ai_local_dto.LocalModelStateDisable.Int() || info.State == ai_local_dto.LocalModelStateError.Int() { - status := ai_local_dto.LocalModelStateNormal.Int() - return i.localModelService.Save(ctx, model, &ai_local.EditLocalModel{State: &status}) + + return i.transaction.Transaction(ctx, func(ctx context.Context) error { + status := ai_local_dto.LocalModelStateNormal.Int() + err = i.localModelService.Save(ctx, model, &ai_local.EditLocalModel{State: &status}) + if err != nil { + return err + } + v, _ := i.settingService.Get(ctx, "system.ai_model.ollama_address") + cfg := make(map[string]interface{}) + cfg["provider"] = "ollama" + cfg["model"] = info.Id + cfg["model_config"] = ai_provider_local.OllamaConfig + cfg["priority"] = 0 + cfg["base"] = v + + return i.syncGateway(ctx, cluster.DefaultClusterID, []*gateway.DynamicRelease{ + { + BasicItem: &gateway.BasicItem{ + ID: info.Id, + Description: info.Id, + Resource: "ai-provider", + Version: info.UpdateAt.Format("20060102150405"), + MatchLabels: map[string]string{ + "module": "ai-provider", + }, + }, + Attr: cfg, + }}, true) + }) + } return fmt.Errorf("model %s is not disabled state,can not enable", model) } @@ -435,7 +481,21 @@ func (i *imlLocalModel) Disable(ctx context.Context, model string) error { } if info.State == ai_local_dto.LocalModelStateNormal.Int() { disable := ai_local_dto.LocalModelStateDisable.Int() - return i.localModelService.Save(ctx, model, &ai_local.EditLocalModel{State: &disable}) + return i.transaction.Transaction(ctx, func(ctx context.Context) error { + err = i.localModelService.Save(ctx, model, &ai_local.EditLocalModel{State: &disable}) + if err != nil { + return err + } + + return i.syncGateway(ctx, cluster.DefaultClusterID, []*gateway.DynamicRelease{ + { + BasicItem: &gateway.BasicItem{ + ID: info.Id, + Resource: "ai-provider", + }, + }}, false) + }) + } return fmt.Errorf("model %s is not enabled state,can not disable", model) } @@ -515,17 +575,24 @@ func (i *imlLocalModel) OnInit() { }) } -func (i *imlLocalModel) getLocalModels(ctx context.Context) ([]*gateway.DynamicRelease, error) { +func (i *imlLocalModel) getLocalModels(ctx context.Context, v string) ([]*gateway.DynamicRelease, error) { list, err := i.localModelService.List(ctx) if err != nil { return nil, err } - v, has := i.settingService.Get(ctx, "system.ai_model.ollama_address") - if !has { - return nil, errors.New("ollama_address not set") + if v == "" { + var has bool + v, has = i.settingService.Get(ctx, "system.ai_model.ollama_address") + if !has { + return nil, errors.New("ollama_address not set") + } } + releases := make([]*gateway.DynamicRelease, 0, len(list)) for _, l := range list { + if l.State != ai_local_dto.LocalModelStateNormal.Int() { + continue + } cfg := make(map[string]interface{}) cfg["provider"] = "ollama" cfg["model"] = l.Id @@ -548,7 +615,7 @@ func (i *imlLocalModel) getLocalModels(ctx context.Context) ([]*gateway.DynamicR } func (i *imlLocalModel) initGateway(ctx context.Context, clusterId string, clientDriver gateway.IClientDriver) error { - releases, err := i.getLocalModels(ctx) + releases, err := i.getLocalModels(ctx, "") if err != nil { return err } diff --git a/module/ai-local/module.go b/module/ai-local/module.go index 38330029..131c72ca 100644 --- a/module/ai-local/module.go +++ b/module/ai-local/module.go @@ -24,6 +24,8 @@ type ILocalModelModule interface { ModelState(ctx context.Context, model string) (*ai_local_dto.DeployState, *ai_local_dto.ModelInfo, error) SimpleList(ctx context.Context) ([]*ai_local_dto.SimpleItem, error) SaveCache(ctx context.Context, model string, target string) error + + SyncLocalModels(ctx context.Context, address string) error } func init() { diff --git a/module/ai/iml.go b/module/ai/iml.go index b3b76186..d4c06d89 100644 --- a/module/ai/iml.go +++ b/module/ai/iml.go @@ -482,7 +482,7 @@ func (i *imlProviderModule) UpdateProviderConfig(ctx context.Context, id string, return fmt.Errorf("ai provider not found") } - return i.transaction.Transaction(ctx, func(txCtx context.Context) error { + return i.transaction.Transaction(ctx, func(ctx context.Context) error { info, err := i.providerService.Get(ctx, id) if err != nil { if !errors.Is(err, gorm.ErrRecordNotFound) { @@ -533,12 +533,12 @@ func (i *imlProviderModule) UpdateProviderConfig(ctx context.Context, id string, Config: &input.Config, Status: &status, } - _, err = i.aiKeyService.DefaultKey(txCtx, id) + _, err = i.aiKeyService.DefaultKey(ctx, id) if err != nil { if !errors.Is(err, gorm.ErrRecordNotFound) { return err } - err = i.aiKeyService.Create(txCtx, &ai_key.Create{ + err = i.aiKeyService.Create(ctx, &ai_key.Create{ ID: id, Name: info.Name, Config: input.Config, @@ -549,7 +549,7 @@ func (i *imlProviderModule) UpdateProviderConfig(ctx context.Context, id string, Priority: 1, }) } else { - err = i.aiKeyService.Save(txCtx, id, &ai_key.Edit{ + err = i.aiKeyService.Save(ctx, id, &ai_key.Edit{ Config: &input.Config, Status: &status, }) @@ -557,13 +557,13 @@ func (i *imlProviderModule) UpdateProviderConfig(ctx context.Context, id string, if err != nil { return err } - err = i.providerService.Save(txCtx, id, pInfo) + err = i.providerService.Save(ctx, id, pInfo) if err != nil { return err } if *pInfo.Status == 0 { - return i.syncGateway(txCtx, cluster.DefaultClusterID, []*gateway.DynamicRelease{ + return i.syncGateway(ctx, cluster.DefaultClusterID, []*gateway.DynamicRelease{ { BasicItem: &gateway.BasicItem{ ID: id, @@ -573,7 +573,7 @@ func (i *imlProviderModule) UpdateProviderConfig(ctx context.Context, id string, }, false) } // 获取当前供应商默认Key信息 - defaultKey, err := i.aiKeyService.DefaultKey(txCtx, id) + defaultKey, err := i.aiKeyService.DefaultKey(ctx, id) if err != nil { return err } @@ -582,7 +582,7 @@ func (i *imlProviderModule) UpdateProviderConfig(ctx context.Context, id string, cfg["model"] = info.DefaultLLM cfg["model_config"] = model.DefaultConfig() cfg["base"] = fmt.Sprintf("%s://%s", p.URI().Scheme(), p.URI().Host()) - return i.syncGateway(txCtx, cluster.DefaultClusterID, []*gateway.DynamicRelease{ + return i.syncGateway(ctx, cluster.DefaultClusterID, []*gateway.DynamicRelease{ { BasicItem: &gateway.BasicItem{ ID: id,