From 34c971ad77009384da9a2ac37941f05a068f0e4c Mon Sep 17 00:00:00 2001 From: Liujian <824010343@qq.com> Date: Fri, 14 Feb 2025 23:46:29 +0800 Subject: [PATCH] Fix: AI model list keyword query failure issue --- module/ai/iml.go | 89 ++++++++++++++--------- service/ai/iml.go | 160 ++++++++++++++++++++++++------------------ service/ai/model.go | 9 ++- service/ai/service.go | 7 +- 4 files changed, 159 insertions(+), 106 deletions(-) diff --git a/module/ai/iml.go b/module/ai/iml.go index dc69b65e..24821923 100644 --- a/module/ai/iml.go +++ b/module/ai/iml.go @@ -8,6 +8,9 @@ import ( "sort" "time" + "github.com/eolinker/go-common/register" + "github.com/eolinker/go-common/server" + ai_local "github.com/APIParkLab/APIPark/service/ai-local" ai_balance "github.com/APIParkLab/APIPark/service/ai-balance" @@ -64,6 +67,20 @@ type imlProviderModule struct { transaction store.ITransaction `autowired:""` } +func (i *imlProviderModule) OnInit() { + register.Handle(func(v server.Server) { + ctx := context.Background() + + list, err := i.providerService.List(ctx) + if err != nil { + return + } + for _, l := range list { + i.providerService.Save(ctx, l.Id, &ai.SetProvider{}) + } + }) +} + func (i *imlProviderModule) Delete(ctx context.Context, id string) error { return i.transaction.Transaction(ctx, func(ctx context.Context) error { keys, err := i.aiKeyService.KeysByProvider(ctx, id) @@ -428,9 +445,6 @@ func (i *imlProviderModule) Provider(ctx context.Context, id string) (*ai_dto.Pr } defaultLLM = model } - //if info.Priority == 0 { - // info.Priority = maxPriority - //} return &ai_dto.Provider{ Id: info.Id, @@ -497,38 +511,48 @@ func (i *imlProviderModule) UpdateProviderConfig(ctx context.Context, id string, if !has { return fmt.Errorf("ai provider not found") } - info, err := i.providerService.Get(ctx, id) - if err != nil { - if !errors.Is(err, gorm.ErrRecordNotFound) { + + return i.transaction.Transaction(ctx, func(txCtx context.Context) error { + info, err := i.providerService.Get(ctx, id) + if err != nil { + if !errors.Is(err, gorm.ErrRecordNotFound) { + return err + } + if input.DefaultLLM == "" { + defaultLLM, has := p.DefaultModel(model_runtime.ModelTypeLLM) + if !has { + return fmt.Errorf("ai provider default llm not found") + } + input.DefaultLLM = defaultLLM.ID() + } + info = &ai.Provider{ + Id: id, + Name: p.Name(), + DefaultLLM: input.DefaultLLM, + Config: input.Config, + } + err = i.providerService.Create(ctx, &ai.CreateProvider{ + Id: info.Id, + Name: info.Name, + DefaultLLM: input.DefaultLLM, + Config: input.Config, + }) + if err != nil { + return err + } + } + model, has := p.GetModel(input.DefaultLLM) + if !has { + return fmt.Errorf("ai provider model not found") + } + err = p.Check(input.Config) + if err != nil { return err } - if input.DefaultLLM == "" { - defaultLLM, has := p.DefaultModel(model_runtime.ModelTypeLLM) - if !has { - return fmt.Errorf("ai provider default llm not found") - } - input.DefaultLLM = defaultLLM.ID() + input.Config, err = p.GenConfig(input.Config, info.Config) + if err != nil { + return err } - info = &ai.Provider{ - Id: id, - Name: p.Name(), - DefaultLLM: input.DefaultLLM, - Config: input.Config, - } - } - model, has := p.GetModel(input.DefaultLLM) - if !has { - return fmt.Errorf("ai provider model not found") - } - err = p.Check(input.Config) - if err != nil { - return err - } - input.Config, err = p.GenConfig(input.Config, info.Config) - if err != nil { - return err - } - return i.transaction.Transaction(ctx, func(txCtx context.Context) error { status := 0 if input.Enable != nil && *input.Enable { status = 1 @@ -537,7 +561,6 @@ func (i *imlProviderModule) UpdateProviderConfig(ctx context.Context, id string, Name: &info.Name, DefaultLLM: &input.DefaultLLM, Config: &input.Config, - Priority: input.Priority, Status: &status, } _, err = i.aiKeyService.DefaultKey(txCtx, id) diff --git a/service/ai/iml.go b/service/ai/iml.go index a2217919..b09b55c1 100644 --- a/service/ai/iml.go +++ b/service/ai/iml.go @@ -2,91 +2,73 @@ package ai import ( "context" - "encoding/base64" - "errors" "time" "github.com/APIParkLab/APIPark/service/universally" "github.com/APIParkLab/APIPark/stores/ai" "github.com/eolinker/go-common/auto" "github.com/eolinker/go-common/utils" - "gorm.io/gorm" ) var _ IProviderService = (*imlProviderService)(nil) type imlProviderService struct { universally.IServiceGet[Provider] + universally.IServiceCreate[CreateProvider] + universally.IServiceEdit[SetProvider] universally.IServiceDelete store ai.IProviderStore `autowired:""` } -func (i *imlProviderService) MaxPriority(ctx context.Context) (int, error) { - t, err := i.store.First(ctx, nil, "priority desc") - if err != nil { - return 0, err - } - return t.Priority, nil -} - -func (i *imlProviderService) Save(ctx context.Context, id string, cfg *SetProvider) error { - userId := utils.UserId(ctx) - now := time.Now() - info, err := i.store.First(ctx, map[string]interface{}{"uuid": id}) - if err != nil { - if !errors.Is(err, gorm.ErrRecordNotFound) { - return err - } - if cfg.Name == nil || cfg.Config == nil || cfg.DefaultLLM == nil { - return errors.New("invalid params") - } - status := 1 - if cfg.Status != nil { - status = *cfg.Status - } - priority := 1 - if cfg.Priority == nil { - count, err := i.store.Count(ctx, "", nil) - if err != nil { - return err - } - priority = int(count) + 1 - } else { - priority = *cfg.Priority - } - info = &ai.Provider{ - UUID: id, - Name: *cfg.Name, - DefaultLLM: *cfg.DefaultLLM, - Config: base64.RawStdEncoding.EncodeToString([]byte(*cfg.Config)), - Status: status, - Creator: userId, - Updater: userId, - Priority: priority, - CreateAt: now, - UpdateAt: now, - } - } else { - if cfg.Name != nil { - info.Name = *cfg.Name - } - if cfg.Config != nil { - info.Config = base64.RawStdEncoding.EncodeToString([]byte(*cfg.Config)) - } - if cfg.DefaultLLM != nil { - info.DefaultLLM = *cfg.DefaultLLM - } - if cfg.Status != nil { - info.Status = *cfg.Status - } - if cfg.Priority != nil { - info.Priority = *cfg.Priority - } - info.Updater = userId - info.UpdateAt = now - } - return i.store.Save(ctx, info) -} +//func (i *imlProviderService) Save(ctx context.Context, id string, cfg *SetProvider) error { +// userId := utils.UserId(ctx) +// now := time.Now() +// info, err := i.store.First(ctx, map[string]interface{}{"uuid": id}) +// if err != nil { +// if !errors.Is(err, gorm.ErrRecordNotFound) { +// return err +// } +// if cfg.Name == nil || cfg.Config == nil || cfg.DefaultLLM == nil { +// return errors.New("invalid params") +// } +// status := 1 +// if cfg.Status != nil { +// status = *cfg.Status +// } +// +// info = &ai.Provider{ +// UUID: id, +// Name: *cfg.Name, +// DefaultLLM: *cfg.DefaultLLM, +// Config: base64.RawStdEncoding.EncodeToString([]byte(*cfg.Config)), +// Status: status, +// Creator: userId, +// Updater: userId, +// //Priority: priority, +// CreateAt: now, +// UpdateAt: now, +// } +// } else { +// if cfg.Name != nil { +// info.Name = *cfg.Name +// } +// if cfg.Config != nil { +// info.Config = base64.RawStdEncoding.EncodeToString([]byte(*cfg.Config)) +// } +// if cfg.DefaultLLM != nil { +// info.DefaultLLM = *cfg.DefaultLLM +// } +// if cfg.Status != nil { +// info.Status = *cfg.Status +// } +// //if cfg.Priority != nil { +// // info.Priority = *cfg.Priority +// //} +// info.Updater = userId +// info.UpdateAt = now +// } +// return i.store.Save(ctx, info) +//} func (i *imlProviderService) GetLabels(ctx context.Context, ids ...string) map[string]string { if len(ids) == 0 { @@ -103,6 +85,46 @@ func (i *imlProviderService) GetLabels(ctx context.Context, ids ...string) map[s func (i *imlProviderService) OnComplete() { i.IServiceGet = universally.NewGet[Provider, ai.Provider](i.store, FromEntity) + i.IServiceCreate = universally.NewCreator[CreateProvider, ai.Provider](i.store, "ai_provider", createEntityHandler, uniquestHandler, labelHandler) + i.IServiceEdit = universally.NewEdit[SetProvider, ai.Provider](i.store, updateHandler, labelHandler) i.IServiceDelete = universally.NewDelete[ai.Provider](i.store) auto.RegisterService("ai_provider", i) } + +func labelHandler(e *ai.Provider) []string { + return []string{e.Name, e.UUID} +} + +func uniquestHandler(i *CreateProvider) []map[string]interface{} { + return []map[string]interface{}{{"uuid": i.Id}} +} + +func createEntityHandler(i *CreateProvider) *ai.Provider { + //cfg, _ := json.Marshal(i.Config) + now := time.Now() + return &ai.Provider{ + UUID: i.Id, + Name: i.Name, + DefaultLLM: i.DefaultLLM, + Config: i.Config, + Status: i.Status, + CreateAt: now, + UpdateAt: now, + } +} + +func updateHandler(e *ai.Provider, i *SetProvider) { + if i.Name != nil { + e.Name = *i.Name + } + if i.DefaultLLM != nil { + e.DefaultLLM = *i.DefaultLLM + } + if i.Config != nil { + e.Config = *i.Config + } + if i.Status != nil { + e.Status = *i.Status + } + e.UpdateAt = time.Now() +} diff --git a/service/ai/model.go b/service/ai/model.go index 80c6b43f..2740ed88 100644 --- a/service/ai/model.go +++ b/service/ai/model.go @@ -20,12 +20,19 @@ type Provider struct { UpdateAt time.Time } +type CreateProvider struct { + Id string + Name string + DefaultLLM string + Config string + Status int +} + type SetProvider struct { Name *string DefaultLLM *string Config *string Status *int - Priority *int } func FromEntity(e *ai.Provider) *Provider { diff --git a/service/ai/service.go b/service/ai/service.go index 7c1f7e27..c4201db6 100644 --- a/service/ai/service.go +++ b/service/ai/service.go @@ -1,7 +1,6 @@ package ai import ( - "context" "reflect" "github.com/APIParkLab/APIPark/service/universally" @@ -10,9 +9,11 @@ import ( type IProviderService interface { universally.IServiceGet[Provider] + universally.IServiceCreate[CreateProvider] + universally.IServiceEdit[SetProvider] universally.IServiceDelete - Save(ctx context.Context, id string, cfg *SetProvider) error - MaxPriority(ctx context.Context) (int, error) + //Save(ctx context.Context, id string, cfg *SetProvider) error + //MaxPriority(ctx context.Context) (int, error) } func init() {