Synchronize to gateway when custom model updates

This commit is contained in:
Liujian
2025-03-12 18:33:10 +08:00
parent 358459f37a
commit 9c530ec470
4 changed files with 164 additions and 42 deletions
@@ -87,6 +87,9 @@ provider_credential_schema:
placeholder:
en_US: A model you have access to (e.g. amazon.titan-text-lite-v1) for validation.
zh_Hans: 为了进行验证,请输入一个您可用的模型名称 (例如:amazon.titan-text-lite-v1)
model_config:
access_configuration_status: true
access_configuration_demo: "{}"
address: https://bedrock-runtime.amazonaws.com
sort: 4
recommend: true
@@ -32,7 +32,7 @@ provider_credential_schema:
en_US: Enter your API Key
- variable: base_url
label:
en_US: https://router.huggingface.co/hf-inference/v1
en_US: https://api.groq.com/openai/v1
type: text-input
required: false
placeholder:
+4
View File
@@ -70,6 +70,10 @@ var dynamicResourceMap = map[string]Worker{
Profession: ProfessionAIResource,
Driver: "ai-key",
},
"ai-model": {
Profession: ProfessionAIResource,
Driver: "ai-model",
},
}
type Worker struct {
+156 -41
View File
@@ -1,7 +1,18 @@
package ai_model
import (
"context"
"errors"
"fmt"
"slices"
"time"
"gorm.io/gorm"
"github.com/APIParkLab/APIPark/service/cluster"
"github.com/APIParkLab/APIPark/gateway"
model_runtime "github.com/APIParkLab/APIPark/ai-provider/model-runtime"
model_dto "github.com/APIParkLab/APIPark/module/ai-model/dto"
"github.com/APIParkLab/APIPark/service/ai"
@@ -9,8 +20,8 @@ import (
ai_model "github.com/APIParkLab/APIPark/service/ai-model"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"slices"
"github.com/eolinker/eosc/log"
"github.com/eolinker/go-common/store"
)
@@ -22,6 +33,7 @@ type imlProviderModelModule struct {
providerService ai.IProviderService `autowired:""`
aiApiService ai_api.IAPIService `autowired:""`
providerModelService ai_model.IProviderModelService `autowired:""`
clusterService cluster.IClusterService `autowired:""`
transaction store.ITransaction `autowired:""`
}
@@ -50,55 +62,89 @@ func (i *imlProviderModelModule) UpdateProviderModel(ctx *gin.Context, provider
return fmt.Errorf("ai provider not found")
}
// check provider exist
providerInfo, err := i.providerService.Get(ctx, provider)
_, err := i.providerService.Get(ctx, provider)
if err != nil {
return err
}
if providerInfo == nil {
return fmt.Errorf("provider not found")
}
modelInfo, _ := i.providerModelService.Get(ctx, input.Id)
if modelInfo == nil || modelInfo.Provider != provider {
return fmt.Errorf("model not found")
}
if err := i.providerModelService.Save(ctx, input.Id, &ai_model.Model{
AccessConfiguration: &input.AccessConfiguration,
ModelParameters: &input.ModelParameters,
}); err != nil {
return err
}
// update provider model
iModel, _ := model_runtime.NewCustomizeModel(input.Id, input.Name, p.Logo(), input.AccessConfiguration, input.ModelParameters)
p.SetModel(input.Id, iModel)
return i.transaction.Transaction(ctx, func(ctx context.Context) error {
if err = i.providerModelService.Save(ctx, input.Id, &ai_model.Model{
AccessConfiguration: &input.AccessConfiguration,
ModelParameters: &input.ModelParameters,
}); err != nil {
return err
}
// update provider model
iModel, err := model_runtime.NewCustomizeModel(input.Id, input.Name, p.Logo(), input.AccessConfiguration, input.ModelParameters)
if err != nil {
return err
}
// 判断是否需要发布model
if p.GetModelConfig().AccessConfigurationStatus {
if err := i.syncGateway(ctx, cluster.DefaultClusterID, []*gateway.DynamicRelease{
newModel(provider, input.Name, input.AccessConfiguration),
}, true); err != nil {
return err
}
}
p.SetModel(input.Id, iModel)
return nil
})
return nil
}
func (i *imlProviderModelModule) DeleteProviderModel(ctx *gin.Context, provider string, id string) error {
p, has := model_runtime.GetProvider(provider)
// check provider exist
providerInfo, err := i.providerService.Get(ctx, provider)
if err != nil {
return err
if !has {
return fmt.Errorf("ai provider not found")
}
if providerInfo == nil || !has {
return fmt.Errorf("provider not found")
// check provider exist
_, err := i.providerService.Get(ctx, provider)
if err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) {
return fmt.Errorf("provider not found")
}
return err
}
modelInfo, _ := i.providerModelService.Get(ctx, id)
if modelInfo == nil || modelInfo.Provider != provider {
return fmt.Errorf("model not found")
}
// check model in use
countMapByModel, _ := i.aiApiService.CountMapByModel(ctx, "", map[string]interface{}{"model": id})
if countValue, has := countMapByModel[id]; has && countValue > 0 {
return fmt.Errorf("model in use")
}
if err := i.providerModelService.Delete(ctx, id); err != nil {
return err
}
p.RemoveModel(id)
return i.transaction.Transaction(ctx, func(ctx context.Context) error {
// check model in use
count, err := i.aiApiService.CountByModel(ctx, id)
if err != nil {
return err
}
if count > 0 {
return fmt.Errorf("model in use")
}
if err := i.providerModelService.Delete(ctx, id); err != nil {
return err
}
err = i.syncGateway(ctx, cluster.DefaultClusterID, []*gateway.DynamicRelease{
{
BasicItem: &gateway.BasicItem{
ID: fmt.Sprintf("%s#%s", provider, modelInfo.Name),
Resource: "ai-model",
},
Attr: nil,
},
}, false)
if err != nil {
return err
}
p.RemoveModel(id)
return nil
})
return nil
}
func (i *imlProviderModelModule) AddProviderModel(ctx *gin.Context, provider string, input *model_dto.Model) (*model_dto.SimpleModel, error) {
@@ -115,21 +161,90 @@ func (i *imlProviderModelModule) AddProviderModel(ctx *gin.Context, provider str
return nil, fmt.Errorf("provider model already exist")
}
id := uuid.New().String()
typeValue := "chat"
if err := i.providerModelService.Save(ctx, id, &ai_model.Model{
Name: &input.Name,
Type: &typeValue,
Provider: &provider,
AccessConfiguration: &input.AccessConfiguration,
ModelParameters: &input.ModelParameters,
}); err != nil {
err := i.transaction.Transaction(ctx, func(ctx context.Context) error {
typeValue := "chat"
err := i.providerModelService.Save(ctx, id, &ai_model.Model{
Name: &input.Name,
Type: &typeValue,
Provider: &provider,
AccessConfiguration: &input.AccessConfiguration,
ModelParameters: &input.ModelParameters,
})
if err != nil {
return err
}
// update provider model
iModel, err := model_runtime.NewCustomizeModel(id, input.Name, p.Logo(), input.AccessConfiguration, input.ModelParameters)
if err != nil {
return err
}
// 判断是否需要发布model
if p.GetModelConfig().AccessConfigurationStatus {
if err := i.syncGateway(ctx, cluster.DefaultClusterID, []*gateway.DynamicRelease{
newModel(provider, input.Name, input.AccessConfiguration),
}, true); err != nil {
return err
}
}
p.SetModel(id, iModel)
return nil
})
if err != nil {
return nil, err
}
// update provider model
iModel, _ := model_runtime.NewCustomizeModel(id, input.Name, p.Logo(), input.AccessConfiguration, input.ModelParameters)
p.SetModel(id, iModel)
return &model_dto.SimpleModel{
Id: id,
Name: input.Name,
}, nil
}
func newModel(provider string, model string, config string) *gateway.DynamicRelease {
return &gateway.DynamicRelease{
BasicItem: &gateway.BasicItem{
ID: fmt.Sprintf("%s$%s", provider, model),
Description: fmt.Sprintf("auto generated model: %s, provider: %s", model, provider),
Resource: "ai-model",
Version: time.Now().Format("20060102150405"),
MatchLabels: map[string]string{
"module": "ai-model",
},
},
Attr: map[string]interface{}{
"provider": provider,
"model": model,
"access_config": config,
},
}
}
func (i *imlProviderModelModule) syncGateway(ctx context.Context, clusterId string, releases []*gateway.DynamicRelease, online bool) error {
client, err := i.clusterService.GatewayClient(ctx, clusterId)
if err != nil {
log.Errorf("get apinto client error: %v", err)
return nil
}
defer func() {
err := client.Close(ctx)
if err != nil {
log.Warn("close apinto client:", err)
}
}()
for _, releaseInfo := range releases {
dynamicClient, err := client.Dynamic(releaseInfo.Resource)
if err != nil {
return err
}
if online {
err = dynamicClient.Online(ctx, releaseInfo)
} else {
err = dynamicClient.Offline(ctx, releaseInfo)
}
if err != nil {
return err
}
}
return nil
}