From 3e353f4eff55f549a9a8ba3187cca033629c1e49 Mon Sep 17 00:00:00 2001 From: Liujian <824010343@qq.com> Date: Wed, 16 Oct 2024 16:17:21 +0800 Subject: [PATCH] finish nvidia --- .../model-runtime/model-providers/nvidia/nvidia.yaml | 2 +- controller/service/iml.go | 10 +++++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/ai-provider/model-runtime/model-providers/nvidia/nvidia.yaml b/ai-provider/model-runtime/model-providers/nvidia/nvidia.yaml index 2cfdf42a..e2fd93c7 100644 --- a/ai-provider/model-runtime/model-providers/nvidia/nvidia.yaml +++ b/ai-provider/model-runtime/model-providers/nvidia/nvidia.yaml @@ -31,4 +31,4 @@ provider_credential_schema: placeholder: zh_Hans: 在此输入您的 API Key en_US: Enter your API Key -address: https://api.openai.com \ No newline at end of file +address: https://integrate.api.nvidia.com \ No newline at end of file diff --git a/controller/service/iml.go b/controller/service/iml.go index 793808e7..01d207ba 100644 --- a/controller/service/iml.go +++ b/controller/service/iml.go @@ -106,17 +106,21 @@ func (i *imlServiceController) CreateAIService(ctx *gin.Context, teamID string, input.Prefix = input.Id[:8] } } + pv, err := i.providerModule.Provider(ctx, *input.Provider) + if err != nil { + return nil, err + } p, has := model_runtime.GetProvider(*input.Provider) if !has { return nil, fmt.Errorf("provider not found") } - m, has := p.DefaultModel(model_runtime.ModelTypeLLM) + m, has := p.GetModel(pv.DefaultLLM) if !has { - return nil, fmt.Errorf("provider default llm not found") + return nil, fmt.Errorf("model %s not found", pv.DefaultLLM) } var info *service_dto.Service - err := i.transaction.Transaction(ctx, func(txCtx context.Context) error { + err = i.transaction.Transaction(ctx, func(txCtx context.Context) error { var err error info, err = i.module.Create(ctx, teamID, input) if err != nil {