mirror of
https://github.com/APIParkLab/APIPark.git
synced 2026-06-04 10:13:53 +08:00
Synchronize to gateway when custom model updates
This commit is contained in:
@@ -87,6 +87,9 @@ provider_credential_schema:
|
||||
placeholder:
|
||||
en_US: A model you have access to (e.g. amazon.titan-text-lite-v1) for validation.
|
||||
zh_Hans: 为了进行验证,请输入一个您可用的模型名称 (例如:amazon.titan-text-lite-v1)
|
||||
model_config:
|
||||
access_configuration_status: true
|
||||
access_configuration_demo: "{}"
|
||||
address: https://bedrock-runtime.amazonaws.com
|
||||
sort: 4
|
||||
recommend: true
|
||||
@@ -32,7 +32,7 @@ provider_credential_schema:
|
||||
en_US: Enter your API Key
|
||||
- variable: base_url
|
||||
label:
|
||||
en_US: https://router.huggingface.co/hf-inference/v1
|
||||
en_US: https://api.groq.com/openai/v1
|
||||
type: text-input
|
||||
required: false
|
||||
placeholder:
|
||||
|
||||
@@ -70,6 +70,10 @@ var dynamicResourceMap = map[string]Worker{
|
||||
Profession: ProfessionAIResource,
|
||||
Driver: "ai-key",
|
||||
},
|
||||
"ai-model": {
|
||||
Profession: ProfessionAIResource,
|
||||
Driver: "ai-model",
|
||||
},
|
||||
}
|
||||
|
||||
type Worker struct {
|
||||
|
||||
+156
-41
@@ -1,7 +1,18 @@
|
||||
package ai_model
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/APIParkLab/APIPark/service/cluster"
|
||||
|
||||
"github.com/APIParkLab/APIPark/gateway"
|
||||
|
||||
model_runtime "github.com/APIParkLab/APIPark/ai-provider/model-runtime"
|
||||
model_dto "github.com/APIParkLab/APIPark/module/ai-model/dto"
|
||||
"github.com/APIParkLab/APIPark/service/ai"
|
||||
@@ -9,8 +20,8 @@ import (
|
||||
ai_model "github.com/APIParkLab/APIPark/service/ai-model"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"slices"
|
||||
|
||||
"github.com/eolinker/eosc/log"
|
||||
"github.com/eolinker/go-common/store"
|
||||
)
|
||||
|
||||
@@ -22,6 +33,7 @@ type imlProviderModelModule struct {
|
||||
providerService ai.IProviderService `autowired:""`
|
||||
aiApiService ai_api.IAPIService `autowired:""`
|
||||
providerModelService ai_model.IProviderModelService `autowired:""`
|
||||
clusterService cluster.IClusterService `autowired:""`
|
||||
transaction store.ITransaction `autowired:""`
|
||||
}
|
||||
|
||||
@@ -50,55 +62,89 @@ func (i *imlProviderModelModule) UpdateProviderModel(ctx *gin.Context, provider
|
||||
return fmt.Errorf("ai provider not found")
|
||||
}
|
||||
// check provider exist
|
||||
providerInfo, err := i.providerService.Get(ctx, provider)
|
||||
_, err := i.providerService.Get(ctx, provider)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if providerInfo == nil {
|
||||
return fmt.Errorf("provider not found")
|
||||
}
|
||||
modelInfo, _ := i.providerModelService.Get(ctx, input.Id)
|
||||
if modelInfo == nil || modelInfo.Provider != provider {
|
||||
return fmt.Errorf("model not found")
|
||||
}
|
||||
if err := i.providerModelService.Save(ctx, input.Id, &ai_model.Model{
|
||||
AccessConfiguration: &input.AccessConfiguration,
|
||||
ModelParameters: &input.ModelParameters,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
// update provider model
|
||||
iModel, _ := model_runtime.NewCustomizeModel(input.Id, input.Name, p.Logo(), input.AccessConfiguration, input.ModelParameters)
|
||||
p.SetModel(input.Id, iModel)
|
||||
return i.transaction.Transaction(ctx, func(ctx context.Context) error {
|
||||
if err = i.providerModelService.Save(ctx, input.Id, &ai_model.Model{
|
||||
AccessConfiguration: &input.AccessConfiguration,
|
||||
ModelParameters: &input.ModelParameters,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// update provider model
|
||||
iModel, err := model_runtime.NewCustomizeModel(input.Id, input.Name, p.Logo(), input.AccessConfiguration, input.ModelParameters)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// 判断是否需要发布model
|
||||
if p.GetModelConfig().AccessConfigurationStatus {
|
||||
if err := i.syncGateway(ctx, cluster.DefaultClusterID, []*gateway.DynamicRelease{
|
||||
newModel(provider, input.Name, input.AccessConfiguration),
|
||||
}, true); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
p.SetModel(input.Id, iModel)
|
||||
return nil
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *imlProviderModelModule) DeleteProviderModel(ctx *gin.Context, provider string, id string) error {
|
||||
p, has := model_runtime.GetProvider(provider)
|
||||
// check provider exist
|
||||
providerInfo, err := i.providerService.Get(ctx, provider)
|
||||
if err != nil {
|
||||
return err
|
||||
if !has {
|
||||
return fmt.Errorf("ai provider not found")
|
||||
}
|
||||
if providerInfo == nil || !has {
|
||||
return fmt.Errorf("provider not found")
|
||||
// check provider exist
|
||||
_, err := i.providerService.Get(ctx, provider)
|
||||
if err != nil {
|
||||
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return fmt.Errorf("provider not found")
|
||||
}
|
||||
return err
|
||||
}
|
||||
modelInfo, _ := i.providerModelService.Get(ctx, id)
|
||||
if modelInfo == nil || modelInfo.Provider != provider {
|
||||
return fmt.Errorf("model not found")
|
||||
}
|
||||
// check model in use
|
||||
countMapByModel, _ := i.aiApiService.CountMapByModel(ctx, "", map[string]interface{}{"model": id})
|
||||
if countValue, has := countMapByModel[id]; has && countValue > 0 {
|
||||
return fmt.Errorf("model in use")
|
||||
}
|
||||
if err := i.providerModelService.Delete(ctx, id); err != nil {
|
||||
return err
|
||||
}
|
||||
p.RemoveModel(id)
|
||||
return i.transaction.Transaction(ctx, func(ctx context.Context) error {
|
||||
// check model in use
|
||||
count, err := i.aiApiService.CountByModel(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if count > 0 {
|
||||
return fmt.Errorf("model in use")
|
||||
}
|
||||
if err := i.providerModelService.Delete(ctx, id); err != nil {
|
||||
return err
|
||||
}
|
||||
err = i.syncGateway(ctx, cluster.DefaultClusterID, []*gateway.DynamicRelease{
|
||||
{
|
||||
BasicItem: &gateway.BasicItem{
|
||||
ID: fmt.Sprintf("%s#%s", provider, modelInfo.Name),
|
||||
Resource: "ai-model",
|
||||
},
|
||||
Attr: nil,
|
||||
},
|
||||
}, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p.RemoveModel(id)
|
||||
return nil
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *imlProviderModelModule) AddProviderModel(ctx *gin.Context, provider string, input *model_dto.Model) (*model_dto.SimpleModel, error) {
|
||||
@@ -115,21 +161,90 @@ func (i *imlProviderModelModule) AddProviderModel(ctx *gin.Context, provider str
|
||||
return nil, fmt.Errorf("provider model already exist")
|
||||
}
|
||||
id := uuid.New().String()
|
||||
typeValue := "chat"
|
||||
if err := i.providerModelService.Save(ctx, id, &ai_model.Model{
|
||||
Name: &input.Name,
|
||||
Type: &typeValue,
|
||||
Provider: &provider,
|
||||
AccessConfiguration: &input.AccessConfiguration,
|
||||
ModelParameters: &input.ModelParameters,
|
||||
}); err != nil {
|
||||
err := i.transaction.Transaction(ctx, func(ctx context.Context) error {
|
||||
typeValue := "chat"
|
||||
err := i.providerModelService.Save(ctx, id, &ai_model.Model{
|
||||
Name: &input.Name,
|
||||
Type: &typeValue,
|
||||
Provider: &provider,
|
||||
AccessConfiguration: &input.AccessConfiguration,
|
||||
ModelParameters: &input.ModelParameters,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// update provider model
|
||||
iModel, err := model_runtime.NewCustomizeModel(id, input.Name, p.Logo(), input.AccessConfiguration, input.ModelParameters)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// 判断是否需要发布model
|
||||
if p.GetModelConfig().AccessConfigurationStatus {
|
||||
if err := i.syncGateway(ctx, cluster.DefaultClusterID, []*gateway.DynamicRelease{
|
||||
newModel(provider, input.Name, input.AccessConfiguration),
|
||||
}, true); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
p.SetModel(id, iModel)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// update provider model
|
||||
iModel, _ := model_runtime.NewCustomizeModel(id, input.Name, p.Logo(), input.AccessConfiguration, input.ModelParameters)
|
||||
p.SetModel(id, iModel)
|
||||
return &model_dto.SimpleModel{
|
||||
Id: id,
|
||||
Name: input.Name,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func newModel(provider string, model string, config string) *gateway.DynamicRelease {
|
||||
|
||||
return &gateway.DynamicRelease{
|
||||
BasicItem: &gateway.BasicItem{
|
||||
ID: fmt.Sprintf("%s$%s", provider, model),
|
||||
Description: fmt.Sprintf("auto generated model: %s, provider: %s", model, provider),
|
||||
Resource: "ai-model",
|
||||
Version: time.Now().Format("20060102150405"),
|
||||
MatchLabels: map[string]string{
|
||||
"module": "ai-model",
|
||||
},
|
||||
},
|
||||
Attr: map[string]interface{}{
|
||||
"provider": provider,
|
||||
"model": model,
|
||||
"access_config": config,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (i *imlProviderModelModule) syncGateway(ctx context.Context, clusterId string, releases []*gateway.DynamicRelease, online bool) error {
|
||||
client, err := i.clusterService.GatewayClient(ctx, clusterId)
|
||||
if err != nil {
|
||||
log.Errorf("get apinto client error: %v", err)
|
||||
return nil
|
||||
}
|
||||
defer func() {
|
||||
err := client.Close(ctx)
|
||||
if err != nil {
|
||||
log.Warn("close apinto client:", err)
|
||||
}
|
||||
}()
|
||||
for _, releaseInfo := range releases {
|
||||
dynamicClient, err := client.Dynamic(releaseInfo.Resource)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if online {
|
||||
err = dynamicClient.Online(ctx, releaseInfo)
|
||||
} else {
|
||||
err = dynamicClient.Offline(ctx, releaseInfo)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user