Files
APIPark/module/ai-key/iml.go
T
2025-01-06 10:57:23 +08:00

440 lines
12 KiB
Go

package ai_key
import (
"context"
"errors"
"fmt"
"time"
"github.com/APIParkLab/APIPark/service/cluster"
"github.com/eolinker/eosc/log"
"github.com/APIParkLab/APIPark/gateway"
"github.com/eolinker/go-common/utils"
"gorm.io/gorm"
"github.com/eolinker/go-common/auto"
model_runtime "github.com/APIParkLab/APIPark/ai-provider/model-runtime"
"github.com/google/uuid"
"github.com/eolinker/go-common/store"
"github.com/APIParkLab/APIPark/service/ai"
ai_key_dto "github.com/APIParkLab/APIPark/module/ai-key/dto"
ai_key "github.com/APIParkLab/APIPark/service/ai-key"
)
var _ IKeyModule = &imlKeyModule{}
type imlKeyModule struct {
providerService ai.IProviderService `autowired:""`
aiKeyService ai_key.IKeyService `autowired:""`
clusterService cluster.IClusterService `autowired:""`
transaction store.ITransaction `autowired:""`
}
func newKey(key *ai_key.Key) *gateway.DynamicRelease {
return &gateway.DynamicRelease{
BasicItem: &gateway.BasicItem{
ID: key.ID,
Description: key.Name,
Resource: "ai-key",
Version: time.Now().Format("20060102150405"),
MatchLabels: map[string]string{
"module": "ai-key",
},
},
Attr: map[string]interface{}{
"expired": key.ExpireTime,
"config": key.Config,
"provider": key.Provider,
"priority": key.Priority,
"disabled": key.Status == 1,
},
}
}
func (i *imlKeyModule) Create(ctx context.Context, providerId string, input *ai_key_dto.Create) error {
_, err := i.providerService.Get(ctx, providerId)
if err != nil {
return fmt.Errorf("provider not found: %w", err)
}
p, has := model_runtime.GetProvider(providerId)
if !has {
return fmt.Errorf("provider not found: %w", err)
}
p.URI()
err = p.Check(input.Config)
if err != nil {
return fmt.Errorf("config check failed: %w", err)
}
priority, err := i.aiKeyService.MaxPriority(ctx, providerId)
if err != nil {
return fmt.Errorf("get key error: %v", err)
}
return i.transaction.Transaction(ctx, func(ctx context.Context) error {
if input.Id == "" {
input.Id = uuid.NewString()
}
status := ai_key_dto.KeyNormal.Int()
if input.ExpireTime > 0 && time.Unix(int64(input.ExpireTime), 0).Before(time.Now()) {
status = ai_key_dto.KeyExpired.Int()
}
err = i.aiKeyService.Create(ctx, &ai_key.Create{
ID: input.Id,
Name: input.Name,
Config: input.Config,
Provider: providerId,
Status: status,
ExpireTime: input.ExpireTime,
Priority: priority + 1,
})
info, _ := i.aiKeyService.Get(ctx, input.Id)
releases := []*gateway.DynamicRelease{newKey(info)}
return i.syncGateway(ctx, cluster.DefaultClusterID, releases, true)
})
}
func (i *imlKeyModule) syncGateway(ctx context.Context, clusterId string, releases []*gateway.DynamicRelease, online bool) error {
client, err := i.clusterService.GatewayClient(ctx, clusterId)
if err != nil {
log.Errorf("get apinto client error: %v", err)
return nil
}
defer func() {
err := client.Close(ctx)
if err != nil {
log.Warn("close apinto client:", err)
}
}()
for _, releaseInfo := range releases {
dynamicClient, err := client.Dynamic(releaseInfo.Resource)
if err != nil {
return err
}
if online {
err = dynamicClient.Online(ctx, releaseInfo)
} else {
err = dynamicClient.Offline(ctx, releaseInfo)
}
if err != nil {
return err
}
}
return nil
}
func (i *imlKeyModule) Edit(ctx context.Context, providerId string, id string, input *ai_key_dto.Edit) error {
p, has := model_runtime.GetProvider(providerId)
if !has {
return fmt.Errorf("provider not found: %s", providerId)
}
_, err := i.providerService.Get(ctx, providerId)
if err != nil {
return fmt.Errorf("provider not found: %w", err)
}
return i.transaction.Transaction(ctx, func(ctx context.Context) error {
info, err := i.aiKeyService.Get(ctx, id)
if err != nil {
return fmt.Errorf("key not found: %w", err)
}
if input.Config != nil {
err = p.Check(*input.Config)
if err != nil {
return fmt.Errorf("config check failed: %w", err)
}
cfg, err := p.GenConfig(info.Config, *input.Config)
if err != nil {
return fmt.Errorf("config gen failed: %w", err)
}
input.Config = &cfg
if info.Default {
err = i.providerService.Save(ctx, info.Provider, &ai.SetProvider{
Config: input.Config,
})
if err != nil {
return err
}
}
}
status := ai_key_dto.KeyNormal.Int()
orgStatus := ai_key_dto.ToKeyStatus(info.Status)
switch orgStatus {
case ai_key_dto.KeyNormal, ai_key_dto.KeyError, ai_key_dto.KeyExpired:
if input.ExpireTime != nil {
expireTime := *input.ExpireTime
if expireTime > 0 && time.Unix(int64(expireTime), 0).Before(time.Now()) {
status = ai_key_dto.KeyExpired.Int()
}
} else if info.ExpireTime > 0 && time.Unix(int64(info.ExpireTime), 0).Before(time.Now()) {
// 如果过期时间未更改,且已过期,则设置为过期状态
status = ai_key_dto.KeyExpired.Int()
}
default:
// 停用、超额需要启用,所以维持原状态
status = orgStatus.Int()
}
if status == ai_key_dto.KeyNormal.Int() {
// TODO: 发布Key到网关
}
return i.aiKeyService.Save(ctx, id, &ai_key.Edit{
Name: input.Name,
Config: input.Config,
ExpireTime: input.ExpireTime,
Status: &status,
})
})
}
func (i *imlKeyModule) Delete(ctx context.Context, providerId string, id string) error {
_, err := i.providerService.Get(ctx, providerId)
if err != nil {
return fmt.Errorf("provider not found: %w", err)
}
return i.transaction.Transaction(ctx, func(ctx context.Context) error {
info, err := i.aiKeyService.Get(ctx, id)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil
}
return err
}
if info.Default {
return fmt.Errorf("default key can't be deleted: %s", id)
}
keys, err := i.aiKeyService.KeysAfterPriority(ctx, providerId, info.Priority)
if err != nil {
return err
}
for _, key := range keys {
key.Priority--
err = i.aiKeyService.Save(ctx, key.ID, &ai_key.Edit{
Priority: &key.Priority,
})
if err != nil {
return err
}
}
err = i.aiKeyService.Delete(ctx, id)
if err != nil {
return err
}
return i.syncGateway(ctx, cluster.DefaultClusterID, []*gateway.DynamicRelease{{
BasicItem: &gateway.BasicItem{
ID: id,
Resource: "ai-key",
},
Attr: nil,
},
}, false)
})
}
func (i *imlKeyModule) Get(ctx context.Context, providerId string, id string) (*ai_key_dto.Key, error) {
p, has := model_runtime.GetProvider(providerId)
if !has {
return nil, fmt.Errorf("provider not found: %s", providerId)
}
_, err := i.providerService.Get(ctx, providerId)
if err != nil {
return nil, fmt.Errorf("provider not found: %w", err)
}
info, err := i.aiKeyService.Get(ctx, id)
if err != nil {
return nil, fmt.Errorf("key not found: %w", err)
}
return &ai_key_dto.Key{
Id: info.ID,
Name: info.Name,
Config: p.MaskConfig(info.Config),
ExpireTime: info.ExpireTime,
}, nil
}
func (i *imlKeyModule) List(ctx context.Context, providerId string, keyword string, page, pageSize int, statuses []string) ([]*ai_key_dto.Item, int64, error) {
_, err := i.aiKeyService.DefaultKey(ctx, providerId)
if err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, 0, fmt.Errorf("get default key failed: %w", err)
}
info, err := i.providerService.Get(ctx, providerId)
if err != nil {
return nil, 0, fmt.Errorf("provider is unconfigued,id is %s", providerId)
}
err = i.aiKeyService.Create(ctx, &ai_key.Create{
ID: info.Id,
Name: info.Name,
Config: info.Config,
Provider: info.Id,
Status: ai_key_dto.KeyNormal.Int(),
Priority: 1,
ExpireTime: 0,
UseToken: 0,
Default: true,
})
if err != nil {
return nil, 0, fmt.Errorf("create default key failed: %w", err)
}
}
w := map[string]interface{}{
"provider": providerId,
}
hasExpired := true
if len(statuses) > 0 {
hasExpired = false
statusConditions := make([]int, 0, len(statuses))
for _, s := range statuses {
status := ai_key_dto.KeyStatus(s)
if status == ai_key_dto.KeyExpired {
hasExpired = true
}
statusConditions = append(statusConditions, status.Int())
}
w["status"] = statusConditions
}
var list []*ai_key.Key
var total int64
if !hasExpired {
if keyword != "" {
list, err = i.aiKeyService.Search(ctx, keyword, w, "sort asc")
if err != nil {
return nil, 0, err
}
if len(list) == 0 {
return nil, 0, nil
}
uuids := utils.SliceToSlice(list, func(key *ai_key.Key) string {
return key.ID
})
w["uuid"] = uuids
}
list, total, err = i.aiKeyService.SearchUnExpiredByPage(ctx, w, page, pageSize, "sort asc")
if err != nil {
return nil, 0, err
}
} else {
list, total, err = i.aiKeyService.SearchByPage(ctx, keyword, w, page, pageSize, "sort asc")
if err != nil {
return nil, 0, err
}
}
var result []*ai_key_dto.Item
for _, item := range list {
status := ai_key_dto.ToKeyStatus(item.Status)
if item.ExpireTime > 0 && time.Unix(int64(item.ExpireTime), 0).Before(time.Now()) {
status = ai_key_dto.KeyExpired
}
result = append(result, &ai_key_dto.Item{
Id: item.ID,
Name: item.Name,
Status: status,
UseToken: item.UseToken,
UpdateTime: auto.TimeLabel(item.UpdateAt),
ExpireTime: item.ExpireTime,
CanDelete: !item.Default,
Priority: item.Priority,
})
}
return result, total, nil
}
func (i *imlKeyModule) UpdateKeyStatus(ctx context.Context, providerId string, id string, enable bool) error {
_, err := i.providerService.Get(ctx, providerId)
if err != nil {
return fmt.Errorf("provider not found: %w", err)
}
info, err := i.aiKeyService.Get(ctx, id)
if err != nil {
return fmt.Errorf("key not found: %w", err)
}
return i.transaction.Transaction(ctx, func(ctx context.Context) error {
if !enable {
// TODO:下线Key
status := ai_key_dto.KeyDisable.Int()
err = i.aiKeyService.Save(ctx, id, &ai_key.Edit{
Status: &status,
})
if err != nil {
return err
}
releases := []*gateway.DynamicRelease{{
BasicItem: &gateway.BasicItem{
ID: id,
Resource: "ai-key",
},
Attr: nil,
}}
return i.syncGateway(ctx, providerId, releases, false)
}
if info.Status == ai_key_dto.KeyDisable.Int() || info.Status == ai_key_dto.KeyExceed.Int() {
// 超额 或 停用状态,可启用
if info.ExpireTime > 0 && time.Unix(int64(info.ExpireTime), 0).Before(time.Now()) {
// 如果过期时间未更改,且已过期,则设置为过期状态
status := ai_key_dto.KeyExpired.Int()
return i.aiKeyService.Save(ctx, id, &ai_key.Edit{
Status: &status,
})
}
status := ai_key_dto.KeyNormal.Int()
err = i.aiKeyService.Save(ctx, id, &ai_key.Edit{
Status: &status,
})
if err != nil {
return err
}
info, err = i.aiKeyService.Get(ctx, id)
if err != nil {
return err
}
releases := []*gateway.DynamicRelease{newKey(info)}
return i.syncGateway(ctx, providerId, releases, true)
}
return nil
})
}
func (i *imlKeyModule) Sort(ctx context.Context, providerId string, input *ai_key_dto.Sort) error {
_, err := i.providerService.Get(ctx, providerId)
if err != nil {
return fmt.Errorf("provider not found: %w", err)
}
return i.transaction.Transaction(ctx, func(ctx context.Context) error {
switch input.Sort {
case "before":
_, err = i.aiKeyService.SortBefore(ctx, providerId, input.Origin, input.Target)
case "after":
_, err = i.aiKeyService.SortAfter(ctx, providerId, input.Origin, input.Target)
default:
return fmt.Errorf("invalid sort type: %s", input.Sort)
}
if err != nil {
return err
}
list, err := i.aiKeyService.KeysByProvider(ctx, providerId)
if err != nil {
return err
}
releases := make([]*gateway.DynamicRelease, 0, len(list))
for _, info := range list {
releases = append(releases, newKey(info))
}
return i.syncGateway(ctx, cluster.DefaultClusterID, releases, true)
})
}