From bc8870a73579f39adff724723b6abe234e712821 Mon Sep 17 00:00:00 2001 From: Liujian <824010343@qq.com> Date: Mon, 17 Feb 2025 17:21:48 +0800 Subject: [PATCH] fix: bug --- controller/ai-local/iml.go | 180 +++++++++++++++++++------------------ module/ai-local/iml.go | 12 ++- module/ai-local/module.go | 2 +- 3 files changed, 104 insertions(+), 90 deletions(-) diff --git a/controller/ai-local/iml.go b/controller/ai-local/iml.go index f6a4c825..1966bf92 100644 --- a/controller/ai-local/iml.go +++ b/controller/ai-local/iml.go @@ -6,6 +6,14 @@ import ( "fmt" "io" "math" + "net/http" + "strings" + + "github.com/APIParkLab/APIPark/model/plugin_model" + "github.com/APIParkLab/APIPark/service/api" + + ai_api_dto "github.com/APIParkLab/APIPark/module/ai-api/dto" + router_dto "github.com/APIParkLab/APIPark/module/router/dto" "github.com/APIParkLab/APIPark/module/router" @@ -163,12 +171,12 @@ func (i *imlLocalModelController) Deploy(ctx *gin.Context) { } func (i *imlLocalModelController) DeployStart(ctx *gin.Context, input *ai_local_dto.DeployInput) error { - err := i.initAILocalService(ctx, input.Model, input.Team) + fn, err := i.initAILocalService(ctx, input.Model, input.Team) if err != nil { return err } id := uuid.NewString() - _, err = i.module.Deploy(ctx, input.Model, id) + _, err = i.module.Deploy(ctx, input.Model, id, fn) if err != nil { return err } @@ -176,15 +184,15 @@ func (i *imlLocalModelController) DeployStart(ctx *gin.Context, input *ai_local_ return nil } -func (i *imlLocalModelController) initAILocalService(ctx context.Context, model string, teamID string) error { - err := i.transaction.Transaction(ctx, func(ctx context.Context) error { - catalogueInfo, err := i.catalogueModule.DefaultCatalogue(ctx) - if err != nil { - return err - } - serviceId := uuid.NewString() - prefix := fmt.Sprintf("/%s", serviceId[:8]) - providerId := "ollama" +func (i *imlLocalModelController) initAILocalService(ctx context.Context, model string, teamID string) (func() error, error) { + catalogueInfo, err := i.catalogueModule.DefaultCatalogue(ctx) + if err != nil { + return nil, err + } + serviceId := uuid.NewString() + prefix := fmt.Sprintf("/%s", serviceId[:8]) + providerId := "ollama" + err = i.transaction.Transaction(ctx, func(ctx context.Context) error { _, err = i.serviceModule.Create(ctx, teamID, &service_dto.CreateService{ Id: serviceId, Name: model, @@ -201,83 +209,83 @@ func (i *imlLocalModelController) initAILocalService(ctx context.Context, model return err } return i.module.SaveCache(ctx, model, serviceId) - - //path := fmt.Sprintf("/%s/chat", strings.Trim(prefix, "/")) - //timeout := 300000 - //retry := 0 - //aiPrompt := &ai_api_dto.AiPrompt{ - // Variables: []*ai_api_dto.AiPromptVariable{}, - // Prompt: "", - //} - //aiModel := &ai_api_dto.AiModel{ - // Id: model, - // Config: ai_provider_local.OllamaConfig, - // Provider: providerId, - // Type: "local", - //} - //name := "Demo AI API" - //description := "A demo that shows you how to use a e a Chat API." - //apiId := uuid.NewString() - //err = i.aiAPIModule.Create( - // ctx, - // info.Id, - // &ai_api_dto.CreateAPI{ - // Id: apiId, - // Name: name, - // Path: path, - // Description: description, - // Disable: false, - // AiPrompt: aiPrompt, - // AiModel: aiModel, - // Timeout: timeout, - // Retry: retry, - // }, - //) - //if err != nil { - // return err - //} - //plugins := make(map[string]api.PluginSetting) - //plugins["ai_prompt"] = api.PluginSetting{ - // Config: plugin_model.ConfigType{ - // "prompt": aiPrompt.Prompt, - // "variables": aiPrompt.Variables, - // }, - //} - //plugins["ai_formatter"] = api.PluginSetting{ - // Config: plugin_model.ConfigType{ - // "model": aiModel.Id, - // "provider": info.Provider.Id, - // "config": aiModel.Config, - // }, - //} - //_, err = i.routerModule.Create(ctx, info.Id, &router_dto.Create{ - // Id: apiId, - // Name: name, - // Path: path, - // Methods: []string{ - // http.MethodPost, - // }, - // Description: description, - // Protocols: []string{"http", "https"}, - // MatchRules: nil, - // Proxy: &router_dto.InputProxy{ - // Path: path, - // Timeout: timeout, - // Retry: retry, - // Plugins: plugins, - // }, - // Disable: false, - //}) - //if err != nil { - // return err - //} - // - //return i.docModule.SaveServiceDoc(ctx, info.Id, &service_dto.SaveServiceDoc{ - // Doc: "", - //}) }) - return err + return func() error { + path := fmt.Sprintf("/%s/chat", strings.Trim(prefix, "/")) + timeout := 300000 + retry := 0 + aiPrompt := &ai_api_dto.AiPrompt{ + Variables: []*ai_api_dto.AiPromptVariable{}, + Prompt: "", + } + aiModel := &ai_api_dto.AiModel{ + Id: model, + Config: ai_provider_local.OllamaConfig, + Provider: providerId, + Type: "local", + } + name := "Demo AI API" + description := "A demo that shows you how to use a e a Chat API." + apiId := uuid.NewString() + err = i.aiAPIModule.Create( + ctx, + serviceId, + &ai_api_dto.CreateAPI{ + Id: apiId, + Name: name, + Path: path, + Description: description, + Disable: false, + AiPrompt: aiPrompt, + AiModel: aiModel, + Timeout: timeout, + Retry: retry, + }, + ) + if err != nil { + return err + } + plugins := make(map[string]api.PluginSetting) + plugins["ai_prompt"] = api.PluginSetting{ + Config: plugin_model.ConfigType{ + "prompt": aiPrompt.Prompt, + "variables": aiPrompt.Variables, + }, + } + plugins["ai_formatter"] = api.PluginSetting{ + Config: plugin_model.ConfigType{ + "model": aiModel.Id, + "provider": providerId, + "config": aiModel.Config, + }, + } + _, err = i.routerModule.Create(ctx, serviceId, &router_dto.Create{ + Id: apiId, + Name: name, + Path: path, + Methods: []string{ + http.MethodPost, + }, + Description: description, + Protocols: []string{"http", "https"}, + MatchRules: nil, + Proxy: &router_dto.InputProxy{ + Path: path, + Timeout: timeout, + Retry: retry, + Plugins: plugins, + }, + Disable: false, + }) + if err != nil { + return err + } + + return i.docModule.SaveServiceDoc(ctx, serviceId, &service_dto.SaveServiceDoc{ + Doc: "", + }) + }, err } func (i *imlLocalModelController) Search(ctx *gin.Context, keyword string) ([]*ai_local_dto.LocalModelItem, error) { diff --git a/module/ai-local/iml.go b/module/ai-local/iml.go index da22e6b6..d3dcabeb 100644 --- a/module/ai-local/iml.go +++ b/module/ai-local/iml.go @@ -167,7 +167,7 @@ func (i *imlLocalModel) ListCanInstall(ctx context.Context, keyword string) ([]* } -func (i *imlLocalModel) pullHook() func(msg ai_provider_local.PullMessage) error { +func (i *imlLocalModel) pullHook(fn ...func() error) func(msg ai_provider_local.PullMessage) error { return func(msg ai_provider_local.PullMessage) error { return i.transaction.Transaction(context.Background(), func(ctx context.Context) error { @@ -240,6 +240,12 @@ func (i *imlLocalModel) pullHook() func(msg ai_provider_local.PullMessage) error return err } if state == ai_local_dto.DeployStateFinish.Int() { + for _, f := range fn { + err = f() + if err != nil { + return err + } + } cfg := make(map[string]interface{}) cfg["provider"] = "ollama" cfg["model"] = msg.Model @@ -296,7 +302,7 @@ func (i *imlLocalModel) syncGateway(ctx context.Context, clusterId string, relea return nil } -func (i *imlLocalModel) Deploy(ctx context.Context, model string, session string) (*ai_provider_local.Pipeline, error) { +func (i *imlLocalModel) Deploy(ctx context.Context, model string, session string, fn ...func() error) (*ai_provider_local.Pipeline, error) { var p *ai_provider_local.Pipeline err := i.transaction.Transaction(ctx, func(txCtx context.Context) error { item, err := i.localModelCacheService.GetByTarget(ctx, ai_local.CacheTypeService, model) @@ -329,7 +335,7 @@ func (i *imlLocalModel) Deploy(ctx context.Context, model string, session string if err != nil { return err } - p, err = ai_provider_local.PullModel(model, session, i.pullHook()) + p, err = ai_provider_local.PullModel(model, session, i.pullHook(fn...)) if err != nil { return err } diff --git a/module/ai-local/module.go b/module/ai-local/module.go index 58362d16..38330029 100644 --- a/module/ai-local/module.go +++ b/module/ai-local/module.go @@ -16,7 +16,7 @@ import ( type ILocalModelModule interface { Search(ctx context.Context, keyword string) ([]*ai_local_dto.LocalModelItem, error) ListCanInstall(ctx context.Context, keyword string) ([]*ai_local_dto.LocalModelPackageItem, error) - Deploy(ctx context.Context, model string, session string) (*ai_provider_local.Pipeline, error) + Deploy(ctx context.Context, model string, session string, fn ...func() error) (*ai_provider_local.Pipeline, error) CancelDeploy(ctx context.Context, model string) error RemoveModel(ctx context.Context, model string) error Enable(ctx context.Context, model string) error