mirror of
https://github.com/APIParkLab/APIPark.git
synced 2026-06-04 10:13:53 +08:00
finish ai balance
This commit is contained in:
@@ -0,0 +1,22 @@
|
||||
package ai_balance
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
|
||||
ai_balance_dto "github.com/APIParkLab/APIPark/module/ai-balance/dto"
|
||||
"github.com/eolinker/go-common/autowire"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type IBalanceController interface {
|
||||
List(ctx *gin.Context, keyword string) ([]*ai_balance_dto.Item, error)
|
||||
Sort(ctx *gin.Context, input *ai_balance_dto.Sort) error
|
||||
Create(ctx *gin.Context, input *ai_balance_dto.Create) error
|
||||
Delete(ctx *gin.Context, id string) error
|
||||
}
|
||||
|
||||
func init() {
|
||||
autowire.Auto[IBalanceController](func() reflect.Value {
|
||||
return reflect.ValueOf(new(imlBalanceController))
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,29 @@
|
||||
package ai_balance
|
||||
|
||||
import (
|
||||
ai_balance "github.com/APIParkLab/APIPark/module/ai-balance"
|
||||
ai_balance_dto "github.com/APIParkLab/APIPark/module/ai-balance/dto"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
var _ IBalanceController = (*imlBalanceController)(nil)
|
||||
|
||||
type imlBalanceController struct {
|
||||
module ai_balance.IBalanceModule `autowired:""`
|
||||
}
|
||||
|
||||
func (i *imlBalanceController) List(ctx *gin.Context, keyword string) ([]*ai_balance_dto.Item, error) {
|
||||
return i.module.List(ctx, keyword)
|
||||
}
|
||||
|
||||
func (i *imlBalanceController) Sort(ctx *gin.Context, input *ai_balance_dto.Sort) error {
|
||||
return i.module.Sort(ctx, input)
|
||||
}
|
||||
|
||||
func (i *imlBalanceController) Create(ctx *gin.Context, input *ai_balance_dto.Create) error {
|
||||
return i.module.Create(ctx, input)
|
||||
}
|
||||
|
||||
func (i *imlBalanceController) Delete(ctx *gin.Context, id string) error {
|
||||
return i.module.Delete(ctx, id)
|
||||
}
|
||||
@@ -17,6 +17,7 @@ type ILocalModelController interface {
|
||||
RemoveModel(ctx *gin.Context, model string) error
|
||||
Update(ctx *gin.Context, model string, input *ai_local_dto.Update) error
|
||||
State(ctx *gin.Context, model string) (*ai_local_dto.DeployState, *ai_local_dto.ModelInfo, error)
|
||||
SimpleList(ctx *gin.Context) ([]*ai_local_dto.SimpleItem, error)
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
||||
@@ -53,6 +53,10 @@ type imlLocalModelController struct {
|
||||
transaction store.ITransaction `autowired:""`
|
||||
}
|
||||
|
||||
func (i *imlLocalModelController) SimpleList(ctx *gin.Context) ([]*ai_local_dto.SimpleItem, error) {
|
||||
return i.module.SimpleList(ctx)
|
||||
}
|
||||
|
||||
func (i *imlLocalModelController) State(ctx *gin.Context, model string) (*ai_local_dto.DeployState, *ai_local_dto.ModelInfo, error) {
|
||||
return i.module.ModelState(ctx, model)
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ require (
|
||||
github.com/eolinker/go-common v1.1.4
|
||||
github.com/gabriel-vasile/mimetype v1.4.4
|
||||
github.com/getkin/kin-openapi v0.127.0
|
||||
github.com/gin-contrib/gzip v1.0.1
|
||||
github.com/gin-gonic/gin v1.10.0
|
||||
github.com/go-sql-driver/mysql v1.7.0
|
||||
github.com/google/uuid v1.6.0
|
||||
@@ -35,7 +36,6 @@ require (
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||
github.com/dustin/go-humanize v1.0.0 // indirect
|
||||
github.com/ghodss/yaml v1.0.0 // indirect
|
||||
github.com/gin-contrib/gzip v1.0.1 // indirect
|
||||
github.com/gin-contrib/sse v0.1.0 // indirect
|
||||
github.com/go-openapi/jsonpointer v0.21.0 // indirect
|
||||
github.com/go-openapi/swag v0.23.0 // indirect
|
||||
|
||||
@@ -131,7 +131,7 @@ func genMessageSchema() *openapi3.Schema {
|
||||
result.Description = "Chat Message"
|
||||
roleSchema := openapi3.NewStringSchema()
|
||||
roleSchema.Description = "Role of the message sender"
|
||||
roleSchema.Example = "assistant"
|
||||
roleSchema.Example = "user"
|
||||
contentSchema := openapi3.NewStringSchema()
|
||||
contentSchema.Description = "The message content"
|
||||
contentSchema.Example = "Hello, how can I help you?"
|
||||
|
||||
@@ -2,9 +2,14 @@ package ai_balance
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sort"
|
||||
|
||||
model_runtime "github.com/APIParkLab/APIPark/ai-provider/model-runtime"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
ai_key "github.com/APIParkLab/APIPark/service/ai-key"
|
||||
|
||||
ai_api "github.com/APIParkLab/APIPark/service/ai-api"
|
||||
@@ -36,18 +41,28 @@ type imlBalanceModule struct {
|
||||
func (i *imlBalanceModule) Create(ctx context.Context, input *ai_balance_dto.Create) error {
|
||||
priority, err := i.balanceService.MaxPriority(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return err
|
||||
}
|
||||
priority = 0
|
||||
}
|
||||
if input.Id == "" {
|
||||
input.Id = uuid.NewString()
|
||||
}
|
||||
providerName := ""
|
||||
modelName := ""
|
||||
// TODO: 名称进行优化
|
||||
switch input.Type {
|
||||
case ai_balance_dto.ModelTypeOnline:
|
||||
p, has := model_runtime.GetProvider(input.Provider)
|
||||
if !has {
|
||||
return fmt.Errorf("provider not found")
|
||||
}
|
||||
providerName = p.Name()
|
||||
modelName = input.Model
|
||||
case ai_balance_dto.ModelTypeLocal:
|
||||
|
||||
input.Provider = "ollama"
|
||||
providerName = "Ollama"
|
||||
modelName = input.Model
|
||||
}
|
||||
return i.balanceService.Create(ctx, &ai_balance.Create{
|
||||
Id: input.Id,
|
||||
@@ -85,8 +100,8 @@ func (i *imlBalanceModule) Sort(ctx context.Context, input *ai_balance_dto.Sort)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *imlBalanceModule) List(ctx context.Context) ([]*ai_balance_dto.Item, error) {
|
||||
list, err := i.balanceService.List(ctx)
|
||||
func (i *imlBalanceModule) List(ctx context.Context, keyword string) ([]*ai_balance_dto.Item, error) {
|
||||
list, err := i.balanceService.Search(ctx, keyword, nil, "priority asc")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
type IBalanceModule interface {
|
||||
Create(ctx context.Context, input *ai_balance_dto.Create) error
|
||||
Sort(ctx context.Context, input *ai_balance_dto.Sort) error
|
||||
List(ctx context.Context) ([]*ai_balance_dto.Item, error)
|
||||
List(ctx context.Context, keyword string) ([]*ai_balance_dto.Item, error)
|
||||
Delete(ctx context.Context, id string) error
|
||||
}
|
||||
|
||||
|
||||
@@ -58,12 +58,19 @@ func FromLocalModelState(state int) LocalModelState {
|
||||
}
|
||||
}
|
||||
|
||||
type SimpleItem struct {
|
||||
Id string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
type LocalModelItem struct {
|
||||
Id string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
State LocalModelState `json:"state"`
|
||||
APICount int64 `json:"api_count"`
|
||||
UpdateTime auto.TimeLabel `json:"update_time"`
|
||||
Id string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
State LocalModelState `json:"state"`
|
||||
APICount int64 `json:"api_count"`
|
||||
|
||||
UpdateTime auto.TimeLabel `json:"update_time"`
|
||||
CanDelete bool `json:"can_delete"`
|
||||
}
|
||||
|
||||
type LocalModelPackageItem struct {
|
||||
|
||||
@@ -37,6 +37,24 @@ type imlLocalModel struct {
|
||||
transaction store.ITransaction `autowired:""`
|
||||
}
|
||||
|
||||
func (i *imlLocalModel) SimpleList(ctx context.Context) ([]*ai_local_dto.SimpleItem, error) {
|
||||
list, err := i.localModelService.List(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return utils.SliceToSlice(list, func(s *ai_local.LocalModel) *ai_local_dto.SimpleItem {
|
||||
return &ai_local_dto.SimpleItem{
|
||||
Id: s.Id,
|
||||
Name: s.Name,
|
||||
}
|
||||
}, func(l *ai_local.LocalModel) bool {
|
||||
if l.State != ai_local_dto.LocalModelStateNormal.Int() && l.State != ai_local_dto.LocalModelStateDisable.Int() {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}), nil
|
||||
}
|
||||
|
||||
func (i *imlLocalModel) ModelState(ctx context.Context, model string) (*ai_local_dto.DeployState, *ai_local_dto.ModelInfo, error) {
|
||||
info, err := i.localModelStateService.Get(ctx, model)
|
||||
if err != nil {
|
||||
@@ -68,6 +86,7 @@ func (i *imlLocalModel) Search(ctx context.Context, keyword string) ([]*ai_local
|
||||
Name: s.Name,
|
||||
State: ai_local_dto.FromLocalModelState(s.State),
|
||||
APICount: apiCountMap[s.Id],
|
||||
CanDelete: true,
|
||||
UpdateTime: auto.TimeLabel(s.UpdateAt),
|
||||
}
|
||||
}), nil
|
||||
|
||||
@@ -20,6 +20,7 @@ type ILocalModelModule interface {
|
||||
Enable(ctx context.Context, model string) error
|
||||
Disable(ctx context.Context, model string) error
|
||||
ModelState(ctx context.Context, model string) (*ai_local_dto.DeployState, *ai_local_dto.ModelInfo, error)
|
||||
SimpleList(ctx context.Context) ([]*ai_local_dto.SimpleItem, error)
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
||||
@@ -16,5 +16,6 @@ func (p *plugin) aiLocalApis() []pm3.Api {
|
||||
pm3.CreateApiWidthDoc(http.MethodDelete, "/api/v1/model/local", []string{"context", "query:model"}, nil, p.aiLocalController.RemoveModel),
|
||||
pm3.CreateApiWidthDoc(http.MethodPut, "/api/v1/model/local/info", []string{"context", "query:model", "body"}, nil, p.aiLocalController.Update),
|
||||
pm3.CreateApiWidthDoc(http.MethodGet, "/api/v1/model/local/state", []string{"context", "query:model"}, []string{"state", "info"}, p.aiLocalController.State),
|
||||
pm3.CreateApiWidthDoc(http.MethodGet, "/api/v1/simple/ai/models/local/configured", []string{"context"}, []string{"models"}, p.aiLocalController.SimpleList),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -36,3 +36,12 @@ func (p *plugin) aiKeyApis() []pm3.Api {
|
||||
pm3.CreateApiWidthDoc(http.MethodPut, "/api/v1/ai/resource/key/sort", []string{"context", "query:provider", "body"}, nil, p.aiKeyController.Sort),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *plugin) aiBalanceAPIs() []pm3.Api {
|
||||
return []pm3.Api{
|
||||
pm3.CreateApiWidthDoc(http.MethodGet, "/api/v1/ai/balances", []string{"context", "query:keyword"}, []string{"list"}, p.aiBalanceController.List),
|
||||
pm3.CreateApiWidthDoc(http.MethodPut, "/api/v1/ai/balance/sort", []string{"context", "body"}, nil, p.aiBalanceController.Sort),
|
||||
pm3.CreateApiWidthDoc(http.MethodPost, "/api/v1/ai/balance", []string{"context", "body"}, nil, p.aiBalanceController.Create),
|
||||
pm3.CreateApiWidthDoc(http.MethodDelete, "/api/v1/ai/balance", []string{"context", "query:id"}, nil, p.aiBalanceController.Delete),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,8 @@ package core
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
ai_balance "github.com/APIParkLab/APIPark/controller/ai-balance"
|
||||
|
||||
ai_local "github.com/APIParkLab/APIPark/controller/ai-local"
|
||||
|
||||
ai_key "github.com/APIParkLab/APIPark/controller/ai-key"
|
||||
@@ -80,6 +82,7 @@ type plugin struct {
|
||||
aiAPIController ai_api.IAPIController `autowired:""`
|
||||
aiStatisticController ai.IStatisticController `autowired:""`
|
||||
aiKeyController ai_key.IKeyController `autowired:""`
|
||||
aiBalanceController ai_balance.IBalanceController `autowired:""`
|
||||
aiLocalController ai_local.ILocalModelController `autowired:""`
|
||||
apiDocController router.IAPIDocController `autowired:""`
|
||||
subscribeController subscribe.ISubscribeController `autowired:""`
|
||||
@@ -122,6 +125,7 @@ func (p *plugin) OnComplete() {
|
||||
p.apis = append(p.apis, p.strategyApis()...)
|
||||
p.apis = append(p.apis, p.logApis()...)
|
||||
p.apis = append(p.apis, p.aiLocalApis()...)
|
||||
p.apis = append(p.apis, p.aiBalanceAPIs()...)
|
||||
}
|
||||
|
||||
func (p *plugin) Name() string {
|
||||
|
||||
@@ -53,16 +53,16 @@ func (i *imlBalanceService) SortBefore(ctx context.Context, originID string, tar
|
||||
fn := func(priority int) int {
|
||||
return priority + 1
|
||||
}
|
||||
sql := "sort < ? and sort >= ?"
|
||||
sql := "priority < ? and priority >= ?"
|
||||
if originKeySort < targetKeySort {
|
||||
// 如果原始Key在目标Key之前,中间的key往前移动,原始Key移动到`targetKeySort - 1`位置
|
||||
sql = "sort > ? and sort < ?"
|
||||
sql = "priority > ? and priority < ?"
|
||||
originKey.Priority = targetKeySort - 1
|
||||
fn = func(priority int) int {
|
||||
return priority - 1
|
||||
}
|
||||
}
|
||||
list, err := i.store.ListQuery(ctx, sql, []interface{}{originKeySort, targetKeySort}, "sort asc")
|
||||
list, err := i.store.ListQuery(ctx, sql, []interface{}{originKeySort, targetKeySort}, "priority asc")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -102,16 +102,16 @@ func (i *imlBalanceService) SortAfter(ctx context.Context, originID string, targ
|
||||
fn := func(priority int) int {
|
||||
return priority + 1
|
||||
}
|
||||
sql := "sort < ? and sort > ?"
|
||||
sql := "priority < ? and priority > ?"
|
||||
if originKeySort < targetKeySort {
|
||||
// 如果原始Key在目标Key之前,中间的Key往前移动,原始Key移动到`targetKeySort`位置
|
||||
sql = "sort > ? and sort <= ?"
|
||||
sql = "priority > ? and priority <= ?"
|
||||
originKey.Priority = targetKeySort
|
||||
fn = func(priority int) int {
|
||||
return priority - 1
|
||||
}
|
||||
}
|
||||
list, err := i.store.ListQuery(ctx, sql, []interface{}{originKeySort, targetKeySort}, "sort asc")
|
||||
list, err := i.store.ListQuery(ctx, sql, []interface{}{originKeySort, targetKeySort}, "priority asc")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user