finish ai balance

This commit is contained in:
Liujian
2025-02-14 16:29:56 +08:00
parent edfb2006b2
commit ee06368a4e
15 changed files with 131 additions and 19 deletions
+1 -1
View File
@@ -131,7 +131,7 @@ func genMessageSchema() *openapi3.Schema {
result.Description = "Chat Message"
roleSchema := openapi3.NewStringSchema()
roleSchema.Description = "Role of the message sender"
roleSchema.Example = "assistant"
roleSchema.Example = "user"
contentSchema := openapi3.NewStringSchema()
contentSchema.Description = "The message content"
contentSchema.Example = "Hello, how can I help you?"
+20 -5
View File
@@ -2,9 +2,14 @@ package ai_balance
import (
"context"
"errors"
"fmt"
"sort"
model_runtime "github.com/APIParkLab/APIPark/ai-provider/model-runtime"
"gorm.io/gorm"
ai_key "github.com/APIParkLab/APIPark/service/ai-key"
ai_api "github.com/APIParkLab/APIPark/service/ai-api"
@@ -36,18 +41,28 @@ type imlBalanceModule struct {
func (i *imlBalanceModule) Create(ctx context.Context, input *ai_balance_dto.Create) error {
priority, err := i.balanceService.MaxPriority(ctx)
if err != nil {
return err
if !errors.Is(err, gorm.ErrRecordNotFound) {
return err
}
priority = 0
}
if input.Id == "" {
input.Id = uuid.NewString()
}
providerName := ""
modelName := ""
// TODO: 名称进行优化
switch input.Type {
case ai_balance_dto.ModelTypeOnline:
p, has := model_runtime.GetProvider(input.Provider)
if !has {
return fmt.Errorf("provider not found")
}
providerName = p.Name()
modelName = input.Model
case ai_balance_dto.ModelTypeLocal:
input.Provider = "ollama"
providerName = "Ollama"
modelName = input.Model
}
return i.balanceService.Create(ctx, &ai_balance.Create{
Id: input.Id,
@@ -85,8 +100,8 @@ func (i *imlBalanceModule) Sort(ctx context.Context, input *ai_balance_dto.Sort)
return nil
}
func (i *imlBalanceModule) List(ctx context.Context) ([]*ai_balance_dto.Item, error) {
list, err := i.balanceService.List(ctx)
func (i *imlBalanceModule) List(ctx context.Context, keyword string) ([]*ai_balance_dto.Item, error) {
list, err := i.balanceService.Search(ctx, keyword, nil, "priority asc")
if err != nil {
return nil, err
}
+1 -1
View File
@@ -12,7 +12,7 @@ import (
type IBalanceModule interface {
Create(ctx context.Context, input *ai_balance_dto.Create) error
Sort(ctx context.Context, input *ai_balance_dto.Sort) error
List(ctx context.Context) ([]*ai_balance_dto.Item, error)
List(ctx context.Context, keyword string) ([]*ai_balance_dto.Item, error)
Delete(ctx context.Context, id string) error
}
+12 -5
View File
@@ -58,12 +58,19 @@ func FromLocalModelState(state int) LocalModelState {
}
}
type SimpleItem struct {
Id string `json:"id"`
Name string `json:"name"`
}
type LocalModelItem struct {
Id string `json:"id"`
Name string `json:"name"`
State LocalModelState `json:"state"`
APICount int64 `json:"api_count"`
UpdateTime auto.TimeLabel `json:"update_time"`
Id string `json:"id"`
Name string `json:"name"`
State LocalModelState `json:"state"`
APICount int64 `json:"api_count"`
UpdateTime auto.TimeLabel `json:"update_time"`
CanDelete bool `json:"can_delete"`
}
type LocalModelPackageItem struct {
+19
View File
@@ -37,6 +37,24 @@ type imlLocalModel struct {
transaction store.ITransaction `autowired:""`
}
func (i *imlLocalModel) SimpleList(ctx context.Context) ([]*ai_local_dto.SimpleItem, error) {
list, err := i.localModelService.List(ctx)
if err != nil {
return nil, err
}
return utils.SliceToSlice(list, func(s *ai_local.LocalModel) *ai_local_dto.SimpleItem {
return &ai_local_dto.SimpleItem{
Id: s.Id,
Name: s.Name,
}
}, func(l *ai_local.LocalModel) bool {
if l.State != ai_local_dto.LocalModelStateNormal.Int() && l.State != ai_local_dto.LocalModelStateDisable.Int() {
return false
}
return true
}), nil
}
func (i *imlLocalModel) ModelState(ctx context.Context, model string) (*ai_local_dto.DeployState, *ai_local_dto.ModelInfo, error) {
info, err := i.localModelStateService.Get(ctx, model)
if err != nil {
@@ -68,6 +86,7 @@ func (i *imlLocalModel) Search(ctx context.Context, keyword string) ([]*ai_local
Name: s.Name,
State: ai_local_dto.FromLocalModelState(s.State),
APICount: apiCountMap[s.Id],
CanDelete: true,
UpdateTime: auto.TimeLabel(s.UpdateAt),
}
}), nil
+1
View File
@@ -20,6 +20,7 @@ type ILocalModelModule interface {
Enable(ctx context.Context, model string) error
Disable(ctx context.Context, model string) error
ModelState(ctx context.Context, model string) (*ai_local_dto.DeployState, *ai_local_dto.ModelInfo, error)
SimpleList(ctx context.Context) ([]*ai_local_dto.SimpleItem, error)
}
func init() {