mirror of
https://github.com/APIParkLab/APIPark.git
synced 2026-06-04 10:13:53 +08:00
Fix: AI model list keyword query failure issue
This commit is contained in:
+56
-33
@@ -8,6 +8,9 @@ import (
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/eolinker/go-common/register"
|
||||
"github.com/eolinker/go-common/server"
|
||||
|
||||
ai_local "github.com/APIParkLab/APIPark/service/ai-local"
|
||||
|
||||
ai_balance "github.com/APIParkLab/APIPark/service/ai-balance"
|
||||
@@ -64,6 +67,20 @@ type imlProviderModule struct {
|
||||
transaction store.ITransaction `autowired:""`
|
||||
}
|
||||
|
||||
func (i *imlProviderModule) OnInit() {
|
||||
register.Handle(func(v server.Server) {
|
||||
ctx := context.Background()
|
||||
|
||||
list, err := i.providerService.List(ctx)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
for _, l := range list {
|
||||
i.providerService.Save(ctx, l.Id, &ai.SetProvider{})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (i *imlProviderModule) Delete(ctx context.Context, id string) error {
|
||||
return i.transaction.Transaction(ctx, func(ctx context.Context) error {
|
||||
keys, err := i.aiKeyService.KeysByProvider(ctx, id)
|
||||
@@ -428,9 +445,6 @@ func (i *imlProviderModule) Provider(ctx context.Context, id string) (*ai_dto.Pr
|
||||
}
|
||||
defaultLLM = model
|
||||
}
|
||||
//if info.Priority == 0 {
|
||||
// info.Priority = maxPriority
|
||||
//}
|
||||
|
||||
return &ai_dto.Provider{
|
||||
Id: info.Id,
|
||||
@@ -497,38 +511,48 @@ func (i *imlProviderModule) UpdateProviderConfig(ctx context.Context, id string,
|
||||
if !has {
|
||||
return fmt.Errorf("ai provider not found")
|
||||
}
|
||||
info, err := i.providerService.Get(ctx, id)
|
||||
if err != nil {
|
||||
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
|
||||
return i.transaction.Transaction(ctx, func(txCtx context.Context) error {
|
||||
info, err := i.providerService.Get(ctx, id)
|
||||
if err != nil {
|
||||
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return err
|
||||
}
|
||||
if input.DefaultLLM == "" {
|
||||
defaultLLM, has := p.DefaultModel(model_runtime.ModelTypeLLM)
|
||||
if !has {
|
||||
return fmt.Errorf("ai provider default llm not found")
|
||||
}
|
||||
input.DefaultLLM = defaultLLM.ID()
|
||||
}
|
||||
info = &ai.Provider{
|
||||
Id: id,
|
||||
Name: p.Name(),
|
||||
DefaultLLM: input.DefaultLLM,
|
||||
Config: input.Config,
|
||||
}
|
||||
err = i.providerService.Create(ctx, &ai.CreateProvider{
|
||||
Id: info.Id,
|
||||
Name: info.Name,
|
||||
DefaultLLM: input.DefaultLLM,
|
||||
Config: input.Config,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
model, has := p.GetModel(input.DefaultLLM)
|
||||
if !has {
|
||||
return fmt.Errorf("ai provider model not found")
|
||||
}
|
||||
err = p.Check(input.Config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if input.DefaultLLM == "" {
|
||||
defaultLLM, has := p.DefaultModel(model_runtime.ModelTypeLLM)
|
||||
if !has {
|
||||
return fmt.Errorf("ai provider default llm not found")
|
||||
}
|
||||
input.DefaultLLM = defaultLLM.ID()
|
||||
input.Config, err = p.GenConfig(input.Config, info.Config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
info = &ai.Provider{
|
||||
Id: id,
|
||||
Name: p.Name(),
|
||||
DefaultLLM: input.DefaultLLM,
|
||||
Config: input.Config,
|
||||
}
|
||||
}
|
||||
model, has := p.GetModel(input.DefaultLLM)
|
||||
if !has {
|
||||
return fmt.Errorf("ai provider model not found")
|
||||
}
|
||||
err = p.Check(input.Config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
input.Config, err = p.GenConfig(input.Config, info.Config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return i.transaction.Transaction(ctx, func(txCtx context.Context) error {
|
||||
status := 0
|
||||
if input.Enable != nil && *input.Enable {
|
||||
status = 1
|
||||
@@ -537,7 +561,6 @@ func (i *imlProviderModule) UpdateProviderConfig(ctx context.Context, id string,
|
||||
Name: &info.Name,
|
||||
DefaultLLM: &input.DefaultLLM,
|
||||
Config: &input.Config,
|
||||
Priority: input.Priority,
|
||||
Status: &status,
|
||||
}
|
||||
_, err = i.aiKeyService.DefaultKey(txCtx, id)
|
||||
|
||||
+91
-69
@@ -2,91 +2,73 @@ package ai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/APIParkLab/APIPark/service/universally"
|
||||
"github.com/APIParkLab/APIPark/stores/ai"
|
||||
"github.com/eolinker/go-common/auto"
|
||||
"github.com/eolinker/go-common/utils"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var _ IProviderService = (*imlProviderService)(nil)
|
||||
|
||||
type imlProviderService struct {
|
||||
universally.IServiceGet[Provider]
|
||||
universally.IServiceCreate[CreateProvider]
|
||||
universally.IServiceEdit[SetProvider]
|
||||
universally.IServiceDelete
|
||||
store ai.IProviderStore `autowired:""`
|
||||
}
|
||||
|
||||
func (i *imlProviderService) MaxPriority(ctx context.Context) (int, error) {
|
||||
t, err := i.store.First(ctx, nil, "priority desc")
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return t.Priority, nil
|
||||
}
|
||||
|
||||
func (i *imlProviderService) Save(ctx context.Context, id string, cfg *SetProvider) error {
|
||||
userId := utils.UserId(ctx)
|
||||
now := time.Now()
|
||||
info, err := i.store.First(ctx, map[string]interface{}{"uuid": id})
|
||||
if err != nil {
|
||||
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return err
|
||||
}
|
||||
if cfg.Name == nil || cfg.Config == nil || cfg.DefaultLLM == nil {
|
||||
return errors.New("invalid params")
|
||||
}
|
||||
status := 1
|
||||
if cfg.Status != nil {
|
||||
status = *cfg.Status
|
||||
}
|
||||
priority := 1
|
||||
if cfg.Priority == nil {
|
||||
count, err := i.store.Count(ctx, "", nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
priority = int(count) + 1
|
||||
} else {
|
||||
priority = *cfg.Priority
|
||||
}
|
||||
info = &ai.Provider{
|
||||
UUID: id,
|
||||
Name: *cfg.Name,
|
||||
DefaultLLM: *cfg.DefaultLLM,
|
||||
Config: base64.RawStdEncoding.EncodeToString([]byte(*cfg.Config)),
|
||||
Status: status,
|
||||
Creator: userId,
|
||||
Updater: userId,
|
||||
Priority: priority,
|
||||
CreateAt: now,
|
||||
UpdateAt: now,
|
||||
}
|
||||
} else {
|
||||
if cfg.Name != nil {
|
||||
info.Name = *cfg.Name
|
||||
}
|
||||
if cfg.Config != nil {
|
||||
info.Config = base64.RawStdEncoding.EncodeToString([]byte(*cfg.Config))
|
||||
}
|
||||
if cfg.DefaultLLM != nil {
|
||||
info.DefaultLLM = *cfg.DefaultLLM
|
||||
}
|
||||
if cfg.Status != nil {
|
||||
info.Status = *cfg.Status
|
||||
}
|
||||
if cfg.Priority != nil {
|
||||
info.Priority = *cfg.Priority
|
||||
}
|
||||
info.Updater = userId
|
||||
info.UpdateAt = now
|
||||
}
|
||||
return i.store.Save(ctx, info)
|
||||
}
|
||||
//func (i *imlProviderService) Save(ctx context.Context, id string, cfg *SetProvider) error {
|
||||
// userId := utils.UserId(ctx)
|
||||
// now := time.Now()
|
||||
// info, err := i.store.First(ctx, map[string]interface{}{"uuid": id})
|
||||
// if err != nil {
|
||||
// if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
// return err
|
||||
// }
|
||||
// if cfg.Name == nil || cfg.Config == nil || cfg.DefaultLLM == nil {
|
||||
// return errors.New("invalid params")
|
||||
// }
|
||||
// status := 1
|
||||
// if cfg.Status != nil {
|
||||
// status = *cfg.Status
|
||||
// }
|
||||
//
|
||||
// info = &ai.Provider{
|
||||
// UUID: id,
|
||||
// Name: *cfg.Name,
|
||||
// DefaultLLM: *cfg.DefaultLLM,
|
||||
// Config: base64.RawStdEncoding.EncodeToString([]byte(*cfg.Config)),
|
||||
// Status: status,
|
||||
// Creator: userId,
|
||||
// Updater: userId,
|
||||
// //Priority: priority,
|
||||
// CreateAt: now,
|
||||
// UpdateAt: now,
|
||||
// }
|
||||
// } else {
|
||||
// if cfg.Name != nil {
|
||||
// info.Name = *cfg.Name
|
||||
// }
|
||||
// if cfg.Config != nil {
|
||||
// info.Config = base64.RawStdEncoding.EncodeToString([]byte(*cfg.Config))
|
||||
// }
|
||||
// if cfg.DefaultLLM != nil {
|
||||
// info.DefaultLLM = *cfg.DefaultLLM
|
||||
// }
|
||||
// if cfg.Status != nil {
|
||||
// info.Status = *cfg.Status
|
||||
// }
|
||||
// //if cfg.Priority != nil {
|
||||
// // info.Priority = *cfg.Priority
|
||||
// //}
|
||||
// info.Updater = userId
|
||||
// info.UpdateAt = now
|
||||
// }
|
||||
// return i.store.Save(ctx, info)
|
||||
//}
|
||||
|
||||
func (i *imlProviderService) GetLabels(ctx context.Context, ids ...string) map[string]string {
|
||||
if len(ids) == 0 {
|
||||
@@ -103,6 +85,46 @@ func (i *imlProviderService) GetLabels(ctx context.Context, ids ...string) map[s
|
||||
|
||||
func (i *imlProviderService) OnComplete() {
|
||||
i.IServiceGet = universally.NewGet[Provider, ai.Provider](i.store, FromEntity)
|
||||
i.IServiceCreate = universally.NewCreator[CreateProvider, ai.Provider](i.store, "ai_provider", createEntityHandler, uniquestHandler, labelHandler)
|
||||
i.IServiceEdit = universally.NewEdit[SetProvider, ai.Provider](i.store, updateHandler, labelHandler)
|
||||
i.IServiceDelete = universally.NewDelete[ai.Provider](i.store)
|
||||
auto.RegisterService("ai_provider", i)
|
||||
}
|
||||
|
||||
func labelHandler(e *ai.Provider) []string {
|
||||
return []string{e.Name, e.UUID}
|
||||
}
|
||||
|
||||
func uniquestHandler(i *CreateProvider) []map[string]interface{} {
|
||||
return []map[string]interface{}{{"uuid": i.Id}}
|
||||
}
|
||||
|
||||
func createEntityHandler(i *CreateProvider) *ai.Provider {
|
||||
//cfg, _ := json.Marshal(i.Config)
|
||||
now := time.Now()
|
||||
return &ai.Provider{
|
||||
UUID: i.Id,
|
||||
Name: i.Name,
|
||||
DefaultLLM: i.DefaultLLM,
|
||||
Config: i.Config,
|
||||
Status: i.Status,
|
||||
CreateAt: now,
|
||||
UpdateAt: now,
|
||||
}
|
||||
}
|
||||
|
||||
func updateHandler(e *ai.Provider, i *SetProvider) {
|
||||
if i.Name != nil {
|
||||
e.Name = *i.Name
|
||||
}
|
||||
if i.DefaultLLM != nil {
|
||||
e.DefaultLLM = *i.DefaultLLM
|
||||
}
|
||||
if i.Config != nil {
|
||||
e.Config = *i.Config
|
||||
}
|
||||
if i.Status != nil {
|
||||
e.Status = *i.Status
|
||||
}
|
||||
e.UpdateAt = time.Now()
|
||||
}
|
||||
|
||||
+8
-1
@@ -20,12 +20,19 @@ type Provider struct {
|
||||
UpdateAt time.Time
|
||||
}
|
||||
|
||||
type CreateProvider struct {
|
||||
Id string
|
||||
Name string
|
||||
DefaultLLM string
|
||||
Config string
|
||||
Status int
|
||||
}
|
||||
|
||||
type SetProvider struct {
|
||||
Name *string
|
||||
DefaultLLM *string
|
||||
Config *string
|
||||
Status *int
|
||||
Priority *int
|
||||
}
|
||||
|
||||
func FromEntity(e *ai.Provider) *Provider {
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
|
||||
"github.com/APIParkLab/APIPark/service/universally"
|
||||
@@ -10,9 +9,11 @@ import (
|
||||
|
||||
type IProviderService interface {
|
||||
universally.IServiceGet[Provider]
|
||||
universally.IServiceCreate[CreateProvider]
|
||||
universally.IServiceEdit[SetProvider]
|
||||
universally.IServiceDelete
|
||||
Save(ctx context.Context, id string, cfg *SetProvider) error
|
||||
MaxPriority(ctx context.Context) (int, error)
|
||||
//Save(ctx context.Context, id string, cfg *SetProvider) error
|
||||
//MaxPriority(ctx context.Context) (int, error)
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
||||
Reference in New Issue
Block a user