diff --git a/ai-provider/model-runtime/model.go b/ai-provider/model-runtime/model.go index d96008a4..44d29e2f 100644 --- a/ai-provider/model-runtime/model.go +++ b/ai-provider/model-runtime/model.go @@ -43,8 +43,8 @@ func (m *Model) Name() string { } type CustomizeProviderConfig struct { - ApiEndpointUrl string `json:"api_endpoint_url"` - ApiKey string `json:"api_key"` + BaseUrl string `json:"base_url"` + ApiKey string `json:"api_key"` } func (m *Model) ID() string { diff --git a/ai-provider/model-runtime/provider.go b/ai-provider/model-runtime/provider.go index 9a36d562..38023a71 100644 --- a/ai-provider/model-runtime/provider.go +++ b/ai-provider/model-runtime/provider.go @@ -45,6 +45,7 @@ type IProviderInfo interface { DefaultModel(modelType string) (IModel, bool) HelpUrl() string Logo() string + SetURI(IProviderURI) URI() IProviderURI } @@ -54,7 +55,7 @@ func GetCustomizeLogo() string { return string(logo) } -func NewCustomizeProvider(id string, name string, models []IModel, defaultModel string, config string) (IProvider, error) { +func GetCustomizeProviderURI(config string, emptyURI bool) (IProviderURI, error) { var providerCfg CustomizeProviderConfig if strings.TrimSpace(config) != "" { err := json.Unmarshal([]byte(config), &providerCfg) @@ -62,7 +63,22 @@ func NewCustomizeProvider(id string, name string, models []IModel, defaultModel return nil, err } } - uri, err := newProviderUri(providerCfg.ApiEndpointUrl) + if providerCfg.BaseUrl == "" && emptyURI { + return &providerUri{ + scheme: "", + host: "", + path: "", + }, nil + } + uri, err := newProviderUri(providerCfg.BaseUrl) + if err != nil { + return nil, err + } + return uri, nil +} + +func NewCustomizeProvider(id string, name string, models []IModel, defaultModel string, config string) (IProvider, error) { + uri, err := GetCustomizeProviderURI(config, true) if err != nil { return nil, err } @@ -241,6 +257,10 @@ func (p *Provider) URI() IProviderURI { return p.uri } +func (p *Provider) SetURI(uri IProviderURI) { + p.uri = uri +} + func (p *Provider) ID() string { return p.id } diff --git a/module/ai/iml.go b/module/ai/iml.go index 4bc359fb..add079c5 100644 --- a/module/ai/iml.go +++ b/module/ai/iml.go @@ -199,7 +199,7 @@ func (i *imlProviderModule) AddProvider(ctx context.Context, input *ai_dto.NewPr return nil, fmt.Errorf("provider `%s` duplicate", input.Name) } id := uuid.New().String() - config, defaultLLM := "{\"api_endpoint_url\": \"http://127.0.0.1\", \"api_key\": \"\"}", "" + config, defaultLLM := "{\"base_url\": \"\", \"api_key\": \"\"}", "" if err := i.providerService.Create(ctx, &ai.CreateProvider{ Id: id, Name: input.Name, @@ -639,6 +639,12 @@ func (i *imlProviderModule) UpdateProviderConfig(ctx context.Context, id string, if err != nil { return err } + // customize provider + if info.Type == 1 { + if uri, uriErr := model_runtime.GetCustomizeProviderURI(input.Config, false); uriErr != nil { + p.SetURI(uri) + } + } /** if *pInfo.Status == 0 { return i.syncGateway(ctx, cluster.DefaultClusterID, []*gateway.DynamicRelease{ diff --git a/service/ai/iml.go b/service/ai/iml.go index d94f7ba0..ad99617d 100644 --- a/service/ai/iml.go +++ b/service/ai/iml.go @@ -114,6 +114,7 @@ func createEntityHandler(i *CreateProvider) *ai.Provider { DefaultLLM: i.DefaultLLM, Config: i.Config, Status: i.Status, + Type: i.Type, CreateAt: now, UpdateAt: now, }