diff --git a/controller/ai-balance/controller.go b/controller/ai-balance/controller.go new file mode 100644 index 00000000..96969176 --- /dev/null +++ b/controller/ai-balance/controller.go @@ -0,0 +1,22 @@ +package ai_balance + +import ( + "reflect" + + ai_balance_dto "github.com/APIParkLab/APIPark/module/ai-balance/dto" + "github.com/eolinker/go-common/autowire" + "github.com/gin-gonic/gin" +) + +type IBalanceController interface { + List(ctx *gin.Context, keyword string) ([]*ai_balance_dto.Item, error) + Sort(ctx *gin.Context, input *ai_balance_dto.Sort) error + Create(ctx *gin.Context, input *ai_balance_dto.Create) error + Delete(ctx *gin.Context, id string) error +} + +func init() { + autowire.Auto[IBalanceController](func() reflect.Value { + return reflect.ValueOf(new(imlBalanceController)) + }) +} diff --git a/controller/ai-balance/iml.go b/controller/ai-balance/iml.go new file mode 100644 index 00000000..f21585d3 --- /dev/null +++ b/controller/ai-balance/iml.go @@ -0,0 +1,29 @@ +package ai_balance + +import ( + ai_balance "github.com/APIParkLab/APIPark/module/ai-balance" + ai_balance_dto "github.com/APIParkLab/APIPark/module/ai-balance/dto" + "github.com/gin-gonic/gin" +) + +var _ IBalanceController = (*imlBalanceController)(nil) + +type imlBalanceController struct { + module ai_balance.IBalanceModule `autowired:""` +} + +func (i *imlBalanceController) List(ctx *gin.Context, keyword string) ([]*ai_balance_dto.Item, error) { + return i.module.List(ctx, keyword) +} + +func (i *imlBalanceController) Sort(ctx *gin.Context, input *ai_balance_dto.Sort) error { + return i.module.Sort(ctx, input) +} + +func (i *imlBalanceController) Create(ctx *gin.Context, input *ai_balance_dto.Create) error { + return i.module.Create(ctx, input) +} + +func (i *imlBalanceController) Delete(ctx *gin.Context, id string) error { + return i.module.Delete(ctx, id) +} diff --git a/controller/ai-local/controller.go b/controller/ai-local/controller.go index 1af7bb7d..8e79e7eb 100644 --- a/controller/ai-local/controller.go +++ b/controller/ai-local/controller.go @@ -17,6 +17,7 @@ type ILocalModelController interface { RemoveModel(ctx *gin.Context, model string) error Update(ctx *gin.Context, model string, input *ai_local_dto.Update) error State(ctx *gin.Context, model string) (*ai_local_dto.DeployState, *ai_local_dto.ModelInfo, error) + SimpleList(ctx *gin.Context) ([]*ai_local_dto.SimpleItem, error) } func init() { diff --git a/controller/ai-local/iml.go b/controller/ai-local/iml.go index 885f0330..40ddf818 100644 --- a/controller/ai-local/iml.go +++ b/controller/ai-local/iml.go @@ -53,6 +53,10 @@ type imlLocalModelController struct { transaction store.ITransaction `autowired:""` } +func (i *imlLocalModelController) SimpleList(ctx *gin.Context) ([]*ai_local_dto.SimpleItem, error) { + return i.module.SimpleList(ctx) +} + func (i *imlLocalModelController) State(ctx *gin.Context, model string) (*ai_local_dto.DeployState, *ai_local_dto.ModelInfo, error) { return i.module.ModelState(ctx, model) } diff --git a/go.mod b/go.mod index d89e2625..51c26ca5 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( github.com/eolinker/go-common v1.1.4 github.com/gabriel-vasile/mimetype v1.4.4 github.com/getkin/kin-openapi v0.127.0 + github.com/gin-contrib/gzip v1.0.1 github.com/gin-gonic/gin v1.10.0 github.com/go-sql-driver/mysql v1.7.0 github.com/google/uuid v1.6.0 @@ -35,7 +36,6 @@ require ( github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/dustin/go-humanize v1.0.0 // indirect github.com/ghodss/yaml v1.0.0 // indirect - github.com/gin-contrib/gzip v1.0.1 // indirect github.com/gin-contrib/sse v0.1.0 // indirect github.com/go-openapi/jsonpointer v0.21.0 // indirect github.com/go-openapi/swag v0.23.0 // indirect diff --git a/module/ai-api/schema.go b/module/ai-api/schema.go index 6d6bc4c1..aaa26bee 100644 --- a/module/ai-api/schema.go +++ b/module/ai-api/schema.go @@ -131,7 +131,7 @@ func genMessageSchema() *openapi3.Schema { result.Description = "Chat Message" roleSchema := openapi3.NewStringSchema() roleSchema.Description = "Role of the message sender" - roleSchema.Example = "assistant" + roleSchema.Example = "user" contentSchema := openapi3.NewStringSchema() contentSchema.Description = "The message content" contentSchema.Example = "Hello, how can I help you?" diff --git a/module/ai-balance/iml.go b/module/ai-balance/iml.go index 45631f2f..18414162 100644 --- a/module/ai-balance/iml.go +++ b/module/ai-balance/iml.go @@ -2,9 +2,14 @@ package ai_balance import ( "context" + "errors" "fmt" "sort" + model_runtime "github.com/APIParkLab/APIPark/ai-provider/model-runtime" + + "gorm.io/gorm" + ai_key "github.com/APIParkLab/APIPark/service/ai-key" ai_api "github.com/APIParkLab/APIPark/service/ai-api" @@ -36,18 +41,28 @@ type imlBalanceModule struct { func (i *imlBalanceModule) Create(ctx context.Context, input *ai_balance_dto.Create) error { priority, err := i.balanceService.MaxPriority(ctx) if err != nil { - return err + if !errors.Is(err, gorm.ErrRecordNotFound) { + return err + } + priority = 0 } if input.Id == "" { input.Id = uuid.NewString() } providerName := "" modelName := "" - // TODO: 名称进行优化 switch input.Type { case ai_balance_dto.ModelTypeOnline: + p, has := model_runtime.GetProvider(input.Provider) + if !has { + return fmt.Errorf("provider not found") + } + providerName = p.Name() + modelName = input.Model case ai_balance_dto.ModelTypeLocal: - + input.Provider = "ollama" + providerName = "Ollama" + modelName = input.Model } return i.balanceService.Create(ctx, &ai_balance.Create{ Id: input.Id, @@ -85,8 +100,8 @@ func (i *imlBalanceModule) Sort(ctx context.Context, input *ai_balance_dto.Sort) return nil } -func (i *imlBalanceModule) List(ctx context.Context) ([]*ai_balance_dto.Item, error) { - list, err := i.balanceService.List(ctx) +func (i *imlBalanceModule) List(ctx context.Context, keyword string) ([]*ai_balance_dto.Item, error) { + list, err := i.balanceService.Search(ctx, keyword, nil, "priority asc") if err != nil { return nil, err } diff --git a/module/ai-balance/module.go b/module/ai-balance/module.go index 4604d602..51f99a32 100644 --- a/module/ai-balance/module.go +++ b/module/ai-balance/module.go @@ -12,7 +12,7 @@ import ( type IBalanceModule interface { Create(ctx context.Context, input *ai_balance_dto.Create) error Sort(ctx context.Context, input *ai_balance_dto.Sort) error - List(ctx context.Context) ([]*ai_balance_dto.Item, error) + List(ctx context.Context, keyword string) ([]*ai_balance_dto.Item, error) Delete(ctx context.Context, id string) error } diff --git a/module/ai-local/dto/output.go b/module/ai-local/dto/output.go index f4db3e44..3486907f 100644 --- a/module/ai-local/dto/output.go +++ b/module/ai-local/dto/output.go @@ -58,12 +58,19 @@ func FromLocalModelState(state int) LocalModelState { } } +type SimpleItem struct { + Id string `json:"id"` + Name string `json:"name"` +} + type LocalModelItem struct { - Id string `json:"id"` - Name string `json:"name"` - State LocalModelState `json:"state"` - APICount int64 `json:"api_count"` - UpdateTime auto.TimeLabel `json:"update_time"` + Id string `json:"id"` + Name string `json:"name"` + State LocalModelState `json:"state"` + APICount int64 `json:"api_count"` + + UpdateTime auto.TimeLabel `json:"update_time"` + CanDelete bool `json:"can_delete"` } type LocalModelPackageItem struct { diff --git a/module/ai-local/iml.go b/module/ai-local/iml.go index d4a57bc2..c81f2f11 100644 --- a/module/ai-local/iml.go +++ b/module/ai-local/iml.go @@ -37,6 +37,24 @@ type imlLocalModel struct { transaction store.ITransaction `autowired:""` } +func (i *imlLocalModel) SimpleList(ctx context.Context) ([]*ai_local_dto.SimpleItem, error) { + list, err := i.localModelService.List(ctx) + if err != nil { + return nil, err + } + return utils.SliceToSlice(list, func(s *ai_local.LocalModel) *ai_local_dto.SimpleItem { + return &ai_local_dto.SimpleItem{ + Id: s.Id, + Name: s.Name, + } + }, func(l *ai_local.LocalModel) bool { + if l.State != ai_local_dto.LocalModelStateNormal.Int() && l.State != ai_local_dto.LocalModelStateDisable.Int() { + return false + } + return true + }), nil +} + func (i *imlLocalModel) ModelState(ctx context.Context, model string) (*ai_local_dto.DeployState, *ai_local_dto.ModelInfo, error) { info, err := i.localModelStateService.Get(ctx, model) if err != nil { @@ -68,6 +86,7 @@ func (i *imlLocalModel) Search(ctx context.Context, keyword string) ([]*ai_local Name: s.Name, State: ai_local_dto.FromLocalModelState(s.State), APICount: apiCountMap[s.Id], + CanDelete: true, UpdateTime: auto.TimeLabel(s.UpdateAt), } }), nil diff --git a/module/ai-local/module.go b/module/ai-local/module.go index 5e8f6a7c..bf40ff4a 100644 --- a/module/ai-local/module.go +++ b/module/ai-local/module.go @@ -20,6 +20,7 @@ type ILocalModelModule interface { Enable(ctx context.Context, model string) error Disable(ctx context.Context, model string) error ModelState(ctx context.Context, model string) (*ai_local_dto.DeployState, *ai_local_dto.ModelInfo, error) + SimpleList(ctx context.Context) ([]*ai_local_dto.SimpleItem, error) } func init() { diff --git a/plugins/core/ai-local.go b/plugins/core/ai-local.go index 28bcb3ad..3671ad86 100644 --- a/plugins/core/ai-local.go +++ b/plugins/core/ai-local.go @@ -16,5 +16,6 @@ func (p *plugin) aiLocalApis() []pm3.Api { pm3.CreateApiWidthDoc(http.MethodDelete, "/api/v1/model/local", []string{"context", "query:model"}, nil, p.aiLocalController.RemoveModel), pm3.CreateApiWidthDoc(http.MethodPut, "/api/v1/model/local/info", []string{"context", "query:model", "body"}, nil, p.aiLocalController.Update), pm3.CreateApiWidthDoc(http.MethodGet, "/api/v1/model/local/state", []string{"context", "query:model"}, []string{"state", "info"}, p.aiLocalController.State), + pm3.CreateApiWidthDoc(http.MethodGet, "/api/v1/simple/ai/models/local/configured", []string{"context"}, []string{"models"}, p.aiLocalController.SimpleList), } } diff --git a/plugins/core/ai.go b/plugins/core/ai.go index 9cd04d37..dfc0cb69 100644 --- a/plugins/core/ai.go +++ b/plugins/core/ai.go @@ -36,3 +36,12 @@ func (p *plugin) aiKeyApis() []pm3.Api { pm3.CreateApiWidthDoc(http.MethodPut, "/api/v1/ai/resource/key/sort", []string{"context", "query:provider", "body"}, nil, p.aiKeyController.Sort), } } + +func (p *plugin) aiBalanceAPIs() []pm3.Api { + return []pm3.Api{ + pm3.CreateApiWidthDoc(http.MethodGet, "/api/v1/ai/balances", []string{"context", "query:keyword"}, []string{"list"}, p.aiBalanceController.List), + pm3.CreateApiWidthDoc(http.MethodPut, "/api/v1/ai/balance/sort", []string{"context", "body"}, nil, p.aiBalanceController.Sort), + pm3.CreateApiWidthDoc(http.MethodPost, "/api/v1/ai/balance", []string{"context", "body"}, nil, p.aiBalanceController.Create), + pm3.CreateApiWidthDoc(http.MethodDelete, "/api/v1/ai/balance", []string{"context", "query:id"}, nil, p.aiBalanceController.Delete), + } +} diff --git a/plugins/core/core.go b/plugins/core/core.go index 1d78901f..be4e3f00 100644 --- a/plugins/core/core.go +++ b/plugins/core/core.go @@ -3,6 +3,8 @@ package core import ( "net/http" + ai_balance "github.com/APIParkLab/APIPark/controller/ai-balance" + ai_local "github.com/APIParkLab/APIPark/controller/ai-local" ai_key "github.com/APIParkLab/APIPark/controller/ai-key" @@ -80,6 +82,7 @@ type plugin struct { aiAPIController ai_api.IAPIController `autowired:""` aiStatisticController ai.IStatisticController `autowired:""` aiKeyController ai_key.IKeyController `autowired:""` + aiBalanceController ai_balance.IBalanceController `autowired:""` aiLocalController ai_local.ILocalModelController `autowired:""` apiDocController router.IAPIDocController `autowired:""` subscribeController subscribe.ISubscribeController `autowired:""` @@ -122,6 +125,7 @@ func (p *plugin) OnComplete() { p.apis = append(p.apis, p.strategyApis()...) p.apis = append(p.apis, p.logApis()...) p.apis = append(p.apis, p.aiLocalApis()...) + p.apis = append(p.apis, p.aiBalanceAPIs()...) } func (p *plugin) Name() string { diff --git a/service/ai-balance/iml.go b/service/ai-balance/iml.go index 01db6fba..34f2dd3c 100644 --- a/service/ai-balance/iml.go +++ b/service/ai-balance/iml.go @@ -53,16 +53,16 @@ func (i *imlBalanceService) SortBefore(ctx context.Context, originID string, tar fn := func(priority int) int { return priority + 1 } - sql := "sort < ? and sort >= ?" + sql := "priority < ? and priority >= ?" if originKeySort < targetKeySort { // 如果原始Key在目标Key之前,中间的key往前移动,原始Key移动到`targetKeySort - 1`位置 - sql = "sort > ? and sort < ?" + sql = "priority > ? and priority < ?" originKey.Priority = targetKeySort - 1 fn = func(priority int) int { return priority - 1 } } - list, err := i.store.ListQuery(ctx, sql, []interface{}{originKeySort, targetKeySort}, "sort asc") + list, err := i.store.ListQuery(ctx, sql, []interface{}{originKeySort, targetKeySort}, "priority asc") if err != nil { return nil, err } @@ -102,16 +102,16 @@ func (i *imlBalanceService) SortAfter(ctx context.Context, originID string, targ fn := func(priority int) int { return priority + 1 } - sql := "sort < ? and sort > ?" + sql := "priority < ? and priority > ?" if originKeySort < targetKeySort { // 如果原始Key在目标Key之前,中间的Key往前移动,原始Key移动到`targetKeySort`位置 - sql = "sort > ? and sort <= ?" + sql = "priority > ? and priority <= ?" originKey.Priority = targetKeySort fn = func(priority int) int { return priority - 1 } } - list, err := i.store.ListQuery(ctx, sql, []interface{}{originKeySort, targetKeySort}, "sort asc") + list, err := i.store.ListQuery(ctx, sql, []interface{}{originKeySort, targetKeySort}, "priority asc") if err != nil { return nil, err }