From 9871e252bc3bbca54748dced2c2727ab89b52a15 Mon Sep 17 00:00:00 2001 From: Liujian <824010343@qq.com> Date: Mon, 6 Jan 2025 09:47:23 +0800 Subject: [PATCH] ai balance finish --- controller/ai-api/iml.go | 15 +++++++------ controller/system/iml.go | 6 +++--- gateway/apinto/entity/router.go | 7 +++++-- module/ai-key/iml.go | 37 +++++++++++++++++++++++++++------ module/ai/iml.go | 35 ++++++++++++++++++++++++++++++- module/publish/iml.go | 6 +++++- 6 files changed, 85 insertions(+), 21 deletions(-) diff --git a/controller/ai-api/iml.go b/controller/ai-api/iml.go index fc59cc5f..d66a77a3 100644 --- a/controller/ai-api/iml.go +++ b/controller/ai-api/iml.go @@ -2,7 +2,6 @@ package ai_api import ( "context" - "fmt" "net/http" "github.com/APIParkLab/APIPark/model/plugin_model" @@ -52,7 +51,7 @@ func (i *imlAPIController) Create(ctx *gin.Context, serviceId string, input *ai_ plugins["ai_formatter"] = api.PluginSetting{ Config: plugin_model.ConfigType{ "model": input.AiModel.Id, - "provider": fmt.Sprintf("%s@ai-provider", input.AiModel.Provider), + "provider": input.AiModel.Provider, "config": input.AiModel.Config, }, } @@ -73,8 +72,8 @@ func (i *imlAPIController) Create(ctx *gin.Context, serviceId string, input *ai_ Retry: input.Retry, Plugins: plugins, }, - Upstream: input.AiModel.Provider, - Disable: false, + //Upstream: input.AiModel.Provider, + Disable: false, }) return err @@ -101,16 +100,16 @@ func (i *imlAPIController) Edit(ctx *gin.Context, serviceId string, apiId string Retry: apiInfo.Proxy.Retry, Plugins: apiInfo.Proxy.Plugins, } - var upstream *string + //var upstream *string if input.AiModel != nil { proxy.Plugins["ai_formatter"] = api.PluginSetting{ Config: plugin_model.ConfigType{ "model": input.AiModel.Id, - "provider": fmt.Sprintf("%s@ai-provider", input.AiModel.Provider), + "provider": input.AiModel.Provider, "config": input.AiModel.Config, }, } - upstream = &input.AiModel.Provider + //upstream = &input.AiModel.Provider } if input.AiPrompt != nil { @@ -128,7 +127,7 @@ func (i *imlAPIController) Edit(ctx *gin.Context, serviceId string, apiId string Path: input.Path, Disable: input.Disable, Methods: &apiInfo.Methods, - Upstream: upstream, + //Upstream: upstream, }) if err != nil { return err diff --git a/controller/system/iml.go b/controller/system/iml.go index 1b9c0434..f54e05b4 100644 --- a/controller/system/iml.go +++ b/controller/system/iml.go @@ -437,7 +437,7 @@ func (i *imlInitController) createAIService(ctx context.Context, teamID string, plugins["ai_formatter"] = api.PluginSetting{ Config: plugin_model.ConfigType{ "model": aiModel.Id, - "provider": fmt.Sprintf("%s@ai-provider", info.Provider.Id), + "provider": info.Provider.Id, "config": aiModel.Config, }, } @@ -457,8 +457,8 @@ func (i *imlInitController) createAIService(ctx context.Context, teamID string, Retry: retry, Plugins: plugins, }, - Disable: false, - Upstream: info.Provider.Id, + Disable: false, + //Upstream: info.Provider.Id, }) if err != nil { return err diff --git a/gateway/apinto/entity/router.go b/gateway/apinto/entity/router.go index 45406b28..664f0eb3 100644 --- a/gateway/apinto/entity/router.go +++ b/gateway/apinto/entity/router.go @@ -161,7 +161,7 @@ func ToRouter(r *gateway.ApiRelease, version string, matches map[string]string) labels = r.Labels } - return &Router{ + router := &Router{ BasicInfo: &BasicInfo{ ID: fmt.Sprintf("%s@router", r.ID), Name: r.ID, @@ -174,13 +174,16 @@ func ToRouter(r *gateway.ApiRelease, version string, matches map[string]string) Method: r.Methods, Location: r.Path, Rules: rules, - Service: fmt.Sprintf("%s@service", r.Service), Plugins: plugin, Retry: r.Retry, TimeOut: r.Timeout, Labels: labels, Protocols: []string{"http", "https"}, } + if r.Service != "" { + router.Service = fmt.Sprintf("%s@service", r.Service) + } + return router } // formatProxyPath 格式化转发路径上,用于转发重写插件正则替换 比如 请求路径 /path/{A}/{B} 原转发路径:/path/{B} 格式化后 新转发路径: /path/$2 diff --git a/module/ai-key/iml.go b/module/ai-key/iml.go index a9b4d6a9..80a3991b 100644 --- a/module/ai-key/iml.go +++ b/module/ai-key/iml.go @@ -228,8 +228,18 @@ func (i *imlKeyModule) Delete(ctx context.Context, providerId string, id string) } } - // TODO: 操作网关下线Key - return i.aiKeyService.Delete(ctx, id) + err = i.aiKeyService.Delete(ctx, id) + if err != nil { + return err + } + return i.syncGateway(ctx, cluster.DefaultClusterID, []*gateway.DynamicRelease{{ + BasicItem: &gateway.BasicItem{ + ID: id, + Resource: "ai-key", + }, + Attr: nil, + }, + }, false) }) } @@ -372,9 +382,18 @@ func (i *imlKeyModule) UpdateKeyStatus(ctx context.Context, providerId string, i } // TODO:发布Key到网关 status := ai_key_dto.KeyNormal.Int() - return i.aiKeyService.Save(ctx, id, &ai_key.Edit{ + err = i.aiKeyService.Save(ctx, id, &ai_key.Edit{ Status: &status, }) + if err != nil { + return err + } + info, err = i.aiKeyService.Get(ctx, id) + if err != nil { + return err + } + releases := []*gateway.DynamicRelease{newKey(info)} + return i.syncGateway(ctx, providerId, releases, true) } return nil }) @@ -397,8 +416,14 @@ func (i *imlKeyModule) Sort(ctx context.Context, providerId string, input *ai_ke if err != nil { return err } - // TODO: 全量更新key配置到网关 - - return nil + list, err := i.aiKeyService.List(ctx) + if err != nil { + return err + } + releases := make([]*gateway.DynamicRelease, 0, len(list)) + for _, info := range list { + releases = append(releases, newKey(info)) + } + return i.syncGateway(ctx, cluster.DefaultClusterID, releases, true) }) } diff --git a/module/ai/iml.go b/module/ai/iml.go index 5bb7f3cc..2dae83c3 100644 --- a/module/ai/iml.go +++ b/module/ai/iml.go @@ -84,8 +84,18 @@ func (i *imlProviderModule) Sort(ctx context.Context, input *ai_dto.Sort) error providerMap := utils.SliceToMap(list, func(e *ai.Provider) string { return e.Id }) + releases := make([]*gateway.DynamicRelease, 0, len(list)) for index, id := range input.Providers { - _, has := providerMap[id] + p, has := model_runtime.GetProvider(id) + if !has { + continue + } + + l, has := providerMap[id] + if !has { + continue + } + model, has := p.GetModel(l.DefaultLLM) if !has { continue } @@ -96,6 +106,28 @@ func (i *imlProviderModule) Sort(ctx context.Context, input *ai_dto.Sort) error if err != nil { return err } + cfg := make(map[string]interface{}) + cfg["provider"] = l.Id + cfg["model"] = l.DefaultLLM + cfg["model_config"] = model.DefaultConfig() + cfg["priority"] = l.Priority + cfg["base"] = fmt.Sprintf("%s://%s", p.URI().Scheme(), p.URI().Host()) + releases = append(releases, &gateway.DynamicRelease{ + BasicItem: &gateway.BasicItem{ + ID: l.Id, + Description: l.Name, + Resource: "ai-provider", + Version: l.UpdateAt.Format("20060102150405"), + MatchLabels: map[string]string{ + "module": "ai-provider", + }, + }, + Attr: cfg, + }) + err = i.syncGateway(ctx, cluster.DefaultClusterID, releases, true) + if err != nil { + return err + } } return nil }) @@ -531,6 +563,7 @@ func (i *imlProviderModule) UpdateProviderConfig(ctx context.Context, id string, cfg["provider"] = info.Id cfg["model"] = info.DefaultLLM cfg["model_config"] = model.DefaultConfig() + cfg["priority"] = info.Priority cfg["base"] = fmt.Sprintf("%s://%s", p.URI().Scheme(), p.URI().Host()) return i.syncGateway(ctx, cluster.DefaultClusterID, []*gateway.DynamicRelease{ { diff --git a/module/publish/iml.go b/module/publish/iml.go index 0eb17754..4226ef1f 100644 --- a/module/publish/iml.go +++ b/module/publish/iml.go @@ -124,6 +124,7 @@ func (m *imlPublishModule) getProjectRelease(ctx context.Context, projectID stri Version: version, } apis := make([]*gateway.ApiRelease, 0, len(apiInfos)) + hasUpstream := len(upstreamCommitIds) > 0 for _, a := range apiInfos { apiInfo := &gateway.ApiRelease{ BasicItem: &gateway.BasicItem{ @@ -133,7 +134,10 @@ func (m *imlPublishModule) getProjectRelease(ctx context.Context, projectID stri }, Path: a.Path, Methods: a.Methods, - Service: a.Upstream, + //Service: a.Upstream, + } + if hasUpstream { + apiInfo.Service = a.Upstream } proxy, ok := proxyCommitMap[a.UUID] if ok {