mirror of
https://github.com/APIParkLab/APIPark.git
synced 2026-06-14 20:41:15 +08:00
444 lines
11 KiB
Go
444 lines
11 KiB
Go
package ai
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"sort"
|
|
"time"
|
|
|
|
model_runtime "github.com/APIParkLab/APIPark/ai-provider/model-runtime"
|
|
"github.com/APIParkLab/APIPark/gateway"
|
|
ai_dto "github.com/APIParkLab/APIPark/module/ai/dto"
|
|
"github.com/APIParkLab/APIPark/service/ai"
|
|
"github.com/APIParkLab/APIPark/service/cluster"
|
|
"github.com/eolinker/eosc/log"
|
|
"github.com/eolinker/go-common/store"
|
|
"github.com/eolinker/go-common/utils"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
func newAIUpstream(provider string, uri model_runtime.IProviderURI) *gateway.DynamicRelease {
|
|
return &gateway.DynamicRelease{
|
|
BasicItem: &gateway.BasicItem{
|
|
ID: provider,
|
|
Description: fmt.Sprintf("auto create by ai provider %s", provider),
|
|
Resource: "service",
|
|
Version: time.Now().Format("20060102150405"),
|
|
MatchLabels: map[string]string{
|
|
"module": "service",
|
|
},
|
|
},
|
|
Attr: map[string]interface{}{
|
|
"driver": "http",
|
|
"balance": "round-robin",
|
|
"nodes": []string{fmt.Sprintf("%s weight=100", uri.Host())},
|
|
"pass_node": "node",
|
|
"scheme": uri.Scheme(),
|
|
"timeout": 300000,
|
|
},
|
|
}
|
|
}
|
|
|
|
var _ IProviderModule = (*imlProviderModule)(nil)
|
|
|
|
type imlProviderModule struct {
|
|
providerService ai.IProviderService `autowired:""`
|
|
clusterService cluster.IClusterService `autowired:""`
|
|
transaction store.ITransaction `autowired:""`
|
|
}
|
|
|
|
func (i *imlProviderModule) SimpleProviders(ctx context.Context) ([]*ai_dto.SimpleProviderItem, error) {
|
|
list, err := i.providerService.List(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
providers := model_runtime.Providers()
|
|
|
|
providerMap := utils.SliceToMap(list, func(e *ai.Provider) string {
|
|
return e.Id
|
|
})
|
|
items := make([]*ai_dto.SimpleProviderItem, 0, len(providers))
|
|
for _, v := range providers {
|
|
item := &ai_dto.SimpleProviderItem{
|
|
Id: v.ID(),
|
|
Name: v.Name(),
|
|
Logo: v.Logo(),
|
|
}
|
|
if _, has := providerMap[v.ID()]; has {
|
|
item.Configured = true
|
|
}
|
|
items = append(items, item)
|
|
}
|
|
return items, nil
|
|
}
|
|
|
|
func (i *imlProviderModule) Providers(ctx context.Context) ([]*ai_dto.ProviderItem, error) {
|
|
list, err := i.providerService.List(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
providers := model_runtime.Providers()
|
|
providerMap := utils.SliceToMap(list, func(e *ai.Provider) string {
|
|
return e.Id
|
|
})
|
|
items := make([]*ai_dto.ProviderItem, 0, len(providers))
|
|
for _, v := range providers {
|
|
defaultLLM, has := v.DefaultModel(model_runtime.ModelTypeLLM)
|
|
if !has {
|
|
continue
|
|
}
|
|
item := &ai_dto.ProviderItem{
|
|
Id: v.ID(),
|
|
Name: v.Name(),
|
|
Logo: v.Logo(),
|
|
Recommend: v.Recommend(),
|
|
Sort: v.Sort(),
|
|
}
|
|
if info, has := providerMap[v.ID()]; has {
|
|
item.Configured = true
|
|
item.DefaultLLM = defaultLLM.ID()
|
|
item.DefaultLLMLogo = defaultLLM.Logo()
|
|
item.UpdateTime = info.UpdateAt
|
|
}
|
|
items = append(items, item)
|
|
}
|
|
sort.Slice(items, func(i, j int) bool {
|
|
if items[i].Configured == items[j].Configured && items[i].Configured {
|
|
return items[i].Name < items[j].Name
|
|
}
|
|
if items[i].Sort != items[j].Sort {
|
|
if items[i].Sort == 0 {
|
|
return false
|
|
}
|
|
if items[j].Sort == 0 {
|
|
return true
|
|
}
|
|
return items[i].Sort < items[j].Sort
|
|
}
|
|
return items[i].Name < items[j].Name
|
|
})
|
|
return items, nil
|
|
}
|
|
|
|
func (i *imlProviderModule) Provider(ctx context.Context, id string) (*ai_dto.Provider, error) {
|
|
p, has := model_runtime.GetProvider(id)
|
|
if !has {
|
|
return nil, fmt.Errorf("ai provider not found")
|
|
}
|
|
info, err := i.providerService.Get(ctx, id)
|
|
if err != nil {
|
|
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, err
|
|
}
|
|
defaultLLM, has := p.DefaultModel(model_runtime.ModelTypeLLM)
|
|
if !has {
|
|
return nil, fmt.Errorf("ai provider llm not found")
|
|
}
|
|
return &ai_dto.Provider{
|
|
Id: p.ID(),
|
|
Name: p.Name(),
|
|
Config: p.DefaultConfig(),
|
|
GetAPIKeyUrl: p.HelpUrl(),
|
|
DefaultLLM: defaultLLM.ID(),
|
|
DefaultLLMConfig: defaultLLM.Logo(),
|
|
}, nil
|
|
}
|
|
defaultLLM, has := p.GetModel(info.DefaultLLM)
|
|
if !has {
|
|
model, has := p.DefaultModel(model_runtime.ModelTypeLLM)
|
|
if !has {
|
|
return nil, fmt.Errorf("ai provider llm not found")
|
|
}
|
|
defaultLLM = model
|
|
}
|
|
return &ai_dto.Provider{
|
|
Id: info.Id,
|
|
Name: info.Name,
|
|
Config: p.MaskConfig(info.Config),
|
|
GetAPIKeyUrl: p.HelpUrl(),
|
|
DefaultLLM: defaultLLM.ID(),
|
|
DefaultLLMConfig: defaultLLM.DefaultConfig(),
|
|
}, nil
|
|
}
|
|
|
|
func (i *imlProviderModule) LLMs(ctx context.Context, driver string) ([]*ai_dto.LLMItem, *ai_dto.ProviderItem, error) {
|
|
p, has := model_runtime.GetProvider(driver)
|
|
if !has {
|
|
return nil, nil, fmt.Errorf("ai provider not found")
|
|
}
|
|
|
|
llms, has := p.ModelsByType(model_runtime.ModelTypeLLM)
|
|
if !has {
|
|
return nil, nil, fmt.Errorf("ai provider not found")
|
|
}
|
|
|
|
items := make([]*ai_dto.LLMItem, 0, len(llms))
|
|
for _, v := range llms {
|
|
items = append(items, &ai_dto.LLMItem{
|
|
Id: v.ID(),
|
|
Logo: v.Logo(),
|
|
Config: v.DefaultConfig(),
|
|
Scopes: []string{
|
|
"chat",
|
|
},
|
|
})
|
|
}
|
|
info, err := i.providerService.Get(ctx, driver)
|
|
if err != nil {
|
|
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, nil, err
|
|
}
|
|
defaultLLM, has := p.DefaultModel(model_runtime.ModelTypeLLM)
|
|
if !has {
|
|
return nil, nil, fmt.Errorf("ai provider default llm not found")
|
|
}
|
|
return items, &ai_dto.ProviderItem{
|
|
Id: p.ID(),
|
|
Name: p.Name(),
|
|
DefaultLLM: defaultLLM.ID(),
|
|
DefaultLLMLogo: defaultLLM.Logo(),
|
|
Logo: p.Logo(),
|
|
Configured: false,
|
|
}, nil
|
|
}
|
|
|
|
return items, &ai_dto.ProviderItem{Id: info.Id, Name: info.Name, DefaultLLM: info.DefaultLLM, Logo: p.Logo(), Configured: true}, nil
|
|
}
|
|
|
|
func (i *imlProviderModule) UpdateProviderStatus(ctx context.Context, id string, enable bool) error {
|
|
driver, has := model_runtime.GetProvider(id)
|
|
if !has {
|
|
return fmt.Errorf("ai provider not found")
|
|
}
|
|
info, err := i.providerService.Get(ctx, id)
|
|
if err != nil {
|
|
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return err
|
|
}
|
|
return fmt.Errorf("ai provider not found")
|
|
}
|
|
|
|
return i.transaction.Transaction(ctx, func(txCtx context.Context) error {
|
|
err = i.providerService.Save(txCtx, id, &ai.SetProvider{
|
|
Status: &enable,
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if enable {
|
|
cfg := make(map[string]interface{})
|
|
err = json.Unmarshal([]byte(info.Config), &cfg)
|
|
if err != nil {
|
|
log.Errorf("unmarshal ai provider config error,id is %s,err is %v", info.Id, err)
|
|
return err
|
|
}
|
|
cfg["driver"] = info.Id
|
|
|
|
return i.syncGateway(txCtx, cluster.DefaultClusterID, []*gateway.DynamicRelease{{
|
|
BasicItem: &gateway.BasicItem{
|
|
ID: info.Id,
|
|
Description: info.Name,
|
|
Version: info.UpdateAt.Format("20060102150405"),
|
|
MatchLabels: map[string]string{
|
|
"module": "ai-provider",
|
|
},
|
|
},
|
|
Attr: cfg,
|
|
}, newAIUpstream(info.Id, driver.URI()),
|
|
}, enable)
|
|
} else {
|
|
return i.syncGateway(txCtx, cluster.DefaultClusterID, []*gateway.DynamicRelease{
|
|
{
|
|
BasicItem: &gateway.BasicItem{
|
|
ID: info.Id,
|
|
Resource: info.Id,
|
|
},
|
|
},
|
|
{
|
|
BasicItem: &gateway.BasicItem{
|
|
ID: info.Id,
|
|
Resource: "service",
|
|
},
|
|
},
|
|
}, enable)
|
|
}
|
|
|
|
})
|
|
|
|
}
|
|
|
|
func (i *imlProviderModule) UpdateProviderConfig(ctx context.Context, id string, input *ai_dto.UpdateConfig) error {
|
|
p, has := model_runtime.GetProvider(id)
|
|
if !has {
|
|
return fmt.Errorf("ai provider not found")
|
|
}
|
|
info, err := i.providerService.Get(ctx, id)
|
|
if err != nil {
|
|
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return err
|
|
}
|
|
//defaultLLm, ok := p.DefaultModel(model_runtime.ModelTypeLLM)
|
|
//if !ok {
|
|
// return fmt.Errorf("ai provider default llm not found")
|
|
//}
|
|
info = &ai.Provider{
|
|
Id: id,
|
|
Name: p.Name(),
|
|
DefaultLLM: input.DefaultLLM,
|
|
Config: input.Config,
|
|
}
|
|
}
|
|
err = p.Check(input.Config)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
input.Config, err = p.GenConfig(input.Config, info.Config)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return i.transaction.Transaction(ctx, func(txCtx context.Context) error {
|
|
err = i.providerService.Save(ctx, id, &ai.SetProvider{
|
|
Name: &info.Name,
|
|
DefaultLLM: &info.DefaultLLM,
|
|
Config: &input.Config,
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
cfg := make(map[string]interface{})
|
|
err = json.Unmarshal([]byte(input.Config), &cfg)
|
|
if err != nil {
|
|
log.Errorf("unmarshal ai provider config error,id is %s,err is %v", id, err)
|
|
return err
|
|
}
|
|
|
|
return i.syncGateway(ctx, cluster.DefaultClusterID, []*gateway.DynamicRelease{
|
|
{
|
|
BasicItem: &gateway.BasicItem{
|
|
ID: id,
|
|
Description: info.Name,
|
|
Resource: id,
|
|
Version: info.UpdateAt.Format("20060102150405"),
|
|
MatchLabels: map[string]string{
|
|
"module": "ai-provider",
|
|
},
|
|
},
|
|
Attr: cfg,
|
|
}, newAIUpstream(id, p.URI()),
|
|
}, true)
|
|
})
|
|
}
|
|
|
|
func (i *imlProviderModule) UpdateProviderDefaultLLM(ctx context.Context, id string, input *ai_dto.UpdateLLM) error {
|
|
_, err := i.providerService.Get(ctx, id)
|
|
if err != nil {
|
|
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return err
|
|
}
|
|
return fmt.Errorf("ai provider not found")
|
|
}
|
|
return i.providerService.Save(ctx, id, &ai.SetProvider{
|
|
DefaultLLM: &input.LLM,
|
|
})
|
|
}
|
|
|
|
func (i *imlProviderModule) getAiProviders(ctx context.Context, clusterId string) ([]*gateway.DynamicRelease, error) {
|
|
list, err := i.providerService.List(ctx, clusterId)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
providers := make([]*gateway.DynamicRelease, 0, len(list))
|
|
for _, p := range list {
|
|
if !p.Status {
|
|
// 关闭
|
|
continue
|
|
}
|
|
cfg := make(map[string]interface{})
|
|
err = json.Unmarshal([]byte(p.Config), &cfg)
|
|
if err != nil {
|
|
log.Errorf("unmarshal ai provider config error,id is %s,err is %v", p.Id, err)
|
|
continue
|
|
}
|
|
providers = append(providers, &gateway.DynamicRelease{
|
|
BasicItem: &gateway.BasicItem{
|
|
ID: p.Id,
|
|
Description: p.Name,
|
|
Resource: p.Id,
|
|
Version: p.UpdateAt.Format("20060102150405"),
|
|
MatchLabels: map[string]string{
|
|
"module": "ai-provider",
|
|
},
|
|
},
|
|
Attr: cfg,
|
|
})
|
|
}
|
|
return providers, nil
|
|
}
|
|
func (i *imlProviderModule) initGateway(ctx context.Context, clusterId string, clientDriver gateway.IClientDriver) error {
|
|
providers, err := i.getAiProviders(ctx, clusterId)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
serviceClient, err := clientDriver.Dynamic("service")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for _, p := range providers {
|
|
driver, has := model_runtime.GetProvider(p.ID)
|
|
if !has {
|
|
continue
|
|
}
|
|
client, err := clientDriver.Dynamic(p.ID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
err = client.Online(ctx, providers...)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = serviceClient.Online(ctx, newAIUpstream(p.ID, driver.URI()))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (i *imlProviderModule) syncGateway(ctx context.Context, clusterId string, releases []*gateway.DynamicRelease, online bool) error {
|
|
client, err := i.clusterService.GatewayClient(ctx, clusterId)
|
|
if err != nil {
|
|
log.Errorf("get apinto client error: %v", err)
|
|
return nil
|
|
}
|
|
defer func() {
|
|
err := client.Close(ctx)
|
|
if err != nil {
|
|
log.Warn("close apinto client:", err)
|
|
}
|
|
}()
|
|
for _, releaseInfo := range releases {
|
|
dynamicClient, err := client.Dynamic(releaseInfo.Resource)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if online {
|
|
err = dynamicClient.Online(ctx, releaseInfo)
|
|
} else {
|
|
err = dynamicClient.Offline(ctx, releaseInfo)
|
|
}
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|