finish ai apis

This commit is contained in:
Liujian
2024-12-24 18:00:46 +08:00
parent 0be2248f41
commit 7ac8beb161
21 changed files with 409 additions and 24 deletions
+13
View File
@@ -30,6 +30,19 @@ type APIItem struct {
Model ModelItem `json:"model"`
}
type AIAPIItem struct {
Id string `json:"id"`
Name string `json:"name"`
Service auto.Label `json:"service" aolabel:"service"`
Method string `json:"method"`
RequestPath string `json:"request_path"`
Model ModelItem `json:"model"`
Provider ProviderItem `json:"provider"`
UpdateTime auto.TimeLabel `json:"update_time"`
UseToken int64 `json:"use_token"`
Disable bool `json:"disable"`
}
type ModelItem struct {
Id string `json:"id"`
Logo string `json:"logo"`
+2
View File
@@ -112,6 +112,7 @@ func (i *imlAPIModule) Create(ctx context.Context, serviceId string, input *ai_a
Name: input.Name,
Service: serviceId,
Path: input.Path,
Disable: input.Disable,
Description: input.Description,
Timeout: input.Timeout,
Retry: input.Retry,
@@ -171,6 +172,7 @@ func (i *imlAPIModule) Edit(ctx context.Context, serviceId string, apiId string,
Model: modelId,
Provider: providerId,
AdditionalConfig: &apiInfo.AdditionalConfig,
Disable: input.Disable,
})
})
}
+2 -1
View File
@@ -2,9 +2,10 @@ package ai_api
import (
"context"
"reflect"
ai_api_dto "github.com/APIParkLab/APIPark/module/ai-api/dto"
"github.com/eolinker/go-common/autowire"
"reflect"
)
type IAPIModule interface {
+4
View File
@@ -10,3 +10,7 @@ type UpdateConfig struct {
Priority *int `json:"priority"`
Enable *bool `json:"enable"`
}
type Sort struct {
Providers []string `json:"providers"`
}
+28 -5
View File
@@ -1,5 +1,9 @@
package ai_dto
import (
"github.com/eolinker/go-common/auto"
)
type Provider struct {
Id string `json:"id"`
Name string `json:"name"`
@@ -19,10 +23,16 @@ type ConfiguredProviderItem struct {
Status ProviderStatus `json:"status"`
APICount int64 `json:"api_count"`
KeyCount int `json:"key_count"`
KeyStatus []string `json:"key_status"`
KeyStatus []*KeyStatus `json:"key_status"`
Priority int `json:"priority"`
}
type KeyStatus struct {
Id string `json:"id"`
Name string `json:"name"`
Status string `json:"status"`
}
type ProviderItem struct {
Id string `json:"id"`
Name string `json:"name"`
@@ -32,10 +42,11 @@ type ProviderItem struct {
}
type SimpleProviderItem struct {
Id string `json:"id"`
Name string `json:"name"`
Logo string `json:"logo"`
Configured bool `json:"configured"`
Id string `json:"id"`
Name string `json:"name"`
Logo string `json:"logo"`
Configured bool `json:"configured"`
Status ProviderStatus `json:"status"`
}
type LLMItem struct {
@@ -44,3 +55,15 @@ type LLMItem struct {
Config string `json:"config"`
Scopes []string `json:"scopes"`
}
type APIItem struct {
Id string `json:"id"`
Name string `json:"name"`
Service auto.Label `json:"service" aolabel:"service"`
Method string `json:"method"`
RequestPath string `json:"request_path"`
Model auto.Label `json:"model"`
UpdateTime auto.TimeLabel `json:"update_time"`
UseToken int `json:"use_token"`
Disable bool `json:"disable"`
}
+148 -10
View File
@@ -5,6 +5,7 @@ import (
"encoding/json"
"errors"
"fmt"
"net/http"
"sort"
"time"
@@ -59,6 +60,32 @@ type imlProviderModule struct {
transaction store.ITransaction `autowired:""`
}
func (i *imlProviderModule) Sort(ctx context.Context, input *ai_dto.Sort) error {
return i.transaction.Transaction(ctx, func(txCtx context.Context) error {
list, err := i.providerService.List(ctx)
if err != nil {
return err
}
providerMap := utils.SliceToMap(list, func(e *ai.Provider) string {
return e.Id
})
for index, id := range input.Providers {
_, has := providerMap[id]
if !has {
continue
}
priority := index + 1
err = i.providerService.Save(txCtx, id, &ai.SetProvider{
Priority: &priority,
})
if err != nil {
return err
}
}
return nil
})
}
func (i *imlProviderModule) ConfiguredProviders(ctx context.Context) ([]*ai_dto.ConfiguredProviderItem, *auto.Label, error) {
// 获取已配置的AI服务商
list, err := i.providerService.List(ctx)
@@ -71,6 +98,28 @@ func (i *imlProviderModule) ConfiguredProviders(ctx context.Context) ([]*ai_dto.
}
providers := make([]*ai_dto.ConfiguredProviderItem, 0, len(list))
for _, l := range list {
// 检查是否有默认Key
_, err = i.aiKeyService.DefaultKey(ctx, l.Id)
if err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil, err
}
err = i.aiKeyService.Create(ctx, &ai_key.Create{
ID: l.Id,
Name: l.Name,
Config: l.Config,
Provider: l.Id,
Priority: 1,
Status: 1,
ExpireTime: 0,
UseToken: 0,
Default: true,
})
if err != nil {
return nil, nil, fmt.Errorf("create default key error:%v", err)
}
}
p, has := model_runtime.GetProvider(l.Id)
if !has {
continue
@@ -80,7 +129,7 @@ func (i *imlProviderModule) ConfiguredProviders(ctx context.Context) ([]*ai_dto.
return nil, nil, fmt.Errorf("get provider keys error:%v", err)
}
keysStatus := make([]string, 0, len(keys))
keysStatus := make([]*ai_dto.KeyStatus, 0, len(keys))
for _, k := range keys {
status := ai_key_dto.ToKeyStatus(k.Status)
switch status {
@@ -88,11 +137,13 @@ func (i *imlProviderModule) ConfiguredProviders(ctx context.Context) ([]*ai_dto.
default:
status = ai_key_dto.KeyError
}
keysStatus = append(keysStatus, status.String())
}
if len(keysStatus) == 0 {
keysStatus = []string{ai_key_dto.KeyNormal.String()}
keysStatus = append(keysStatus, &ai_dto.KeyStatus{
Id: k.ID,
Name: k.Name,
Status: status.String(),
})
}
providers = append(providers, &ai_dto.ConfiguredProviderItem{
Id: l.Id,
Name: l.Name,
@@ -143,12 +194,14 @@ func (i *imlProviderModule) SimpleProviders(ctx context.Context) ([]*ai_dto.Simp
items := make([]*ai_dto.SimpleProviderItem, 0, len(providers))
for _, v := range providers {
item := &ai_dto.SimpleProviderItem{
Id: v.ID(),
Name: v.Name(),
Logo: v.Logo(),
Id: v.ID(),
Name: v.Name(),
Logo: v.Logo(),
Status: ai_dto.ProviderDisabled,
}
if _, has := providerMap[v.ID()]; has {
if info, has := providerMap[v.ID()]; has {
item.Configured = true
item.Status = ai_dto.ToProviderStatus(info.Status)
}
items = append(items, item)
}
@@ -403,7 +456,7 @@ func (i *imlProviderModule) UpdateProviderConfig(ctx context.Context, id string,
}
pInfo := &ai.SetProvider{
Name: &info.Name,
DefaultLLM: &info.DefaultLLM,
DefaultLLM: &input.DefaultLLM,
Config: &input.Config,
Priority: input.Priority,
}
@@ -568,3 +621,88 @@ func (i *imlProviderModule) syncGateway(ctx context.Context, clusterId string, r
return nil
}
var _ IAIAPIModule = (*imlAIApiModule)(nil)
type imlAIApiModule struct {
aiAPIService ai_api.IAPIService `autowired:""`
aiAPIUseService ai_api.IAPIUseService `autowired:""`
}
func (i *imlAIApiModule) APIs(ctx context.Context, keyword string, providerId string, start int64, end int64, page int, pageSize int, sortCondition string, asc bool) ([]*ai_dto.APIItem, int64, error) {
sortRule := "desc"
if asc {
sortRule = "asc"
}
switch sortCondition {
default:
apis, err := i.aiAPIService.Search(ctx, keyword, map[string]interface{}{
"provider": providerId,
}, "update_at desc")
if err != nil {
return nil, 0, err
}
if len(apis) <= 0 {
return nil, 0, nil
}
apiMap := make(map[string]*ai_api.API)
apiIds := make([]string, 0, len(apis))
for _, a := range apis {
apiMap[a.ID] = a
apiIds = append(apiIds, a.ID)
}
offset := (page - 1) * pageSize
results, _, err := i.aiAPIUseService.SumByApisPage(ctx, providerId, start, end, offset, pageSize, fmt.Sprintf("total_token %s", sortRule), apiIds...)
if err != nil {
return nil, 0, err
}
apiItems := utils.SliceToSlice(results, func(e *ai_api.APIUse) *ai_dto.APIItem {
info := apiMap[e.API]
delete(apiMap, e.API)
return &ai_dto.APIItem{
Id: e.API,
Name: info.Name,
Service: auto.UUID(info.Service),
Method: http.MethodPost,
RequestPath: info.Path,
Model: auto.Label{
Id: info.Model,
Name: info.Model,
},
UpdateTime: auto.TimeLabel(info.UpdateAt),
UseToken: e.TotalToken,
Disable: info.Disable,
}
})
sortApis := make([]*ai_dto.APIItem, 0, len(apiMap))
for _, a := range apiMap {
sortApis = append(sortApis, &ai_dto.APIItem{
Id: a.ID,
Name: a.Name,
Service: auto.UUID(a.Service),
Method: http.MethodPost,
RequestPath: a.Path,
Model: auto.Label{
Id: a.Model,
Name: a.Model,
},
UpdateTime: auto.TimeLabel(a.UpdateAt),
UseToken: 0,
Disable: a.Disable,
})
}
// 排序
sort.Slice(sortApis, func(i, j int) bool {
return time.Time(sortApis[i].UpdateTime).After(time.Time(sortApis[j].UpdateTime))
})
size := pageSize - len(apiItems)
for i := offset; i < offset+size && i < len(sortApis); i++ {
apiItems = append(apiItems, sortApis[i])
}
total := int64(len(apis))
return apiItems, total, nil
}
}
+9
View File
@@ -20,6 +20,11 @@ type IProviderModule interface {
UpdateProviderStatus(ctx context.Context, id string, enable bool) error
UpdateProviderConfig(ctx context.Context, id string, input *ai_dto.UpdateConfig) error
UpdateProviderDefaultLLM(ctx context.Context, id string, input *ai_dto.UpdateLLM) error
Sort(ctx context.Context, input *ai_dto.Sort) error
}
type IAIAPIModule interface {
APIs(ctx context.Context, keyword string, providerId string, start int64, end int64, page int, pageSize int, sortCondition string, asc bool) ([]*ai_dto.APIItem, int64, error)
}
func init() {
@@ -28,4 +33,8 @@ func init() {
gateway.RegisterInitHandleFunc(module.initGateway)
return reflect.ValueOf(module)
})
autowire.Auto[IAIAPIModule](func() reflect.Value {
return reflect.ValueOf(new(imlAIApiModule))
})
}