diff --git a/module/ai-local/iml.go b/module/ai-local/iml.go index ebc08b5d..f6606b75 100644 --- a/module/ai-local/iml.go +++ b/module/ai-local/iml.go @@ -116,12 +116,13 @@ func (i *imlLocalModel) Search(ctx context.Context, keyword string) ([]*ai_local } return utils.SliceToSlice(list, func(s *ai_local.LocalModel) *ai_local_dto.LocalModelItem { + count := apiCountMap[s.Id] return &ai_local_dto.LocalModelItem{ Id: s.Id, Name: s.Name, State: ai_local_dto.FromLocalModelState(s.State), - APICount: apiCountMap[s.Id], - CanDelete: true, + APICount: count, + CanDelete: count < 1, UpdateTime: auto.TimeLabel(s.UpdateAt), Provider: "ollama", } @@ -372,12 +373,22 @@ func (i *imlLocalModel) CancelDeploy(ctx context.Context, model string) error { } func (i *imlLocalModel) RemoveModel(ctx context.Context, model string) error { - - err := ai_provider_local.RemoveModel(model) + // 判断是否有api + count, err := i.aiAPIService.CountByModel(ctx, model) if err != nil { return err } - return i.localModelService.Delete(ctx, model) + if count > 0 { + return fmt.Errorf("model %s has api, can not remove", model) + } + return i.transaction.Transaction(ctx, func(txCtx context.Context) error { + err = i.localModelService.Delete(ctx, model) + if err != nil { + return err + } + return ai_provider_local.RemoveModel(model) + }) + } func (i *imlLocalModel) Enable(ctx context.Context, model string) error { diff --git a/module/ai/iml.go b/module/ai/iml.go index 54145790..b3b76186 100644 --- a/module/ai/iml.go +++ b/module/ai/iml.go @@ -8,6 +8,8 @@ import ( "sort" "time" + "github.com/google/uuid" + ai_provider_local "github.com/APIParkLab/APIPark/ai-provider/local" "github.com/eolinker/go-common/register" @@ -77,14 +79,59 @@ func (i *imlProviderModule) OnInit() { if err != nil { return } - for _, l := range list { - i.providerService.Save(ctx, l.Id, &ai.SetProvider{}) - } + i.transaction.Transaction(ctx, func(ctx context.Context) error { + for _, l := range list { + if l.Priority < 1 { + continue + } + has, err := i.aiBalanceService.Exist(ctx, l.Id, l.DefaultLLM) + if err != nil { + return err + } + if has { + continue + } + + p, has := model_runtime.GetProvider(l.Id) + if !has { + continue + } + err = i.aiBalanceService.Create(ctx, &ai_balance.Create{ + Id: uuid.NewString(), + Priority: l.Priority, + Provider: l.Id, + ProviderName: p.Name(), + Model: l.DefaultLLM, + ModelName: l.DefaultLLM, + Type: 0, + }) + if err != nil { + return err + } + priority := 0 + err = i.providerService.Save(ctx, l.Id, &ai.SetProvider{ + Priority: &priority, + }) + if err != nil { + return err + } + } + return nil + }) + }) } func (i *imlProviderModule) Delete(ctx context.Context, id string) error { return i.transaction.Transaction(ctx, func(ctx context.Context) error { + // 判断是否有api + count, err := i.aiAPIService.CountByProvider(ctx, id) + if err != nil { + return err + } + if count > 0 { + return fmt.Errorf("provider has api") + } keys, err := i.aiKeyService.KeysByProvider(ctx, id) if err != nil { return err @@ -173,6 +220,7 @@ func (i *imlProviderModule) ConfiguredProviders(ctx context.Context, keyword str if !has { continue } + apiCount := aiAPIMap[l.Id] providers = append(providers, &ai_dto.ConfiguredProviderItem{ Id: l.Id, @@ -180,9 +228,9 @@ func (i *imlProviderModule) ConfiguredProviders(ctx context.Context, keyword str Logo: p.Logo(), DefaultLLM: l.DefaultLLM, Status: ai_dto.ToProviderStatus(l.Status), - APICount: aiAPIMap[l.Id], + APICount: apiCount, KeyCount: keyMap[l.Id], - CanDelete: len(list) > 1, + CanDelete: apiCount < 1, }) } @@ -216,18 +264,6 @@ func (i *imlProviderModule) SimpleProviders(ctx context.Context) ([]*ai_dto.Simp items = append(items, item) } - //sort.Slice(items, func(i, j int) bool { - // if items[i].Priority != items[j].Priority { - // if items[i].Priority == 0 { - // return false - // } - // if items[j].Priority == 0 { - // return true - // } - // return items[i].Priority < items[j].Priority - // } - // return items[i].Name < items[j].Name - //}) return items, nil } diff --git a/service/ai-api/iml.go b/service/ai-api/iml.go index f31716c2..358e99f5 100644 --- a/service/ai-api/iml.go +++ b/service/ai-api/iml.go @@ -23,6 +23,14 @@ type imlAPIService struct { universally.IServiceDelete } +func (i *imlAPIService) CountByProvider(ctx context.Context, provider string) (int64, error) { + return i.store.Count(ctx, "", map[string]interface{}{"provider": provider}) +} + +func (i *imlAPIService) CountByModel(ctx context.Context, model string) (int64, error) { + return i.store.Count(ctx, "", map[string]interface{}{"model": model}) +} + func (i *imlAPIService) DeleteByService(ctx context.Context, serviceId string) error { _, err := i.store.DeleteWhere(ctx, map[string]interface{}{"service": serviceId}) if err != nil { diff --git a/service/ai-api/service.go b/service/ai-api/service.go index d6f3f6da..62ef5389 100644 --- a/service/ai-api/service.go +++ b/service/ai-api/service.go @@ -15,6 +15,8 @@ type IAPIService interface { universally.IServiceDelete CountMapByProvider(ctx context.Context, keyword string, conditions map[string]interface{}) (map[string]int64, error) CountMapByModel(ctx context.Context, keyword string, conditions map[string]interface{}) (map[string]int64, error) + CountByModel(ctx context.Context, model string) (int64, error) + CountByProvider(ctx context.Context, provider string) (int64, error) DeleteByService(ctx context.Context, serviceId string) error } diff --git a/service/ai/iml.go b/service/ai/iml.go index b09b55c1..6eda08a9 100644 --- a/service/ai/iml.go +++ b/service/ai/iml.go @@ -126,5 +126,8 @@ func updateHandler(e *ai.Provider, i *SetProvider) { if i.Status != nil { e.Status = *i.Status } + if i.Priority != nil { + e.Priority = *i.Priority + } e.UpdateAt = time.Now() } diff --git a/service/ai/model.go b/service/ai/model.go index 2740ed88..263a79fd 100644 --- a/service/ai/model.go +++ b/service/ai/model.go @@ -32,6 +32,7 @@ type SetProvider struct { Name *string DefaultLLM *string Config *string + Priority *int Status *int }