diff --git a/ai-provider/model-runtime/model-providers/bedrock/bedrock.yaml b/ai-provider/model-runtime/model-providers/bedrock/bedrock.yaml index b9f165c2..04f76d30 100644 --- a/ai-provider/model-runtime/model-providers/bedrock/bedrock.yaml +++ b/ai-provider/model-runtime/model-providers/bedrock/bedrock.yaml @@ -87,6 +87,9 @@ provider_credential_schema: placeholder: en_US: A model you have access to (e.g. amazon.titan-text-lite-v1) for validation. zh_Hans: 为了进行验证,请输入一个您可用的模型名称 (例如:amazon.titan-text-lite-v1) +model_config: + access_configuration_status: true + access_configuration_demo: "{}" address: https://bedrock-runtime.amazonaws.com sort: 4 recommend: true \ No newline at end of file diff --git a/ai-provider/model-runtime/model-providers/groq/groq.yaml b/ai-provider/model-runtime/model-providers/groq/groq.yaml index 7fc5d275..697a4e10 100644 --- a/ai-provider/model-runtime/model-providers/groq/groq.yaml +++ b/ai-provider/model-runtime/model-providers/groq/groq.yaml @@ -32,7 +32,7 @@ provider_credential_schema: en_US: Enter your API Key - variable: base_url label: - en_US: https://router.huggingface.co/hf-inference/v1 + en_US: https://api.groq.com/openai/v1 type: text-input required: false placeholder: diff --git a/gateway/profession.go b/gateway/profession.go index 16956421..0be26331 100644 --- a/gateway/profession.go +++ b/gateway/profession.go @@ -70,6 +70,10 @@ var dynamicResourceMap = map[string]Worker{ Profession: ProfessionAIResource, Driver: "ai-key", }, + "ai-model": { + Profession: ProfessionAIResource, + Driver: "ai-model", + }, } type Worker struct { diff --git a/module/ai-model/iml.go b/module/ai-model/iml.go index b819f1e9..26e1410e 100644 --- a/module/ai-model/iml.go +++ b/module/ai-model/iml.go @@ -1,7 +1,18 @@ package ai_model import ( + "context" + "errors" "fmt" + "slices" + "time" + + "gorm.io/gorm" + + "github.com/APIParkLab/APIPark/service/cluster" + + "github.com/APIParkLab/APIPark/gateway" + model_runtime "github.com/APIParkLab/APIPark/ai-provider/model-runtime" model_dto "github.com/APIParkLab/APIPark/module/ai-model/dto" "github.com/APIParkLab/APIPark/service/ai" @@ -9,8 +20,8 @@ import ( ai_model "github.com/APIParkLab/APIPark/service/ai-model" "github.com/gin-gonic/gin" "github.com/google/uuid" - "slices" + "github.com/eolinker/eosc/log" "github.com/eolinker/go-common/store" ) @@ -22,6 +33,7 @@ type imlProviderModelModule struct { providerService ai.IProviderService `autowired:""` aiApiService ai_api.IAPIService `autowired:""` providerModelService ai_model.IProviderModelService `autowired:""` + clusterService cluster.IClusterService `autowired:""` transaction store.ITransaction `autowired:""` } @@ -50,55 +62,89 @@ func (i *imlProviderModelModule) UpdateProviderModel(ctx *gin.Context, provider return fmt.Errorf("ai provider not found") } // check provider exist - providerInfo, err := i.providerService.Get(ctx, provider) + _, err := i.providerService.Get(ctx, provider) if err != nil { return err } - if providerInfo == nil { - return fmt.Errorf("provider not found") - } modelInfo, _ := i.providerModelService.Get(ctx, input.Id) if modelInfo == nil || modelInfo.Provider != provider { return fmt.Errorf("model not found") } - if err := i.providerModelService.Save(ctx, input.Id, &ai_model.Model{ - AccessConfiguration: &input.AccessConfiguration, - ModelParameters: &input.ModelParameters, - }); err != nil { - return err - } - // update provider model - iModel, _ := model_runtime.NewCustomizeModel(input.Id, input.Name, p.Logo(), input.AccessConfiguration, input.ModelParameters) - p.SetModel(input.Id, iModel) + return i.transaction.Transaction(ctx, func(ctx context.Context) error { + if err = i.providerModelService.Save(ctx, input.Id, &ai_model.Model{ + AccessConfiguration: &input.AccessConfiguration, + ModelParameters: &input.ModelParameters, + }); err != nil { + return err + } + + // update provider model + iModel, err := model_runtime.NewCustomizeModel(input.Id, input.Name, p.Logo(), input.AccessConfiguration, input.ModelParameters) + if err != nil { + return err + } + // 判断是否需要发布model + if p.GetModelConfig().AccessConfigurationStatus { + if err := i.syncGateway(ctx, cluster.DefaultClusterID, []*gateway.DynamicRelease{ + newModel(provider, input.Name, input.AccessConfiguration), + }, true); err != nil { + return err + } + } + + p.SetModel(input.Id, iModel) + return nil + }) return nil } func (i *imlProviderModelModule) DeleteProviderModel(ctx *gin.Context, provider string, id string) error { p, has := model_runtime.GetProvider(provider) - // check provider exist - providerInfo, err := i.providerService.Get(ctx, provider) - if err != nil { - return err + if !has { + return fmt.Errorf("ai provider not found") } - if providerInfo == nil || !has { - return fmt.Errorf("provider not found") + // check provider exist + _, err := i.providerService.Get(ctx, provider) + if err != nil { + if !errors.Is(err, gorm.ErrRecordNotFound) { + return fmt.Errorf("provider not found") + } + return err } modelInfo, _ := i.providerModelService.Get(ctx, id) if modelInfo == nil || modelInfo.Provider != provider { return fmt.Errorf("model not found") } - // check model in use - countMapByModel, _ := i.aiApiService.CountMapByModel(ctx, "", map[string]interface{}{"model": id}) - if countValue, has := countMapByModel[id]; has && countValue > 0 { - return fmt.Errorf("model in use") - } - if err := i.providerModelService.Delete(ctx, id); err != nil { - return err - } - p.RemoveModel(id) + return i.transaction.Transaction(ctx, func(ctx context.Context) error { + // check model in use + count, err := i.aiApiService.CountByModel(ctx, id) + if err != nil { + return err + } + if count > 0 { + return fmt.Errorf("model in use") + } + if err := i.providerModelService.Delete(ctx, id); err != nil { + return err + } + err = i.syncGateway(ctx, cluster.DefaultClusterID, []*gateway.DynamicRelease{ + { + BasicItem: &gateway.BasicItem{ + ID: fmt.Sprintf("%s#%s", provider, modelInfo.Name), + Resource: "ai-model", + }, + Attr: nil, + }, + }, false) + if err != nil { + return err + } + + p.RemoveModel(id) + return nil + }) - return nil } func (i *imlProviderModelModule) AddProviderModel(ctx *gin.Context, provider string, input *model_dto.Model) (*model_dto.SimpleModel, error) { @@ -115,21 +161,90 @@ func (i *imlProviderModelModule) AddProviderModel(ctx *gin.Context, provider str return nil, fmt.Errorf("provider model already exist") } id := uuid.New().String() - typeValue := "chat" - if err := i.providerModelService.Save(ctx, id, &ai_model.Model{ - Name: &input.Name, - Type: &typeValue, - Provider: &provider, - AccessConfiguration: &input.AccessConfiguration, - ModelParameters: &input.ModelParameters, - }); err != nil { + err := i.transaction.Transaction(ctx, func(ctx context.Context) error { + typeValue := "chat" + err := i.providerModelService.Save(ctx, id, &ai_model.Model{ + Name: &input.Name, + Type: &typeValue, + Provider: &provider, + AccessConfiguration: &input.AccessConfiguration, + ModelParameters: &input.ModelParameters, + }) + if err != nil { + return err + } + // update provider model + iModel, err := model_runtime.NewCustomizeModel(id, input.Name, p.Logo(), input.AccessConfiguration, input.ModelParameters) + if err != nil { + return err + } + // 判断是否需要发布model + if p.GetModelConfig().AccessConfigurationStatus { + if err := i.syncGateway(ctx, cluster.DefaultClusterID, []*gateway.DynamicRelease{ + newModel(provider, input.Name, input.AccessConfiguration), + }, true); err != nil { + return err + } + } + + p.SetModel(id, iModel) + return nil + }) + if err != nil { return nil, err } - // update provider model - iModel, _ := model_runtime.NewCustomizeModel(id, input.Name, p.Logo(), input.AccessConfiguration, input.ModelParameters) - p.SetModel(id, iModel) return &model_dto.SimpleModel{ Id: id, Name: input.Name, }, nil } + +func newModel(provider string, model string, config string) *gateway.DynamicRelease { + + return &gateway.DynamicRelease{ + BasicItem: &gateway.BasicItem{ + ID: fmt.Sprintf("%s$%s", provider, model), + Description: fmt.Sprintf("auto generated model: %s, provider: %s", model, provider), + Resource: "ai-model", + Version: time.Now().Format("20060102150405"), + MatchLabels: map[string]string{ + "module": "ai-model", + }, + }, + Attr: map[string]interface{}{ + "provider": provider, + "model": model, + "access_config": config, + }, + } +} + +func (i *imlProviderModelModule) syncGateway(ctx context.Context, clusterId string, releases []*gateway.DynamicRelease, online bool) error { + client, err := i.clusterService.GatewayClient(ctx, clusterId) + if err != nil { + log.Errorf("get apinto client error: %v", err) + return nil + } + defer func() { + err := client.Close(ctx) + if err != nil { + log.Warn("close apinto client:", err) + } + }() + for _, releaseInfo := range releases { + dynamicClient, err := client.Dynamic(releaseInfo.Resource) + if err != nil { + return err + } + if online { + err = dynamicClient.Online(ctx, releaseInfo) + } else { + err = dynamicClient.Offline(ctx, releaseInfo) + } + if err != nil { + return err + } + } + + return nil +}