mirror of
https://github.com/APIParkLab/APIPark.git
synced 2026-06-14 20:41:15 +08:00
finish ai balance
This commit is contained in:
@@ -131,7 +131,7 @@ func genMessageSchema() *openapi3.Schema {
|
||||
result.Description = "Chat Message"
|
||||
roleSchema := openapi3.NewStringSchema()
|
||||
roleSchema.Description = "Role of the message sender"
|
||||
roleSchema.Example = "assistant"
|
||||
roleSchema.Example = "user"
|
||||
contentSchema := openapi3.NewStringSchema()
|
||||
contentSchema.Description = "The message content"
|
||||
contentSchema.Example = "Hello, how can I help you?"
|
||||
|
||||
@@ -2,9 +2,14 @@ package ai_balance
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sort"
|
||||
|
||||
model_runtime "github.com/APIParkLab/APIPark/ai-provider/model-runtime"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
ai_key "github.com/APIParkLab/APIPark/service/ai-key"
|
||||
|
||||
ai_api "github.com/APIParkLab/APIPark/service/ai-api"
|
||||
@@ -36,18 +41,28 @@ type imlBalanceModule struct {
|
||||
func (i *imlBalanceModule) Create(ctx context.Context, input *ai_balance_dto.Create) error {
|
||||
priority, err := i.balanceService.MaxPriority(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return err
|
||||
}
|
||||
priority = 0
|
||||
}
|
||||
if input.Id == "" {
|
||||
input.Id = uuid.NewString()
|
||||
}
|
||||
providerName := ""
|
||||
modelName := ""
|
||||
// TODO: 名称进行优化
|
||||
switch input.Type {
|
||||
case ai_balance_dto.ModelTypeOnline:
|
||||
p, has := model_runtime.GetProvider(input.Provider)
|
||||
if !has {
|
||||
return fmt.Errorf("provider not found")
|
||||
}
|
||||
providerName = p.Name()
|
||||
modelName = input.Model
|
||||
case ai_balance_dto.ModelTypeLocal:
|
||||
|
||||
input.Provider = "ollama"
|
||||
providerName = "Ollama"
|
||||
modelName = input.Model
|
||||
}
|
||||
return i.balanceService.Create(ctx, &ai_balance.Create{
|
||||
Id: input.Id,
|
||||
@@ -85,8 +100,8 @@ func (i *imlBalanceModule) Sort(ctx context.Context, input *ai_balance_dto.Sort)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *imlBalanceModule) List(ctx context.Context) ([]*ai_balance_dto.Item, error) {
|
||||
list, err := i.balanceService.List(ctx)
|
||||
func (i *imlBalanceModule) List(ctx context.Context, keyword string) ([]*ai_balance_dto.Item, error) {
|
||||
list, err := i.balanceService.Search(ctx, keyword, nil, "priority asc")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
type IBalanceModule interface {
|
||||
Create(ctx context.Context, input *ai_balance_dto.Create) error
|
||||
Sort(ctx context.Context, input *ai_balance_dto.Sort) error
|
||||
List(ctx context.Context) ([]*ai_balance_dto.Item, error)
|
||||
List(ctx context.Context, keyword string) ([]*ai_balance_dto.Item, error)
|
||||
Delete(ctx context.Context, id string) error
|
||||
}
|
||||
|
||||
|
||||
@@ -58,12 +58,19 @@ func FromLocalModelState(state int) LocalModelState {
|
||||
}
|
||||
}
|
||||
|
||||
type SimpleItem struct {
|
||||
Id string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
type LocalModelItem struct {
|
||||
Id string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
State LocalModelState `json:"state"`
|
||||
APICount int64 `json:"api_count"`
|
||||
UpdateTime auto.TimeLabel `json:"update_time"`
|
||||
Id string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
State LocalModelState `json:"state"`
|
||||
APICount int64 `json:"api_count"`
|
||||
|
||||
UpdateTime auto.TimeLabel `json:"update_time"`
|
||||
CanDelete bool `json:"can_delete"`
|
||||
}
|
||||
|
||||
type LocalModelPackageItem struct {
|
||||
|
||||
@@ -37,6 +37,24 @@ type imlLocalModel struct {
|
||||
transaction store.ITransaction `autowired:""`
|
||||
}
|
||||
|
||||
func (i *imlLocalModel) SimpleList(ctx context.Context) ([]*ai_local_dto.SimpleItem, error) {
|
||||
list, err := i.localModelService.List(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return utils.SliceToSlice(list, func(s *ai_local.LocalModel) *ai_local_dto.SimpleItem {
|
||||
return &ai_local_dto.SimpleItem{
|
||||
Id: s.Id,
|
||||
Name: s.Name,
|
||||
}
|
||||
}, func(l *ai_local.LocalModel) bool {
|
||||
if l.State != ai_local_dto.LocalModelStateNormal.Int() && l.State != ai_local_dto.LocalModelStateDisable.Int() {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}), nil
|
||||
}
|
||||
|
||||
func (i *imlLocalModel) ModelState(ctx context.Context, model string) (*ai_local_dto.DeployState, *ai_local_dto.ModelInfo, error) {
|
||||
info, err := i.localModelStateService.Get(ctx, model)
|
||||
if err != nil {
|
||||
@@ -68,6 +86,7 @@ func (i *imlLocalModel) Search(ctx context.Context, keyword string) ([]*ai_local
|
||||
Name: s.Name,
|
||||
State: ai_local_dto.FromLocalModelState(s.State),
|
||||
APICount: apiCountMap[s.Id],
|
||||
CanDelete: true,
|
||||
UpdateTime: auto.TimeLabel(s.UpdateAt),
|
||||
}
|
||||
}), nil
|
||||
|
||||
@@ -20,6 +20,7 @@ type ILocalModelModule interface {
|
||||
Enable(ctx context.Context, model string) error
|
||||
Disable(ctx context.Context, model string) error
|
||||
ModelState(ctx context.Context, model string) (*ai_local_dto.DeployState, *ai_local_dto.ModelInfo, error)
|
||||
SimpleList(ctx context.Context) ([]*ai_local_dto.SimpleItem, error)
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
||||
Reference in New Issue
Block a user