From a9d5585ef9fb5fb9ddf57cd47bd73d587f3d7734 Mon Sep 17 00:00:00 2001 From: Liujian <824010343@qq.com> Date: Tue, 11 Mar 2025 19:44:55 +0800 Subject: [PATCH] Local model compatibility testing completed --- ai-provider/local/local.go | 14 +++- ai-provider/local/ollama.go | 102 -------------------------- controller/ai-api/iml.go | 6 +- controller/ai-local/iml.go | 4 +- controller/service/iml.go | 8 +- module/ai-api/iml.go | 54 +++++++++----- module/ai-balance/iml.go | 16 ++-- module/ai-local/iml.go | 88 ++++++++++++++-------- module/ai/iml.go | 10 +-- module/service/dto/output.go | 3 +- module/system/iml.go | 4 +- service/ai-api/iml.go | 5 ++ service/ai-api/service.go | 1 + service/ai-local/iml.go | 5 ++ service/ai-local/service.go | 1 + service/cluster/cluster.go | 2 +- service/universally/get-softdelete.go | 2 +- service/universally/get.go | 2 +- 18 files changed, 147 insertions(+), 180 deletions(-) delete mode 100644 ai-provider/local/ollama.go diff --git a/ai-provider/local/local.go b/ai-provider/local/local.go index 5e084eeb..bcea9e84 100644 --- a/ai-provider/local/local.go +++ b/ai-provider/local/local.go @@ -8,10 +8,11 @@ import ( ) var ( - client *api.Client + client *api.Client + ProviderLocal = "LocalModel" ) -func ResetOllamaAddress(address string) error { +func ResetLocalAddress(address string) error { u, err := url.Parse(address) if err != nil { return err @@ -19,3 +20,12 @@ func ResetOllamaAddress(address string) error { client = api.NewClient(u, http.DefaultClient) return nil } + +var ( + LocalConfig = "{\n \"temperature\": \"\",\n \"top_p\": \"\",\n \"max_tokens\": \"\"\n}" + LocalSvg = ` + + + +` +) diff --git a/ai-provider/local/ollama.go b/ai-provider/local/ollama.go deleted file mode 100644 index be765848..00000000 --- a/ai-provider/local/ollama.go +++ /dev/null @@ -1,102 +0,0 @@ -package ai_provider_local - -var ( - OllamaConfig = "{\n \"mirostat\": 0,\n \"mirostat_eta\": 0.1,\n \"mirostat_tau\": 5.0,\n \"num_ctx\": 4096,\n \"repeat_last_n\":64,\n \"repeat_penalty\": 1.1,\n \"temperature\": 0.7,\n \"seed\": 42,\n \"num_predict\": 42,\n \"top_k\": 40,\n \"top_p\": 0.9,\n \"min_p\": 0.5\n}\n" - OllamaSvg = ` - - - - - - - - - - - -` -) diff --git a/controller/ai-api/iml.go b/controller/ai-api/iml.go index d83439e7..83292cbe 100644 --- a/controller/ai-api/iml.go +++ b/controller/ai-api/iml.go @@ -4,6 +4,8 @@ import ( "context" "net/http" + ai_provider_local "github.com/APIParkLab/APIPark/ai-provider/local" + "github.com/APIParkLab/APIPark/model/plugin_model" ai_api "github.com/APIParkLab/APIPark/module/ai-api" ai_api_dto "github.com/APIParkLab/APIPark/module/ai-api/dto" @@ -48,7 +50,7 @@ func (i *imlAPIController) Create(ctx *gin.Context, serviceId string, input *ai_ } } if input.AiModel != nil { - provider := "ollama" + provider := ai_provider_local.ProviderLocal if input.AiModel.Type != "local" { provider = input.AiModel.Provider } @@ -107,7 +109,7 @@ func (i *imlAPIController) Edit(ctx *gin.Context, serviceId string, apiId string } //var upstream *string if input.AiModel != nil { - provider := "ollama" + provider := ai_provider_local.ProviderLocal if input.AiModel.Type != "local" { provider = input.AiModel.Provider } diff --git a/controller/ai-local/iml.go b/controller/ai-local/iml.go index e436609a..f2c13487 100644 --- a/controller/ai-local/iml.go +++ b/controller/ai-local/iml.go @@ -245,7 +245,7 @@ func (i *imlLocalModelController) initAILocalService(ctx context.Context, model } serviceId := uuid.NewString() prefix := fmt.Sprintf("/%s", serviceId[:8]) - providerId := "ollama" + providerId := ai_provider_local.ProviderLocal err = i.transaction.Transaction(ctx, func(ctx context.Context) error { _, err = i.serviceModule.Create(ctx, teamID, &service_dto.CreateService{ Id: serviceId, @@ -276,7 +276,7 @@ func (i *imlLocalModelController) initAILocalService(ctx context.Context, model } aiModel := &ai_api_dto.AiModel{ Id: model, - Config: ai_provider_local.OllamaConfig, + Config: ai_provider_local.LocalConfig, Provider: providerId, Type: "local", } diff --git a/controller/service/iml.go b/controller/service/iml.go index e2f0bf32..b391241a 100644 --- a/controller/service/iml.go +++ b/controller/service/iml.go @@ -296,7 +296,7 @@ func (i *imlServiceController) editAIService(ctx *gin.Context, id string, input if input.Provider == nil { return nil, fmt.Errorf("provider is required") } - if *input.Provider != "ollama" { + if *input.Provider != ai_provider_local.ProviderLocal { _, has := model_runtime.GetProvider(*input.Provider) if !has { return nil, fmt.Errorf("provider not found") @@ -330,7 +330,7 @@ func (i *imlServiceController) createAIService(ctx *gin.Context, teamID string, modelId := "" modelCfg := "" modelType := "online" - if *input.Provider == "ollama" { + if *input.Provider == ai_provider_local.ProviderLocal { modelType = "local" list, err := i.aiLocalModel.SimpleList(ctx) if err != nil { @@ -340,7 +340,7 @@ func (i *imlServiceController) createAIService(ctx *gin.Context, teamID string, return nil, fmt.Errorf("no local model") } modelId = list[0].Id - modelCfg = ai_provider_local.OllamaConfig + modelCfg = ai_provider_local.LocalConfig } else { pv, err := i.providerModule.Provider(ctx, *input.Provider) if err != nil { @@ -367,7 +367,7 @@ func (i *imlServiceController) createAIService(ctx *gin.Context, teamID string, return err } prefix := strings.Replace(input.Prefix, ":", "_", -1) - path := fmt.Sprintf("/%s/chat", strings.Trim(prefix, "/")) + path := fmt.Sprintf("/%s/chat/completions", strings.Trim(prefix, "/")) timeout := 300000 retry := 0 aiPrompt := &ai_api_dto.AiPrompt{ diff --git a/module/ai-api/iml.go b/module/ai-api/iml.go index 4bea9514..25eda7f7 100644 --- a/module/ai-api/iml.go +++ b/module/ai-api/iml.go @@ -5,13 +5,15 @@ import ( "encoding/json" "errors" "fmt" - ai_model "github.com/APIParkLab/APIPark/service/ai-model" "net/http" "strings" - "github.com/eolinker/eosc/log" + ai_provider_local "github.com/APIParkLab/APIPark/ai-provider/local" model_runtime "github.com/APIParkLab/APIPark/ai-provider/model-runtime" + ai_model "github.com/APIParkLab/APIPark/service/ai-model" + + "github.com/eolinker/eosc/log" ai_api_dto "github.com/APIParkLab/APIPark/module/ai-api/dto" ai_api "github.com/APIParkLab/APIPark/service/ai-api" @@ -239,27 +241,41 @@ func (i *imlAPIModule) List(ctx context.Context, keyword string, serviceId strin if err != nil { return item } - p, has := model_runtime.GetProvider(aiModel.Provider) - if has { - item.Provider = ai_api_dto.ProviderItem{ - Id: p.ID(), - Name: p.Name(), - Logo: "", - } - m, has := p.GetModel(t.Model) - if has { - item.Model = ai_api_dto.ModelItem{ - Id: m.ID(), - Name: m.Name(), - Logo: "", - } - } - } else { + item.ModelType = ai_api_dto.ModelType(aiModel.Type) + if item.ModelType == ai_api_dto.ModelTypeLocal { item.Model = ai_api_dto.ModelItem{ Id: aiModel.Id, - Name: "unknown", + Name: aiModel.Id, + } + item.Provider = ai_api_dto.ProviderItem{ + Id: ai_provider_local.ProviderLocal, + Name: ai_provider_local.ProviderLocal, + Logo: "", + } + } else { + p, has := model_runtime.GetProvider(aiModel.Provider) + if has { + item.Provider = ai_api_dto.ProviderItem{ + Id: p.ID(), + Name: p.Name(), + Logo: "", + } + m, has := p.GetModel(t.Model) + if has { + item.Model = ai_api_dto.ModelItem{ + Id: m.ID(), + Name: m.Name(), + Logo: "", + } + } + } else { + item.Model = ai_api_dto.ModelItem{ + Id: aiModel.Id, + Name: "unknown", + } } } + return item }), nil } diff --git a/module/ai-balance/iml.go b/module/ai-balance/iml.go index f361d724..2d6c0b41 100644 --- a/module/ai-balance/iml.go +++ b/module/ai-balance/iml.go @@ -82,8 +82,8 @@ func (i *imlBalanceModule) Create(ctx context.Context, input *ai_balance_dto.Cre modelName = input.Model base = fmt.Sprintf("%s://%s", p.URI().Scheme(), p.URI().Host()) case ai_balance_dto.ModelTypeLocal: - input.Provider = "ollama" - providerName = "Ollama" + input.Provider = ai_provider_local.ProviderLocal + providerName = ai_provider_local.ProviderLocal modelName = input.Model v, has := i.settingService.Get(ctx, "system.ai_model.ollama_address") if !has { @@ -119,7 +119,7 @@ func newRelease(item *ai_balance.Balance, base string) *gateway.DynamicRelease { cfg := make(map[string]interface{}) cfg["provider"] = item.Provider cfg["model"] = item.Model - cfg["model_config"] = ai_provider_local.OllamaConfig + cfg["model_config"] = ai_provider_local.LocalConfig cfg["base"] = base cfg["priority"] = item.Priority return &gateway.DynamicRelease{ @@ -155,7 +155,7 @@ func (i *imlBalanceModule) Sort(ctx context.Context, input *ai_balance_dto.Sort) releases := make([]*gateway.DynamicRelease, 0, len(list)) for _, item := range list { base := v - if item.Provider != "ollama" { + if item.Provider != ai_provider_local.ProviderLocal { p, has := model_runtime.GetProvider(item.Provider) if !has { continue @@ -259,7 +259,7 @@ func (i *imlBalanceModule) syncGateway(ctx context.Context, clusterId string, re } func (i *imlBalanceModule) getLocalBalances(ctx context.Context, v string) ([]*gateway.DynamicRelease, error) { - balances, err := i.balanceService.Search(ctx, "", map[string]interface{}{"provider": "ollama"}, "priority asc") + balances, err := i.balanceService.Search(ctx, "", map[string]interface{}{"provider": ai_provider_local.ProviderLocal}, "priority asc") if err != nil { return nil, err } @@ -274,7 +274,7 @@ func (i *imlBalanceModule) getLocalBalances(ctx context.Context, v string) ([]*g releases := make([]*gateway.DynamicRelease, 0, len(balances)) for _, item := range balances { base := v - if item.Provider != "ollama" { + if item.Provider != ai_provider_local.ProviderLocal { p, has := model_runtime.GetProvider(item.Provider) if !has { continue @@ -298,12 +298,12 @@ func (i *imlBalanceModule) getBalances(ctx context.Context) ([]*gateway.DynamicR releases := make([]*gateway.DynamicRelease, 0, len(balances)) for _, item := range balances { base := v - if item.Provider != "ollama" { + if item.Provider != ai_provider_local.ProviderLocal { p, has := model_runtime.GetProvider(item.Provider) if !has { continue } - base = fmt.Sprintf("%s://%s", p.URI().Scheme(), p.URI().Host()) + base = fmt.Sprintf("%s://%s%s", p.URI().Scheme(), p.URI().Host(), p.URI().Path()) } releases = append(releases, newRelease(item, base)) } diff --git a/module/ai-local/iml.go b/module/ai-local/iml.go index 28577413..cf886e44 100644 --- a/module/ai-local/iml.go +++ b/module/ai-local/iml.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "strings" ai_balance "github.com/APIParkLab/APIPark/service/ai-balance" @@ -73,8 +72,8 @@ func (i *imlLocalModel) SimpleList(ctx context.Context) ([]*ai_local_dto.SimpleI return &ai_local_dto.SimpleItem{ Id: s.Id, Name: s.Name, - DefaultConfig: ai_provider_local.OllamaConfig, - Logo: ai_provider_local.OllamaSvg, + DefaultConfig: ai_provider_local.LocalConfig, + Logo: ai_provider_local.LocalSvg, } }, func(l *ai_local.LocalModel) bool { if l.State != ai_local_dto.LocalModelStateNormal.Int() && l.State != ai_local_dto.LocalModelStateDisable.Int() { @@ -118,7 +117,7 @@ func (i *imlLocalModel) Search(ctx context.Context, keyword string) ([]*ai_local APICount: count, CanDelete: count < 1 && s.State != ai_local_dto.LocalModelStateDeploying.Int(), UpdateTime: auto.TimeLabel(s.UpdateAt), - Provider: "ollama", + Provider: ai_provider_local.ProviderLocal, } }), nil } @@ -249,7 +248,7 @@ func (i *imlLocalModel) pullHook(fn ...func() error) func(msg ai_provider_local. cfg := make(map[string]interface{}) cfg["provider"] = "ollama" cfg["model"] = msg.Model - cfg["model_config"] = ai_provider_local.OllamaConfig + cfg["model_config"] = ai_provider_local.LocalConfig cfg["priority"] = 0 cfg["base"] = v @@ -322,7 +321,7 @@ func (i *imlLocalModel) Deploy(ctx context.Context, model string, session string err = i.localModelService.Create(ctx, &ai_local.CreateLocalModel{ Id: model, Name: model, - Provider: "ollama", + Provider: ai_provider_local.ProviderLocal, State: ai_local_dto.LocalModelStateDeploying.Int(), }) @@ -451,7 +450,7 @@ func (i *imlLocalModel) Enable(ctx context.Context, model string) error { cfg := make(map[string]interface{}) cfg["provider"] = "ollama" cfg["model"] = info.Id - cfg["model_config"] = ai_provider_local.OllamaConfig + cfg["model_config"] = ai_provider_local.LocalConfig cfg["priority"] = 0 cfg["base"] = v @@ -513,7 +512,7 @@ func (i *imlLocalModel) OnInit() { }) models, version := ai_provider_local.ModelsCanInstall() for _, model := range models { - delete(oldModels, model.Id) + if v, ok := oldModels[model.Id]; ok { if v.Version == version { continue @@ -542,6 +541,7 @@ func (i *imlLocalModel) OnInit() { return } } + delete(oldModels, model.Id) } for id := range oldModels { err = i.localModelPackageService.Delete(ctx, id) @@ -549,29 +549,57 @@ func (i *imlLocalModel) OnInit() { return } } - installModels, err := ai_provider_local.ModelsInstalled() - if err != nil { - return - } - for _, model := range installModels { - - id := strings.TrimSuffix(model.Name, ":latest") - name := strings.TrimSuffix(model.Name, ":latest") - _, err = i.localModelService.Get(ctx, id) + //installModels, err := ai_provider_local.ModelsInstalled() + //if err != nil { + // return + //} + //for _, model := range installModels { + // + // id := strings.TrimSuffix(model.Name, ":latest") + // name := strings.TrimSuffix(model.Name, ":latest") + // _, err = i.localModelService.Get(ctx, id) + // if err != nil { + // if !errors.Is(err, gorm.ErrRecordNotFound) { + // return + // } + // err = i.localModelService.Create(ctx, &ai_local.CreateLocalModel{ + // Id: id, + // Name: name, + // State: 1, + // }) + // if err != nil { + // return + // } + // } + //} + i.transaction.Transaction(ctx, func(ctx context.Context) error { + localModels, err := i.localModelService.Search(ctx, "", map[string]interface{}{ + "provider": "ollama", + }) if err != nil { - if !errors.Is(err, gorm.ErrRecordNotFound) { - return - } - err = i.localModelService.Create(ctx, &ai_local.CreateLocalModel{ - Id: id, - Name: name, - State: 1, - }) - if err != nil { - return - } + return err } - } + if len(localModels) == 0 { + return nil + } + err = i.localModelService.UpdateProvider(ctx, ai_provider_local.ProviderLocal, utils.SliceToSlice(localModels, func(s *ai_local.LocalModel) string { + return s.Id + })...) + if err != nil { + return err + } + + apis, err := i.aiAPIService.Search(ctx, "", map[string]interface{}{ + "provider": "ollama", + }) + if err != nil { + return err + } + return i.aiAPIService.UpdateAIProvider(ctx, ai_provider_local.ProviderLocal, utils.SliceToSlice(apis, func(s *ai_api.API) string { + return s.ID + })...) + }) + }) } @@ -596,7 +624,7 @@ func (i *imlLocalModel) getLocalModels(ctx context.Context, v string) ([]*gatewa cfg := make(map[string]interface{}) cfg["provider"] = "ollama" cfg["model"] = l.Id - cfg["model_config"] = ai_provider_local.OllamaConfig + cfg["model_config"] = ai_provider_local.LocalConfig cfg["base"] = v releases = append(releases, &gateway.DynamicRelease{ BasicItem: &gateway.BasicItem{ diff --git a/module/ai/iml.go b/module/ai/iml.go index e0f0ae65..595c7f73 100644 --- a/module/ai/iml.go +++ b/module/ai/iml.go @@ -337,11 +337,11 @@ func (i *imlProviderModule) SimpleConfiguredProviders(ctx context.Context, all b healthProvider := make(map[string]struct{}) if all { - healthProvider["ollama"] = struct{}{} + healthProvider[ai_provider_local.ProviderLocal] = struct{}{} items = append(items, &ai_dto.SimpleProviderItem{ - Id: "ollama", - Name: "Ollama", - Logo: ai_provider_local.OllamaSvg, + Id: ai_provider_local.ProviderLocal, + Name: ai_provider_local.ProviderLocal, + Logo: ai_provider_local.LocalSvg, Configured: true, DefaultConfig: "", Status: ai_dto.ProviderEnabled, @@ -792,7 +792,7 @@ type imlAIApiModule struct { func (i *imlAIApiModule) APIs(ctx context.Context, keyword string, providerId string, start int64, end int64, page int, pageSize int, sortCondition string, asc bool, models []string, serviceIds []string) ([]*ai_dto.APIItem, *ai_dto.Condition, int64, error) { modelItems := make([]*ai_dto.BasicInfo, 0) - if providerId == "ollama" { + if providerId == ai_provider_local.ProviderLocal { items, err := i.aiLocalModelService.Search(ctx, "", nil, "update_at desc") if err != nil { return nil, nil, 0, err diff --git a/module/service/dto/output.go b/module/service/dto/output.go index 636420f0..0bd258c3 100644 --- a/module/service/dto/output.go +++ b/module/service/dto/output.go @@ -1,6 +1,7 @@ package service_dto import ( + ai_provider_local "github.com/APIParkLab/APIPark/ai-provider/local" "github.com/APIParkLab/APIPark/service/service" "github.com/eolinker/go-common/auto" ) @@ -151,7 +152,7 @@ func ToService(model *service.Service) *Service { provider := auto.UUID(model.AdditionalConfig["provider"]) s.Provider = &provider s.ProviderType = "local" - if provider.Id != "ollama" { + if provider.Id != ai_provider_local.ProviderLocal { s.ProviderType = "online" } modelId := model.AdditionalConfig["model"] diff --git a/module/system/iml.go b/module/system/iml.go index b74a231b..20e6b6b5 100644 --- a/module/system/iml.go +++ b/module/system/iml.go @@ -49,7 +49,7 @@ func (i *imlSettingModule) Set(ctx context.Context, input *system_dto.InputSetti } } if input.OllamaAddress != nil { - ai_provider_local.ResetOllamaAddress(*input.OllamaAddress) + ai_provider_local.ResetLocalAddress(*input.OllamaAddress) } return nil }) @@ -61,7 +61,7 @@ func (i *imlSettingModule) OnInit() { address, has := i.settingService.Get(ctx, "system.ai_model.ollama_address") if has { - ai_provider_local.ResetOllamaAddress(address) + ai_provider_local.ResetLocalAddress(address) } }) diff --git a/service/ai-api/iml.go b/service/ai-api/iml.go index 358e99f5..8a8126bd 100644 --- a/service/ai-api/iml.go +++ b/service/ai-api/iml.go @@ -23,6 +23,11 @@ type imlAPIService struct { universally.IServiceDelete } +func (i *imlAPIService) UpdateAIProvider(ctx context.Context, providerId string, ids ...string) error { + _, err := i.store.UpdateField(ctx, "provider", providerId, "uuid in (?)", ids) + return err +} + func (i *imlAPIService) CountByProvider(ctx context.Context, provider string) (int64, error) { return i.store.Count(ctx, "", map[string]interface{}{"provider": provider}) } diff --git a/service/ai-api/service.go b/service/ai-api/service.go index 62ef5389..55385a08 100644 --- a/service/ai-api/service.go +++ b/service/ai-api/service.go @@ -17,6 +17,7 @@ type IAPIService interface { CountMapByModel(ctx context.Context, keyword string, conditions map[string]interface{}) (map[string]int64, error) CountByModel(ctx context.Context, model string) (int64, error) CountByProvider(ctx context.Context, provider string) (int64, error) + UpdateAIProvider(ctx context.Context, providerId string, ids ...string) error DeleteByService(ctx context.Context, serviceId string) error } diff --git a/service/ai-local/iml.go b/service/ai-local/iml.go index 5d21cc2b..7256c5f9 100644 --- a/service/ai-local/iml.go +++ b/service/ai-local/iml.go @@ -20,6 +20,11 @@ type imlLocalModelService struct { universally.IServiceDelete } +func (i *imlLocalModelService) UpdateProvider(ctx context.Context, provider string, ids ...string) error { + _, err := i.store.UpdateWhere(ctx, map[string]interface{}{"provider": provider}, map[string]interface{}{"uuid": ids}) + return err +} + func (i *imlLocalModelService) DefaultModel(ctx context.Context) (*LocalModel, error) { info, err := i.store.First(ctx, map[string]interface{}{"state": 1}) if err != nil { diff --git a/service/ai-local/service.go b/service/ai-local/service.go index 2c493b36..61e67d1b 100644 --- a/service/ai-local/service.go +++ b/service/ai-local/service.go @@ -14,6 +14,7 @@ type ILocalModelService interface { universally.IServiceEdit[EditLocalModel] universally.IServiceDelete DefaultModel(ctx context.Context) (*LocalModel, error) + UpdateProvider(ctx context.Context, provider string, ids ...string) error } type ILocalModelPackageService interface { diff --git a/service/cluster/cluster.go b/service/cluster/cluster.go index 0550c748..490f5d49 100644 --- a/service/cluster/cluster.go +++ b/service/cluster/cluster.go @@ -80,7 +80,7 @@ func (s *imlClusterService) GetLabels(ctx context.Context, ids ...string) map[st } return map[string]string{o.UUID: o.Name} } - list, err := s.store.ListQuery(ctx, "uuid in ?", []interface{}{ids}, "id") + list, err := s.store.ListQuery(ctx, "uuid in (?)", []interface{}{ids}, "id") if err != nil { return nil } diff --git a/service/universally/get-softdelete.go b/service/universally/get-softdelete.go index e191fcf6..2f83ab2b 100644 --- a/service/universally/get-softdelete.go +++ b/service/universally/get-softdelete.go @@ -46,7 +46,7 @@ func (s *imlServiceGetSoftDelete[T, E]) List(ctx context.Context, uuids ...strin where = append(where, "uuid = ?") args = append(args, uuids[0]) } else { - where = append(where, "uuid in ?") + where = append(where, "uuid in (?)") args = append(args, uuids) } } diff --git a/service/universally/get.go b/service/universally/get.go index a9802b2f..db417f1c 100644 --- a/service/universally/get.go +++ b/service/universally/get.go @@ -59,7 +59,7 @@ func (s *imlServiceGet[T, E]) List(ctx context.Context, uuids ...string) ([]*T, where = append(where, "uuid = ?") args = append(args, uuids[0]) } else { - where = append(where, "uuid in ?") + where = append(where, "uuid in (?)") args = append(args, uuids) } }