diff --git a/ai-provider/local/executor.go b/ai-provider/local/executor.go index b6648d23..56959ceb 100644 --- a/ai-provider/local/executor.go +++ b/ai-provider/local/executor.go @@ -210,6 +210,9 @@ func (e *AsyncExecutor) DistributeToModelPipelines(model string, msg PullMessage type PullCallback func(msg PullMessage) error func PullModel(model string, id string, fn PullCallback) (*Pipeline, error) { + if client == nil { + return nil, fmt.Errorf("client not initialized") + } mp, has := taskExecutor.GetModelPipeline(model) if !has { mp = newModelPipeline(taskExecutor.ctx, 100) @@ -279,6 +282,9 @@ func PullModel(model string, id string, fn PullCallback) (*Pipeline, error) { } func StopPull(model string) { + if client == nil { + return + } taskExecutor.CloseModelPipeline(model) } @@ -287,6 +293,9 @@ func CancelPipeline(model string, id string) { } func RemoveModel(model string) error { + if client == nil { + return fmt.Errorf("client not initialized") + } taskExecutor.CloseModelPipeline(model) err := client.Delete(context.Background(), &api.DeleteRequest{Model: model}) if err != nil { @@ -298,6 +307,9 @@ func RemoveModel(model string) error { } func ModelsInstalled() ([]Model, error) { + if client == nil { + return nil, fmt.Errorf("client not initialized") + } result, err := client.List(context.Background()) if err != nil { return nil, err diff --git a/ai-provider/local/local.go b/ai-provider/local/local.go index 93659245..5e084eeb 100644 --- a/ai-provider/local/local.go +++ b/ai-provider/local/local.go @@ -4,27 +4,18 @@ import ( "net/http" "net/url" - "github.com/eolinker/eosc/env" "github.com/ollama/ollama/api" ) var ( - ollamaAddress = "http://127.0.0.1:11434" - EnvOllamaAddress = "OLLAMA_ADDRESS" - client *api.Client + client *api.Client ) -func init() { - address, has := env.GetEnv(EnvOllamaAddress) - if !has { - address = ollamaAddress - } +func ResetOllamaAddress(address string) error { u, err := url.Parse(address) if err != nil { - u, err = url.Parse(ollamaAddress) - if err != nil { - panic(err) - } + return err } client = api.NewClient(u, http.DefaultClient) + return nil } diff --git a/ai-provider/local/ollama.go b/ai-provider/local/ollama.go index 31660541..be765848 100644 --- a/ai-provider/local/ollama.go +++ b/ai-provider/local/ollama.go @@ -1,7 +1,6 @@ package ai_provider_local var ( - OllamaBase = "http://apipark-ollama:11434" OllamaConfig = "{\n \"mirostat\": 0,\n \"mirostat_eta\": 0.1,\n \"mirostat_tau\": 5.0,\n \"num_ctx\": 4096,\n \"repeat_last_n\":64,\n \"repeat_penalty\": 1.1,\n \"temperature\": 0.7,\n \"seed\": 42,\n \"num_predict\": 42,\n \"top_k\": 40,\n \"top_p\": 0.9,\n \"min_p\": 0.5\n}\n" OllamaSvg = ` 0 { return fmt.Errorf("model %s has api, can not remove", model) } + info, err := i.localModelService.Get(ctx, model) + if err != nil { + if !errors.Is(err, gorm.ErrRecordNotFound) { + return err + } + return ai_provider_local.RemoveModel(model) + } + if info.State == ai_local_dto.LocalModelStateDeploying.Int() { + return fmt.Errorf("model %s is deploying, can not remove", model) + } return i.transaction.Transaction(ctx, func(txCtx context.Context) error { err = i.localModelService.Delete(ctx, model) if err != nil { @@ -430,8 +440,36 @@ func (i *imlLocalModel) Enable(ctx context.Context, model string) error { return err } if info.State == ai_local_dto.LocalModelStateDisable.Int() || info.State == ai_local_dto.LocalModelStateError.Int() { - status := ai_local_dto.LocalModelStateNormal.Int() - return i.localModelService.Save(ctx, model, &ai_local.EditLocalModel{State: &status}) + + return i.transaction.Transaction(ctx, func(ctx context.Context) error { + status := ai_local_dto.LocalModelStateNormal.Int() + err = i.localModelService.Save(ctx, model, &ai_local.EditLocalModel{State: &status}) + if err != nil { + return err + } + v, _ := i.settingService.Get(ctx, "system.ai_model.ollama_address") + cfg := make(map[string]interface{}) + cfg["provider"] = "ollama" + cfg["model"] = info.Id + cfg["model_config"] = ai_provider_local.OllamaConfig + cfg["priority"] = 0 + cfg["base"] = v + + return i.syncGateway(ctx, cluster.DefaultClusterID, []*gateway.DynamicRelease{ + { + BasicItem: &gateway.BasicItem{ + ID: info.Id, + Description: info.Id, + Resource: "ai-provider", + Version: info.UpdateAt.Format("20060102150405"), + MatchLabels: map[string]string{ + "module": "ai-provider", + }, + }, + Attr: cfg, + }}, true) + }) + } return fmt.Errorf("model %s is not disabled state,can not enable", model) } @@ -443,7 +481,21 @@ func (i *imlLocalModel) Disable(ctx context.Context, model string) error { } if info.State == ai_local_dto.LocalModelStateNormal.Int() { disable := ai_local_dto.LocalModelStateDisable.Int() - return i.localModelService.Save(ctx, model, &ai_local.EditLocalModel{State: &disable}) + return i.transaction.Transaction(ctx, func(ctx context.Context) error { + err = i.localModelService.Save(ctx, model, &ai_local.EditLocalModel{State: &disable}) + if err != nil { + return err + } + + return i.syncGateway(ctx, cluster.DefaultClusterID, []*gateway.DynamicRelease{ + { + BasicItem: &gateway.BasicItem{ + ID: info.Id, + Resource: "ai-provider", + }, + }}, false) + }) + } return fmt.Errorf("model %s is not enabled state,can not disable", model) } @@ -523,18 +575,29 @@ func (i *imlLocalModel) OnInit() { }) } -func (i *imlLocalModel) getLocalModels(ctx context.Context) ([]*gateway.DynamicRelease, error) { +func (i *imlLocalModel) getLocalModels(ctx context.Context, v string) ([]*gateway.DynamicRelease, error) { list, err := i.localModelService.List(ctx) if err != nil { return nil, err } + if v == "" { + var has bool + v, has = i.settingService.Get(ctx, "system.ai_model.ollama_address") + if !has { + return nil, errors.New("ollama_address not set") + } + } + releases := make([]*gateway.DynamicRelease, 0, len(list)) for _, l := range list { + if l.State != ai_local_dto.LocalModelStateNormal.Int() { + continue + } cfg := make(map[string]interface{}) cfg["provider"] = "ollama" cfg["model"] = l.Id - cfg["model_config"] = ai_provider_local.OllamaSvg - cfg["base"] = ollamaBase + cfg["model_config"] = ai_provider_local.OllamaConfig + cfg["base"] = v releases = append(releases, &gateway.DynamicRelease{ BasicItem: &gateway.BasicItem{ ID: l.Id, @@ -552,7 +615,7 @@ func (i *imlLocalModel) getLocalModels(ctx context.Context) ([]*gateway.DynamicR } func (i *imlLocalModel) initGateway(ctx context.Context, clusterId string, clientDriver gateway.IClientDriver) error { - releases, err := i.getLocalModels(ctx) + releases, err := i.getLocalModels(ctx, "") if err != nil { return err } diff --git a/module/ai-local/module.go b/module/ai-local/module.go index 38330029..131c72ca 100644 --- a/module/ai-local/module.go +++ b/module/ai-local/module.go @@ -24,6 +24,8 @@ type ILocalModelModule interface { ModelState(ctx context.Context, model string) (*ai_local_dto.DeployState, *ai_local_dto.ModelInfo, error) SimpleList(ctx context.Context) ([]*ai_local_dto.SimpleItem, error) SaveCache(ctx context.Context, model string, target string) error + + SyncLocalModels(ctx context.Context, address string) error } func init() { diff --git a/module/ai/iml.go b/module/ai/iml.go index b3b76186..d4c06d89 100644 --- a/module/ai/iml.go +++ b/module/ai/iml.go @@ -482,7 +482,7 @@ func (i *imlProviderModule) UpdateProviderConfig(ctx context.Context, id string, return fmt.Errorf("ai provider not found") } - return i.transaction.Transaction(ctx, func(txCtx context.Context) error { + return i.transaction.Transaction(ctx, func(ctx context.Context) error { info, err := i.providerService.Get(ctx, id) if err != nil { if !errors.Is(err, gorm.ErrRecordNotFound) { @@ -533,12 +533,12 @@ func (i *imlProviderModule) UpdateProviderConfig(ctx context.Context, id string, Config: &input.Config, Status: &status, } - _, err = i.aiKeyService.DefaultKey(txCtx, id) + _, err = i.aiKeyService.DefaultKey(ctx, id) if err != nil { if !errors.Is(err, gorm.ErrRecordNotFound) { return err } - err = i.aiKeyService.Create(txCtx, &ai_key.Create{ + err = i.aiKeyService.Create(ctx, &ai_key.Create{ ID: id, Name: info.Name, Config: input.Config, @@ -549,7 +549,7 @@ func (i *imlProviderModule) UpdateProviderConfig(ctx context.Context, id string, Priority: 1, }) } else { - err = i.aiKeyService.Save(txCtx, id, &ai_key.Edit{ + err = i.aiKeyService.Save(ctx, id, &ai_key.Edit{ Config: &input.Config, Status: &status, }) @@ -557,13 +557,13 @@ func (i *imlProviderModule) UpdateProviderConfig(ctx context.Context, id string, if err != nil { return err } - err = i.providerService.Save(txCtx, id, pInfo) + err = i.providerService.Save(ctx, id, pInfo) if err != nil { return err } if *pInfo.Status == 0 { - return i.syncGateway(txCtx, cluster.DefaultClusterID, []*gateway.DynamicRelease{ + return i.syncGateway(ctx, cluster.DefaultClusterID, []*gateway.DynamicRelease{ { BasicItem: &gateway.BasicItem{ ID: id, @@ -573,7 +573,7 @@ func (i *imlProviderModule) UpdateProviderConfig(ctx context.Context, id string, }, false) } // 获取当前供应商默认Key信息 - defaultKey, err := i.aiKeyService.DefaultKey(txCtx, id) + defaultKey, err := i.aiKeyService.DefaultKey(ctx, id) if err != nil { return err } @@ -582,7 +582,7 @@ func (i *imlProviderModule) UpdateProviderConfig(ctx context.Context, id string, cfg["model"] = info.DefaultLLM cfg["model_config"] = model.DefaultConfig() cfg["base"] = fmt.Sprintf("%s://%s", p.URI().Scheme(), p.URI().Host()) - return i.syncGateway(txCtx, cluster.DefaultClusterID, []*gateway.DynamicRelease{ + return i.syncGateway(ctx, cluster.DefaultClusterID, []*gateway.DynamicRelease{ { BasicItem: &gateway.BasicItem{ ID: id, diff --git a/module/service/dto/input.go b/module/service/dto/input.go index 6cfab3da..dcea95b4 100644 --- a/module/service/dto/input.go +++ b/module/service/dto/input.go @@ -20,6 +20,7 @@ type CreateService struct { Kind string `json:"service_kind"` State string `json:"state"` Provider *string `json:"provider"` + Model *string `json:"model"` AsApp *bool `json:"as_app"` AsServer *bool `json:"as_server"` } @@ -32,6 +33,7 @@ type EditService struct { Logo *string `json:"logo"` Tags *[]string `json:"tags"` Provider *string `json:"provider"` + Model *string `json:"model"` ApprovalType *string `json:"approval_type"` State *string `json:"state"` } diff --git a/module/service/dto/output.go b/module/service/dto/output.go index 016c1007..ef4e6820 100644 --- a/module/service/dto/output.go +++ b/module/service/dto/output.go @@ -97,7 +97,8 @@ type Service struct { Tags []auto.Label `json:"tags" aolabel:"tag"` Logo string `json:"logo"` Provider *auto.Label `json:"provider,omitempty" aolabel:"ai_provider"` - ProviderType string `json:"provider_type"` + ProviderType string `json:"provider_type,omitempty"` + Model string `json:"model,omitempty"` ApprovalType string `json:"approval_type"` AsServer bool `json:"as_server"` AsApp bool `json:"as_app"` @@ -152,6 +153,10 @@ func ToService(model *service.Service) *Service { if provider.Id != "ollama" { s.ProviderType = "online" } + modelId := model.AdditionalConfig["model"] + if modelId != "" { + s.Model = modelId + } } return s } diff --git a/module/service/iml.go b/module/service/iml.go index 60324a44..a36197cd 100644 --- a/module/service/iml.go +++ b/module/service/iml.go @@ -8,6 +8,10 @@ import ( "strings" "time" + ai_local "github.com/APIParkLab/APIPark/service/ai-local" + + model_runtime "github.com/APIParkLab/APIPark/ai-provider/model-runtime" + "github.com/eolinker/eosc/log" "github.com/APIParkLab/APIPark/resources/access" @@ -58,6 +62,7 @@ type imlServiceModule struct { teamService team.ITeamService `autowired:""` teamMemberService team_member.ITeamMemberService `autowired:""` tagService tag.ITagService `autowired:""` + localModelService ai_local.ILocalModelService `autowired:""` serviceTagService service_tag.ITagService `autowired:""` apiService api.IAPIService `autowired:""` @@ -223,6 +228,25 @@ func (i *imlServiceModule) Get(ctx context.Context, id string) (*service_dto.Ser s.Tags = auto.List(utils.SliceToSlice(tags, func(p *service_tag.Tag) string { return p.Tid })) + if s.Model == "" { + switch s.ProviderType { + case "online": + p, has := model_runtime.GetProvider(s.Provider.Id) + if has { + m, has := p.DefaultModel(model_runtime.ModelTypeLLM) + if has { + s.Model = m.ID() + } + } + case "local": + info, err := i.localModelService.DefaultModel(ctx) + if err != nil { + return nil, err + } + s.Model = info.Id + + } + } log.Infof("get service cost %d ms", time.Since(now).Milliseconds()) return s, nil } @@ -328,6 +352,11 @@ func (i *imlServiceModule) Create(ctx context.Context, teamID string, input *ser return nil, fmt.Errorf("ai service: provider can not be empty") } mo.AdditionalConfig["provider"] = *input.Provider + if input.Model == nil { + return nil, fmt.Errorf("ai service: model can not be empty") + } + mo.AdditionalConfig["model"] = *input.Model + } if input.AsApp == nil { // 默认值为false @@ -378,6 +407,9 @@ func (i *imlServiceModule) Edit(ctx context.Context, id string, input *service_d if input.Provider != nil { info.AdditionalConfig["provider"] = *input.Provider } + if input.Model != nil { + info.AdditionalConfig["model"] = *input.Model + } } err = i.transaction.Transaction(ctx, func(ctx context.Context) error { diff --git a/module/system/dto/input.go b/module/system/dto/input.go index 1b7d8f35..ddb8c615 100644 --- a/module/system/dto/input.go +++ b/module/system/dto/input.go @@ -6,14 +6,24 @@ import ( ) type InputSetting struct { - InvokeAddress string `json:"invoke_address" key:"system.node.invoke_address"` - SitePrefix string `json:"site_prefix" key:"system.setting.site_prefix"` + InvokeAddress *string `json:"invoke_address" key:"system.node.invoke_address"` + SitePrefix *string `json:"site_prefix" key:"system.setting.site_prefix"` + OllamaAddress *string `json:"ollama_address" key:"system.ai_model.ollama_address"` } func (i *InputSetting) Validate() error { - _, err := url.Parse(i.InvokeAddress) - if err != nil { - return err + if i.InvokeAddress != nil { + _, err := url.Parse(*i.InvokeAddress) + if err != nil { + return err + } + } + + if i.OllamaAddress != nil { + _, err := url.Parse(*i.OllamaAddress) + if err != nil { + return err + } } return nil } @@ -31,9 +41,18 @@ func ToKeyMap(i interface{}) map[string]string { { for i := 0; i < typ.NumField(); i++ { f := typ.Field(i) - if f.Tag.Get("key") != "" { - result[f.Tag.Get("key")] = val.Field(i).String() + v := val.Field(i) + if f.Type.Kind() == reflect.Ptr { + if v.IsNil() { + continue + } + v = v.Elem() } + + if f.Tag.Get("key") != "" { + result[f.Tag.Get("key")] = v.String() + } + } } } diff --git a/module/system/dto/input_test.go b/module/system/dto/input_test.go index e33908ed..1b3838de 100644 --- a/module/system/dto/input_test.go +++ b/module/system/dto/input_test.go @@ -6,9 +6,11 @@ import ( ) func TestMap(t *testing.T) { - + invokeAddress := "http://127.0.0.1:8080" + ollamaAddress := "http://127.0.0.1:8081" input := &InputSetting{ - InvokeAddress: "http://127.0.0.1:8080", + InvokeAddress: &invokeAddress, + OllamaAddress: &ollamaAddress, } err := input.Validate() if err != nil { diff --git a/module/system/dto/output.go b/module/system/dto/output.go index d6ea63d3..2185f57a 100644 --- a/module/system/dto/output.go +++ b/module/system/dto/output.go @@ -8,6 +8,7 @@ import ( type Setting struct { InvokeAddress string `json:"invoke_address" key:"system.node.invoke_address"` SitePrefix string `json:"site_prefix" key:"system.setting.site_prefix"` + OllamaAddress string `json:"ollama_address" key:"system.ai_model.ollama_address"` } func MapStringToStruct[T any](m map[string]string) *T { diff --git a/module/system/iml.go b/module/system/iml.go index b83cefe5..b74a231b 100644 --- a/module/system/iml.go +++ b/module/system/iml.go @@ -3,6 +3,11 @@ package system import ( "context" + "github.com/eolinker/go-common/server" + + ai_provider_local "github.com/APIParkLab/APIPark/ai-provider/local" + "github.com/eolinker/go-common/register" + "github.com/eolinker/go-common/store" "github.com/eolinker/go-common/utils" @@ -43,6 +48,21 @@ func (i *imlSettingModule) Set(ctx context.Context, input *system_dto.InputSetti return err } } + if input.OllamaAddress != nil { + ai_provider_local.ResetOllamaAddress(*input.OllamaAddress) + } return nil }) } + +func (i *imlSettingModule) OnInit() { + register.Handle(func(v server.Server) { + ctx := context.Background() + + address, has := i.settingService.Get(ctx, "system.ai_model.ollama_address") + if has { + ai_provider_local.ResetOllamaAddress(address) + } + + }) +} diff --git a/plugins/core/ai-local.go b/plugins/core/ai-local.go index 3671ad86..3c8f3205 100644 --- a/plugins/core/ai-local.go +++ b/plugins/core/ai-local.go @@ -17,5 +17,8 @@ func (p *plugin) aiLocalApis() []pm3.Api { pm3.CreateApiWidthDoc(http.MethodPut, "/api/v1/model/local/info", []string{"context", "query:model", "body"}, nil, p.aiLocalController.Update), pm3.CreateApiWidthDoc(http.MethodGet, "/api/v1/model/local/state", []string{"context", "query:model"}, []string{"state", "info"}, p.aiLocalController.State), pm3.CreateApiWidthDoc(http.MethodGet, "/api/v1/simple/ai/models/local/configured", []string{"context"}, []string{"models"}, p.aiLocalController.SimpleList), + + pm3.CreateApiWidthDoc(http.MethodGet, "/api/v1/model/local/source/ollama", []string{"context"}, []string{"config"}, p.aiLocalController.OllamaConfig), + pm3.CreateApiWidthDoc(http.MethodPut, "/api/v1/model/local/source/ollama", []string{"context", "body"}, nil, p.aiLocalController.OllamaConfigUpdate), } } diff --git a/service/ai-local/iml.go b/service/ai-local/iml.go index 7c992c36..5d21cc2b 100644 --- a/service/ai-local/iml.go +++ b/service/ai-local/iml.go @@ -20,6 +20,14 @@ type imlLocalModelService struct { universally.IServiceDelete } +func (i *imlLocalModelService) DefaultModel(ctx context.Context) (*LocalModel, error) { + info, err := i.store.First(ctx, map[string]interface{}{"state": 1}) + if err != nil { + return nil, err + } + return i.fromEntity(info), nil +} + func (i *imlLocalModelService) OnComplete() { i.IServiceGet = universally.NewGet[LocalModel, ai.LocalModel](i.store, i.fromEntity) i.IServiceCreate = universally.NewCreator[CreateLocalModel, ai.LocalModel](i.store, "ai_local_model", i.createEntityHandler, i.uniquestHandler, i.labelHandler) diff --git a/service/ai-local/service.go b/service/ai-local/service.go index e32458ea..2c493b36 100644 --- a/service/ai-local/service.go +++ b/service/ai-local/service.go @@ -13,6 +13,7 @@ type ILocalModelService interface { universally.IServiceCreate[CreateLocalModel] universally.IServiceEdit[EditLocalModel] universally.IServiceDelete + DefaultModel(ctx context.Context) (*LocalModel, error) } type ILocalModelPackageService interface {