diff --git a/controller/ai-api/iml.go b/controller/ai-api/iml.go index fc59cc5f..d66a77a3 100644 --- a/controller/ai-api/iml.go +++ b/controller/ai-api/iml.go @@ -2,7 +2,6 @@ package ai_api import ( "context" - "fmt" "net/http" "github.com/APIParkLab/APIPark/model/plugin_model" @@ -52,7 +51,7 @@ func (i *imlAPIController) Create(ctx *gin.Context, serviceId string, input *ai_ plugins["ai_formatter"] = api.PluginSetting{ Config: plugin_model.ConfigType{ "model": input.AiModel.Id, - "provider": fmt.Sprintf("%s@ai-provider", input.AiModel.Provider), + "provider": input.AiModel.Provider, "config": input.AiModel.Config, }, } @@ -73,8 +72,8 @@ func (i *imlAPIController) Create(ctx *gin.Context, serviceId string, input *ai_ Retry: input.Retry, Plugins: plugins, }, - Upstream: input.AiModel.Provider, - Disable: false, + //Upstream: input.AiModel.Provider, + Disable: false, }) return err @@ -101,16 +100,16 @@ func (i *imlAPIController) Edit(ctx *gin.Context, serviceId string, apiId string Retry: apiInfo.Proxy.Retry, Plugins: apiInfo.Proxy.Plugins, } - var upstream *string + //var upstream *string if input.AiModel != nil { proxy.Plugins["ai_formatter"] = api.PluginSetting{ Config: plugin_model.ConfigType{ "model": input.AiModel.Id, - "provider": fmt.Sprintf("%s@ai-provider", input.AiModel.Provider), + "provider": input.AiModel.Provider, "config": input.AiModel.Config, }, } - upstream = &input.AiModel.Provider + //upstream = &input.AiModel.Provider } if input.AiPrompt != nil { @@ -128,7 +127,7 @@ func (i *imlAPIController) Edit(ctx *gin.Context, serviceId string, apiId string Path: input.Path, Disable: input.Disable, Methods: &apiInfo.Methods, - Upstream: upstream, + //Upstream: upstream, }) if err != nil { return err diff --git a/controller/ai/iml.go b/controller/ai/iml.go index 5c526785..78b51087 100644 --- a/controller/ai/iml.go +++ b/controller/ai/iml.go @@ -50,11 +50,13 @@ func (i *imlProviderController) LLMs(ctx *gin.Context, driver string) ([]*ai_dto } func (i *imlProviderController) Enable(ctx *gin.Context, id string) error { - return i.module.UpdateProviderStatus(ctx, id, true) + //return i.module.UpdateProviderStatus(ctx, id, true) + return nil } func (i *imlProviderController) Disable(ctx *gin.Context, id string) error { - return i.module.UpdateProviderStatus(ctx, id, false) + //return i.module.UpdateProviderStatus(ctx, id, false) + return nil } func (i *imlProviderController) UpdateProviderConfig(ctx *gin.Context, id string, input *ai_dto.UpdateConfig) error { @@ -62,7 +64,8 @@ func (i *imlProviderController) UpdateProviderConfig(ctx *gin.Context, id string } func (i *imlProviderController) UpdateProviderDefaultLLM(ctx *gin.Context, id string, input *ai_dto.UpdateLLM) error { - return i.module.UpdateProviderDefaultLLM(ctx, id, input) + //return i.module.UpdateProviderDefaultLLM(ctx, id, input) + return nil } var _ IStatisticController = (*imlStatisticController)(nil) diff --git a/controller/system/iml.go b/controller/system/iml.go index 1b9c0434..f54e05b4 100644 --- a/controller/system/iml.go +++ b/controller/system/iml.go @@ -437,7 +437,7 @@ func (i *imlInitController) createAIService(ctx context.Context, teamID string, plugins["ai_formatter"] = api.PluginSetting{ Config: plugin_model.ConfigType{ "model": aiModel.Id, - "provider": fmt.Sprintf("%s@ai-provider", info.Provider.Id), + "provider": info.Provider.Id, "config": aiModel.Config, }, } @@ -457,8 +457,8 @@ func (i *imlInitController) createAIService(ctx context.Context, teamID string, Retry: retry, Plugins: plugins, }, - Disable: false, - Upstream: info.Provider.Id, + Disable: false, + //Upstream: info.Provider.Id, }) if err != nil { return err diff --git a/gateway/apinto/entity/router.go b/gateway/apinto/entity/router.go index 45406b28..664f0eb3 100644 --- a/gateway/apinto/entity/router.go +++ b/gateway/apinto/entity/router.go @@ -161,7 +161,7 @@ func ToRouter(r *gateway.ApiRelease, version string, matches map[string]string) labels = r.Labels } - return &Router{ + router := &Router{ BasicInfo: &BasicInfo{ ID: fmt.Sprintf("%s@router", r.ID), Name: r.ID, @@ -174,13 +174,16 @@ func ToRouter(r *gateway.ApiRelease, version string, matches map[string]string) Method: r.Methods, Location: r.Path, Rules: rules, - Service: fmt.Sprintf("%s@service", r.Service), Plugins: plugin, Retry: r.Retry, TimeOut: r.Timeout, Labels: labels, Protocols: []string{"http", "https"}, } + if r.Service != "" { + router.Service = fmt.Sprintf("%s@service", r.Service) + } + return router } // formatProxyPath 格式化转发路径上,用于转发重写插件正则替换 比如 请求路径 /path/{A}/{B} 原转发路径:/path/{B} 格式化后 新转发路径: /path/$2 diff --git a/gateway/profession.go b/gateway/profession.go index 606a4904..4c812778 100644 --- a/gateway/profession.go +++ b/gateway/profession.go @@ -8,6 +8,7 @@ const ( ProfessionStrategy = "strategy" ProfessionService = "service" ProfessionAIProvider = "ai-provider" + ProfessionAIResource = "ai-resource" ) func RegisterDynamicResourceDriver(key string, worker Worker) { @@ -61,6 +62,14 @@ var dynamicResourceMap = map[string]Worker{ Profession: ProfessionOutput, Driver: "loki", }, + "ai-provider": { + Profession: ProfessionAIProvider, + Driver: "ai-provider", + }, + "ai-key": { + Profession: ProfessionAIResource, + Driver: "ai-key", + }, } type Worker struct { diff --git a/module/ai-key/iml.go b/module/ai-key/iml.go index caa1c964..80a3991b 100644 --- a/module/ai-key/iml.go +++ b/module/ai-key/iml.go @@ -6,6 +6,11 @@ import ( "fmt" "time" + "github.com/APIParkLab/APIPark/service/cluster" + "github.com/eolinker/eosc/log" + + "github.com/APIParkLab/APIPark/gateway" + "github.com/eolinker/go-common/utils" "gorm.io/gorm" @@ -27,9 +32,32 @@ import ( var _ IKeyModule = &imlKeyModule{} type imlKeyModule struct { - providerService ai.IProviderService `autowired:""` - aiKeyService ai_key.IKeyService `autowired:""` - transaction store.ITransaction `autowired:""` + providerService ai.IProviderService `autowired:""` + aiKeyService ai_key.IKeyService `autowired:""` + clusterService cluster.IClusterService `autowired:""` + transaction store.ITransaction `autowired:""` +} + +func newKey(key *ai_key.Key) *gateway.DynamicRelease { + + return &gateway.DynamicRelease{ + BasicItem: &gateway.BasicItem{ + ID: key.ID, + Description: key.Name, + Resource: "ai-key", + Version: time.Now().Format("20060102150405"), + MatchLabels: map[string]string{ + "module": "ai-key", + }, + }, + Attr: map[string]interface{}{ + "expired": key.ExpireTime, + "config": key.Config, + "provider": key.Provider, + "priority": key.Priority, + "disabled": key.Status == 1, + }, + } } func (i *imlKeyModule) Create(ctx context.Context, providerId string, input *ai_key_dto.Create) error { @@ -57,11 +85,9 @@ func (i *imlKeyModule) Create(ctx context.Context, providerId string, input *ai_ status := ai_key_dto.KeyNormal.Int() if input.ExpireTime > 0 && time.Unix(int64(input.ExpireTime), 0).Before(time.Now()) { status = ai_key_dto.KeyExpired.Int() - } else { - // TODO: 发布Key到网关 } - return i.aiKeyService.Create(ctx, &ai_key.Create{ + err = i.aiKeyService.Create(ctx, &ai_key.Create{ ID: input.Id, Name: input.Name, Config: input.Config, @@ -70,9 +96,43 @@ func (i *imlKeyModule) Create(ctx context.Context, providerId string, input *ai_ ExpireTime: input.ExpireTime, Priority: priority + 1, }) + + info, _ := i.aiKeyService.Get(ctx, input.Id) + releases := []*gateway.DynamicRelease{newKey(info)} + return i.syncGateway(ctx, cluster.DefaultClusterID, releases, true) }) } +func (i *imlKeyModule) syncGateway(ctx context.Context, clusterId string, releases []*gateway.DynamicRelease, online bool) error { + client, err := i.clusterService.GatewayClient(ctx, clusterId) + if err != nil { + log.Errorf("get apinto client error: %v", err) + return nil + } + defer func() { + err := client.Close(ctx) + if err != nil { + log.Warn("close apinto client:", err) + } + }() + for _, releaseInfo := range releases { + dynamicClient, err := client.Dynamic(releaseInfo.Resource) + if err != nil { + return err + } + if online { + err = dynamicClient.Online(ctx, releaseInfo) + } else { + err = dynamicClient.Offline(ctx, releaseInfo) + } + if err != nil { + return err + } + } + + return nil +} + func (i *imlKeyModule) Edit(ctx context.Context, providerId string, id string, input *ai_key_dto.Edit) error { p, has := model_runtime.GetProvider(providerId) if !has { @@ -168,8 +228,18 @@ func (i *imlKeyModule) Delete(ctx context.Context, providerId string, id string) } } - // TODO: 操作网关下线Key - return i.aiKeyService.Delete(ctx, id) + err = i.aiKeyService.Delete(ctx, id) + if err != nil { + return err + } + return i.syncGateway(ctx, cluster.DefaultClusterID, []*gateway.DynamicRelease{{ + BasicItem: &gateway.BasicItem{ + ID: id, + Resource: "ai-key", + }, + Attr: nil, + }, + }, false) }) } @@ -312,9 +382,18 @@ func (i *imlKeyModule) UpdateKeyStatus(ctx context.Context, providerId string, i } // TODO:发布Key到网关 status := ai_key_dto.KeyNormal.Int() - return i.aiKeyService.Save(ctx, id, &ai_key.Edit{ + err = i.aiKeyService.Save(ctx, id, &ai_key.Edit{ Status: &status, }) + if err != nil { + return err + } + info, err = i.aiKeyService.Get(ctx, id) + if err != nil { + return err + } + releases := []*gateway.DynamicRelease{newKey(info)} + return i.syncGateway(ctx, providerId, releases, true) } return nil }) @@ -337,8 +416,14 @@ func (i *imlKeyModule) Sort(ctx context.Context, providerId string, input *ai_ke if err != nil { return err } - // TODO: 全量更新key配置到网关 - - return nil + list, err := i.aiKeyService.List(ctx) + if err != nil { + return err + } + releases := make([]*gateway.DynamicRelease, 0, len(list)) + for _, info := range list { + releases = append(releases, newKey(info)) + } + return i.syncGateway(ctx, cluster.DefaultClusterID, releases, true) }) } diff --git a/module/ai/iml.go b/module/ai/iml.go index aefb7f0f..2dae83c3 100644 --- a/module/ai/iml.go +++ b/module/ai/iml.go @@ -2,7 +2,6 @@ package ai import ( "context" - "encoding/json" "errors" "fmt" "net/http" @@ -30,24 +29,24 @@ import ( "gorm.io/gorm" ) -func newAIUpstream(provider string, uri model_runtime.IProviderURI) *gateway.DynamicRelease { +func newKey(key *ai_key.Key) *gateway.DynamicRelease { + return &gateway.DynamicRelease{ BasicItem: &gateway.BasicItem{ - ID: provider, - Description: fmt.Sprintf("auto create by ai provider %s", provider), - Resource: "service", + ID: key.ID, + Description: key.Name, + Resource: "ai-key", Version: time.Now().Format("20060102150405"), MatchLabels: map[string]string{ - "module": "service", + "module": "ai-key", }, }, Attr: map[string]interface{}{ - "driver": "http", - "balance": "round-robin", - "nodes": []string{fmt.Sprintf("%s weight=100", uri.Host())}, - "pass_host": "node", - "scheme": uri.Scheme(), - "timeout": 300000, + "expired": key.ExpireTime, + "config": key.Config, + "provider": key.Provider, + "priority": key.Priority, + "disabled": key.Status == 1, }, } } @@ -85,8 +84,18 @@ func (i *imlProviderModule) Sort(ctx context.Context, input *ai_dto.Sort) error providerMap := utils.SliceToMap(list, func(e *ai.Provider) string { return e.Id }) + releases := make([]*gateway.DynamicRelease, 0, len(list)) for index, id := range input.Providers { - _, has := providerMap[id] + p, has := model_runtime.GetProvider(id) + if !has { + continue + } + + l, has := providerMap[id] + if !has { + continue + } + model, has := p.GetModel(l.DefaultLLM) if !has { continue } @@ -97,6 +106,28 @@ func (i *imlProviderModule) Sort(ctx context.Context, input *ai_dto.Sort) error if err != nil { return err } + cfg := make(map[string]interface{}) + cfg["provider"] = l.Id + cfg["model"] = l.DefaultLLM + cfg["model_config"] = model.DefaultConfig() + cfg["priority"] = l.Priority + cfg["base"] = fmt.Sprintf("%s://%s", p.URI().Scheme(), p.URI().Host()) + releases = append(releases, &gateway.DynamicRelease{ + BasicItem: &gateway.BasicItem{ + ID: l.Id, + Description: l.Name, + Resource: "ai-provider", + Version: l.UpdateAt.Format("20060102150405"), + MatchLabels: map[string]string{ + "module": "ai-provider", + }, + }, + Attr: cfg, + }) + err = i.syncGateway(ctx, cluster.DefaultClusterID, releases, true) + if err != nil { + return err + } } return nil }) @@ -442,72 +473,6 @@ func (i *imlProviderModule) LLMs(ctx context.Context, driver string) ([]*ai_dto. }, nil } -func (i *imlProviderModule) UpdateProviderStatus(ctx context.Context, id string, enable bool) error { - driver, has := model_runtime.GetProvider(id) - if !has { - return fmt.Errorf("ai provider not found") - } - info, err := i.providerService.Get(ctx, id) - if err != nil { - if !errors.Is(err, gorm.ErrRecordNotFound) { - return err - } - return fmt.Errorf("ai provider not found") - } - - return i.transaction.Transaction(ctx, func(txCtx context.Context) error { - status := 0 - if enable { - status = 1 - } - err = i.providerService.Save(txCtx, id, &ai.SetProvider{ - Status: &status, - }) - if err != nil { - return err - } - if enable { - cfg := make(map[string]interface{}) - err = json.Unmarshal([]byte(info.Config), &cfg) - if err != nil { - log.Errorf("unmarshal ai provider config error,id is %s,err is %v", info.Id, err) - return err - } - cfg["driver"] = info.Id - - return i.syncGateway(txCtx, cluster.DefaultClusterID, []*gateway.DynamicRelease{{ - BasicItem: &gateway.BasicItem{ - ID: info.Id, - Description: info.Name, - Version: info.UpdateAt.Format("20060102150405"), - MatchLabels: map[string]string{ - "module": "ai-provider", - }, - }, - Attr: cfg, - }, newAIUpstream(info.Id, driver.URI()), - }, enable) - } else { - return i.syncGateway(txCtx, cluster.DefaultClusterID, []*gateway.DynamicRelease{ - { - BasicItem: &gateway.BasicItem{ - ID: info.Id, - Resource: info.Id, - }, - }, - { - BasicItem: &gateway.BasicItem{ - ID: info.Id, - Resource: "service", - }, - }, - }, enable) - } - - }) - -} - func (i *imlProviderModule) UpdateProviderConfig(ctx context.Context, id string, input *ai_dto.UpdateConfig) error { p, has := model_runtime.GetProvider(id) if !has { @@ -532,6 +497,10 @@ func (i *imlProviderModule) UpdateProviderConfig(ctx context.Context, id string, Config: input.Config, } } + model, has := p.GetModel(input.DefaultLLM) + if !has { + return fmt.Errorf("ai provider model not found") + } err = p.Check(input.Config) if err != nil { return err @@ -573,6 +542,7 @@ func (i *imlProviderModule) UpdateProviderConfig(ctx context.Context, id string, if err != nil { return err } + if input.Enable != nil { status = 0 if *input.Enable { @@ -584,40 +554,31 @@ func (i *imlProviderModule) UpdateProviderConfig(ctx context.Context, id string, if err != nil { return err } - cfg := make(map[string]interface{}) - err = json.Unmarshal([]byte(input.Config), &cfg) + // 获取当前供应商所有Key信息 + defaultKey, err := i.aiKeyService.DefaultKey(ctx, id) if err != nil { - log.Errorf("unmarshal ai provider config error,id is %s,err is %v", id, err) return err } - return nil - //return i.syncGateway(ctx, cluster.DefaultClusterID, []*gateway.DynamicRelease{ - // { - // BasicItem: &gateway.BasicItem{ - // ID: id, - // Description: info.Name, - // Resource: id, - // Version: info.UpdateAt.Format("20060102150405"), - // MatchLabels: map[string]string{ - // "module": "ai-provider", - // }, - // }, - // Attr: cfg, - // }, newAIUpstream(id, p.URI()), - //}, true) - }) -} - -func (i *imlProviderModule) UpdateProviderDefaultLLM(ctx context.Context, id string, input *ai_dto.UpdateLLM) error { - _, err := i.providerService.Get(ctx, id) - if err != nil { - if !errors.Is(err, gorm.ErrRecordNotFound) { - return err - } - return fmt.Errorf("ai provider not found") - } - return i.providerService.Save(ctx, id, &ai.SetProvider{ - DefaultLLM: &input.LLM, + cfg := make(map[string]interface{}) + cfg["provider"] = info.Id + cfg["model"] = info.DefaultLLM + cfg["model_config"] = model.DefaultConfig() + cfg["priority"] = info.Priority + cfg["base"] = fmt.Sprintf("%s://%s", p.URI().Scheme(), p.URI().Host()) + return i.syncGateway(ctx, cluster.DefaultClusterID, []*gateway.DynamicRelease{ + { + BasicItem: &gateway.BasicItem{ + ID: id, + Description: info.Name, + Resource: "ai-provider", + Version: info.UpdateAt.Format("20060102150405"), + MatchLabels: map[string]string{ + "module": "ai-provider", + }, + }, + Attr: cfg, + }, newKey(defaultKey), + }, true) }) } @@ -626,20 +587,30 @@ func (i *imlProviderModule) getAiProviders(ctx context.Context) ([]*gateway.Dyna if err != nil { return nil, err } + providers := make([]*gateway.DynamicRelease, 0, len(list)) - for _, p := range list { - cfg := make(map[string]interface{}) - err = json.Unmarshal([]byte(p.Config), &cfg) - if err != nil { - log.Errorf("unmarshal ai provider config error,id is %s,err is %v", p.Id, err) - continue + for _, l := range list { + // 获取当前供应商所有Key信息 + + driver, has := model_runtime.GetProvider(l.Id) + if !has { + return nil, fmt.Errorf("provider not found: %s", l.Id) } + model, has := driver.GetModel(l.DefaultLLM) + if !has { + return nil, fmt.Errorf("model not found: %s", l.DefaultLLM) + } + cfg := make(map[string]interface{}) + cfg["provider"] = l.Id + cfg["model"] = l.DefaultLLM + cfg["model_config"] = model.DefaultConfig() + cfg["priority"] = l.Priority providers = append(providers, &gateway.DynamicRelease{ BasicItem: &gateway.BasicItem{ - ID: p.Id, - Description: p.Name, - Resource: p.Id, - Version: p.UpdateAt.Format("20060102150405"), + ID: l.Id, + Description: l.Name, + Resource: "ai-provider", + Version: l.UpdateAt.Format("20060102150405"), MatchLabels: map[string]string{ "module": "ai-provider", }, @@ -655,16 +626,9 @@ func (i *imlProviderModule) initGateway(ctx context.Context, clusterId string, c if err != nil { return err } - serviceClient, err := clientDriver.Dynamic("service") - if err != nil { - return err - } + for _, p := range providers { - driver, has := model_runtime.GetProvider(p.ID) - if !has { - continue - } - client, err := clientDriver.Dynamic(p.ID) + client, err := clientDriver.Dynamic(p.Resource) if err != nil { return err } @@ -672,12 +636,6 @@ func (i *imlProviderModule) initGateway(ctx context.Context, clusterId string, c if err != nil { return err } - - err = serviceClient.Online(ctx, newAIUpstream(p.ID, driver.URI())) - if err != nil { - return err - } - } return nil diff --git a/module/ai/module.go b/module/ai/module.go index f9682700..99b6a56e 100644 --- a/module/ai/module.go +++ b/module/ai/module.go @@ -17,9 +17,9 @@ type IProviderModule interface { Provider(ctx context.Context, id string) (*ai_dto.Provider, error) SimpleProvider(ctx context.Context, id string) (*ai_dto.SimpleProvider, error) LLMs(ctx context.Context, driver string) ([]*ai_dto.LLMItem, *ai_dto.ProviderItem, error) - UpdateProviderStatus(ctx context.Context, id string, enable bool) error + //UpdateProviderStatus(ctx context.Context, id string, enable bool) error UpdateProviderConfig(ctx context.Context, id string, input *ai_dto.UpdateConfig) error - UpdateProviderDefaultLLM(ctx context.Context, id string, input *ai_dto.UpdateLLM) error + //UpdateProviderDefaultLLM(ctx context.Context, id string, input *ai_dto.UpdateLLM) error Sort(ctx context.Context, input *ai_dto.Sort) error } diff --git a/module/publish/iml.go b/module/publish/iml.go index 0eb17754..4226ef1f 100644 --- a/module/publish/iml.go +++ b/module/publish/iml.go @@ -124,6 +124,7 @@ func (m *imlPublishModule) getProjectRelease(ctx context.Context, projectID stri Version: version, } apis := make([]*gateway.ApiRelease, 0, len(apiInfos)) + hasUpstream := len(upstreamCommitIds) > 0 for _, a := range apiInfos { apiInfo := &gateway.ApiRelease{ BasicItem: &gateway.BasicItem{ @@ -133,7 +134,10 @@ func (m *imlPublishModule) getProjectRelease(ctx context.Context, projectID stri }, Path: a.Path, Methods: a.Methods, - Service: a.Upstream, + //Service: a.Upstream, + } + if hasUpstream { + apiInfo.Service = a.Upstream } proxy, ok := proxyCommitMap[a.UUID] if ok {