Add support for creating online models and integrating custom model providers

This commit is contained in:
sunanzhi
2025-03-06 17:42:17 +08:00
parent 7a84c5aec3
commit b9f6abc9b3
53 changed files with 1796 additions and 33 deletions
+4
View File
@@ -14,3 +14,7 @@ type UpdateConfig struct {
type Sort struct {
Providers []string `json:"providers"`
}
type NewProvider struct {
Name string `json:"name"`
}
+24 -6
View File
@@ -4,6 +4,11 @@ import (
"github.com/eolinker/go-common/auto"
)
type SimpleModel struct {
Id string `json:"id"`
Name string `json:"name"`
}
type SimpleProvider struct {
Id string `json:"id"`
Name string `json:"name"`
@@ -20,8 +25,14 @@ type Provider struct {
DefaultLLM string `json:"default_llm"`
DefaultLLMConfig string `json:"-"`
//Priority int `json:"priority"`
Status ProviderStatus `json:"status"`
Configured bool `json:"configured"`
Status ProviderStatus `json:"status"`
Configured bool `json:"configured"`
ModelConfig ModelConfig `json:"model_config"`
}
type ModelConfig struct {
AccessConfigurationStatus bool `json:"access_configuration_status"`
AccessConfigurationDemo string `json:"access_configuration_demo"`
}
type ConfiguredProviderItem struct {
@@ -32,6 +43,7 @@ type ConfiguredProviderItem struct {
Status ProviderStatus `json:"status"`
APICount int64 `json:"api_count"`
KeyCount int64 `json:"key_count"`
ModelCount int64 `json:"model_count"`
CanDelete bool `json:"can_delete"`
}
@@ -48,6 +60,7 @@ type ProviderItem struct {
Logo string `json:"logo"`
DefaultLLM string `json:"default_llm"`
Sort int `json:"-"`
Type int `json:"type"` // 0:default 1:customize
}
type SimpleProviderItem struct {
@@ -69,10 +82,15 @@ type BackupProvider struct {
}
type LLMItem struct {
Id string `json:"id"`
Logo string `json:"logo"`
Config string `json:"config"`
Scopes []string `json:"scopes"`
Id string `json:"id"`
Logo string `json:"logo"`
Config string `json:"config"`
AccessConfiguration string `json:"access_configuration"`
ModelParameters string `json:"model_parameters"`
Scopes []string `json:"scopes"`
Type string `json:"type"`
IsSystem bool `json:"is_system"`
ApiCount int64 `json:"api_count"`
}
type APIItem struct {
+88 -14
View File
@@ -4,12 +4,12 @@ import (
"context"
"errors"
"fmt"
ai_model "github.com/APIParkLab/APIPark/service/ai-model"
"github.com/google/uuid"
"net/http"
"sort"
"time"
"github.com/google/uuid"
ai_provider_local "github.com/APIParkLab/APIPark/ai-provider/local"
"github.com/eolinker/go-common/register"
@@ -63,12 +63,13 @@ func newKey(key *ai_key.Key) *gateway.DynamicRelease {
var _ IProviderModule = (*imlProviderModule)(nil)
type imlProviderModule struct {
providerService ai.IProviderService `autowired:""`
clusterService cluster.IClusterService `autowired:""`
aiAPIService ai_api.IAPIService `autowired:""`
aiKeyService ai_key.IKeyService `autowired:""`
aiBalanceService ai_balance.IBalanceService `autowired:""`
transaction store.ITransaction `autowired:""`
providerService ai.IProviderService `autowired:""`
providerModelService ai_model.IProviderModelService `autowired:""`
clusterService cluster.IClusterService `autowired:""`
aiAPIService ai_api.IAPIService `autowired:""`
aiKeyService ai_key.IKeyService `autowired:""`
aiBalanceService ai_balance.IBalanceService `autowired:""`
transaction store.ITransaction `autowired:""`
}
func (i *imlProviderModule) OnInit() {
@@ -79,6 +80,33 @@ func (i *imlProviderModule) OnInit() {
if err != nil {
return
}
// register provider
for _, p := range list {
// get customize models
models, _ := i.providerModelService.Search(ctx, "", map[string]interface{}{"provider": p.Id}, "update_at desc")
iModels := make([]model_runtime.IModel, 0, len(models))
if models != nil {
for _, model := range models {
// parse access_config & model_parameters
iModel, _ := model_runtime.NewCustomizeModel(model.Id, model.Name, model_runtime.GetCustomizeLogo(), model.AccessConfiguration, model.ModelParameters)
iModels = append(iModels, iModel)
}
}
// default provider
if p.Type == 0 {
runtimeProvider, _ := model_runtime.GetProvider(p.Id)
for _, tmpIModel := range iModels {
tmpIModel.SetLogo(runtimeProvider.Logo())
if p.DefaultLLM == tmpIModel.ID() {
runtimeProvider.SetDefaultModel(model_runtime.ModelTypeLLM, tmpIModel)
}
runtimeProvider.SetModel(tmpIModel.ID(), tmpIModel)
}
} else {
provider, _ := model_runtime.NewCustomizeProvider(p.Id, p.Name, iModels, p.DefaultLLM, p.Config)
model_runtime.Register(p.Id, provider)
}
}
i.transaction.Transaction(ctx, func(ctx context.Context) error {
for _, l := range list {
if l.Priority < 1 {
@@ -145,6 +173,8 @@ func (i *imlProviderModule) Delete(ctx context.Context, id string) error {
if err != nil {
return err
}
// delete register provider
model_runtime.Remove(id)
releases := make([]*gateway.DynamicRelease, 0, len(keys))
for _, key := range keys {
releases = append(releases, newKey(key))
@@ -164,6 +194,30 @@ func (i *imlProviderModule) Delete(ctx context.Context, id string) error {
})
}
func (i *imlProviderModule) AddProvider(ctx context.Context, input *ai_dto.NewProvider) (*ai_dto.SimpleProvider, error) {
if has := i.providerService.CheckNameDuplicate(ctx, input.Name); has {
return nil, fmt.Errorf("provider `%s` duplicate", input.Name)
}
id := uuid.New().String()
config, defaultLLM := "{\"api_endpoint_url\": \"http://127.0.0.1\", \"api_key\": \"\"}", ""
typeValue := 1
if err := i.providerService.Save(ctx, id, &ai.SetProvider{
Name: &input.Name,
DefaultLLM: &defaultLLM,
Config: &config,
Type: &typeValue,
}); err != nil {
return nil, err
}
// register provider
iProvider, _ := model_runtime.NewCustomizeProvider(id, input.Name, []model_runtime.IModel{}, "", "")
model_runtime.Register(id, iProvider)
return &ai_dto.SimpleProvider{
Id: id,
Name: input.Name,
}, nil
}
func (i *imlProviderModule) SimpleProvider(ctx context.Context, id string) (*ai_dto.SimpleProvider, error) {
p, has := model_runtime.GetProvider(id)
if !has {
@@ -231,6 +285,7 @@ func (i *imlProviderModule) ConfiguredProviders(ctx context.Context, keyword str
APICount: apiCount,
KeyCount: keyMap[l.Id],
CanDelete: apiCount < 1,
ModelCount: int64(len(p.Models())),
})
}
@@ -394,8 +449,9 @@ func (i *imlProviderModule) Provider(ctx context.Context, id string) (*ai_dto.Pr
}
defaultLLM, has := p.DefaultModel(model_runtime.ModelTypeLLM)
if !has {
return nil, fmt.Errorf("ai provider llm not found")
defaultLLM, _ = model_runtime.NewCustomizeModel("", "", "", "", "")
}
providerModelConfig := p.GetModelConfig()
return &ai_dto.Provider{
Id: p.ID(),
Name: p.Name(),
@@ -405,13 +461,17 @@ func (i *imlProviderModule) Provider(ctx context.Context, id string) (*ai_dto.Pr
DefaultLLMConfig: defaultLLM.Logo(),
Status: ai_dto.ProviderDisabled,
//Priority: maxPriority,
ModelConfig: ai_dto.ModelConfig{
AccessConfigurationStatus: providerModelConfig.AccessConfigurationStatus,
AccessConfigurationDemo: providerModelConfig.AccessConfigurationDemo,
},
}, nil
}
defaultLLM, has := p.GetModel(info.DefaultLLM)
if !has {
model, has := p.DefaultModel(model_runtime.ModelTypeLLM)
if !has {
return nil, fmt.Errorf("ai provider llm not found")
defaultLLM, _ = model_runtime.NewCustomizeModel("", "", "", "", "")
}
defaultLLM = model
}
@@ -426,6 +486,10 @@ func (i *imlProviderModule) Provider(ctx context.Context, id string) (*ai_dto.Pr
//Priority: info.Priority,
Status: ai_dto.ToProviderStatus(info.Status),
Configured: true,
ModelConfig: ai_dto.ModelConfig{
AccessConfigurationStatus: false,
AccessConfigurationDemo: "",
},
}, nil
}
@@ -439,16 +503,21 @@ func (i *imlProviderModule) LLMs(ctx context.Context, driver string) ([]*ai_dto.
if !has {
return nil, nil, fmt.Errorf("ai provider not found")
}
modelApiCountMap, _ := i.aiAPIService.CountMapByModel(ctx, "", map[string]interface{}{"provider": driver})
items := make([]*ai_dto.LLMItem, 0, len(llms))
for _, v := range llms {
items = append(items, &ai_dto.LLMItem{
Id: v.ID(),
Logo: v.Logo(),
Config: v.DefaultConfig(),
Id: v.ID(),
Logo: v.Logo(),
Config: v.DefaultConfig(),
AccessConfiguration: v.AccessConfiguration(),
ModelParameters: v.ModelParameters(),
Scopes: []string{
"chat",
},
Type: "chat",
IsSystem: v.Source() != "customize",
ApiCount: modelApiCountMap[v.ID()],
})
}
info, err := i.providerService.Get(ctx, driver)
@@ -523,6 +592,11 @@ func (i *imlProviderModule) UpdateProviderConfig(ctx context.Context, id string,
if err != nil {
return err
}
if input.DefaultLLM != "" {
if defaultLLM, has := p.GetModel(input.DefaultLLM); has {
p.SetDefaultModel(model_runtime.ModelTypeLLM, defaultLLM)
}
}
status := 0
if input.Enable != nil && *input.Enable {
status = 1
+1
View File
@@ -19,6 +19,7 @@ type IProviderModule interface {
LLMs(ctx context.Context, driver string) ([]*ai_dto.LLMItem, *ai_dto.ProviderItem, error)
UpdateProviderConfig(ctx context.Context, id string, input *ai_dto.UpdateConfig) error
Delete(ctx context.Context, id string) error
AddProvider(ctx context.Context, input *ai_dto.NewProvider) (*ai_dto.SimpleProvider, error)
}
type IAIAPIModule interface {