diff --git a/.gitignore b/.gitignore
index cb2bf660..9708fbd2 100644
--- a/.gitignore
+++ b/.gitignore
@@ -5,3 +5,6 @@
/apipark
.gitlab-ci.yml
/.vscode/
+.air.toml
+/tmp/
+/work
\ No newline at end of file
diff --git a/ai-provider/model-runtime/entity/provider.go b/ai-provider/model-runtime/entity/provider.go
index d9014560..f3ee1d6e 100644
--- a/ai-provider/model-runtime/entity/provider.go
+++ b/ai-provider/model-runtime/entity/provider.go
@@ -25,6 +25,11 @@ type Provider struct {
Address string `json:"address" yaml:"address"`
Recommend bool `json:"recommend" yaml:"recommend"`
Sort int `json:"sort" yaml:"sort"`
+ ModelConfig ModelConfig `json:"model_config" yaml:"model_config"`
+}
+type ModelConfig struct {
+ AccessConfigurationStatus bool `json:"access_configuration_status" yaml:"access_configuration_status"`
+ AccessConfigurationDemo string `json:"access_configuration_demo" yaml:"access_configuration_demo"`
}
type ProviderCredentialSchema struct {
diff --git a/ai-provider/model-runtime/loader.go b/ai-provider/model-runtime/loader.go
index 5320ff0e..a6eefb25 100644
--- a/ai-provider/model-runtime/loader.go
+++ b/ai-provider/model-runtime/loader.go
@@ -6,8 +6,6 @@ import (
"fmt"
"strings"
- "github.com/APIParkLab/APIPark/gateway"
-
"github.com/eolinker/eosc"
)
@@ -36,7 +34,10 @@ func (c *Config) Check(cfg string) error {
if err != nil {
return err
}
- return c.validator.Valid(data)
+ if c.validator != nil {
+ return c.validator.Valid(data)
+ }
+ return nil
}
func (c *Config) GenConfig(target string, origin string) (string, error) {
@@ -83,6 +84,9 @@ func Load() error {
continue
}
name := fmt.Sprintf("model-providers/%s", file.Name())
+ if file.Name() == "customize" {
+ continue
+ }
err = LoadProvider(name)
if err != nil {
return err
@@ -119,10 +123,10 @@ func LoadProvider(name string) error {
if err != nil {
return err
}
- gateway.RegisterDynamicResourceDriver(provider.ID(), gateway.Worker{
- Profession: gateway.ProfessionAIProvider,
- Driver: provider.ID(),
- })
+ //gateway.RegisterDynamicResourceDriver(provider.ID(), gateway.Worker{
+ // Profession: gateway.ProfessionAIProvider,
+ // Driver: provider.ID(),
+ //})
Register(provider.ID(), provider)
return nil
}
diff --git a/ai-provider/model-runtime/manager.go b/ai-provider/model-runtime/manager.go
index dbefea3e..169ff0fc 100644
--- a/ai-provider/model-runtime/manager.go
+++ b/ai-provider/model-runtime/manager.go
@@ -43,6 +43,10 @@ func Register(name string, driver IProvider) {
defaultManager.Set(name, driver)
}
+func Remove(name string) {
+ defaultManager.Del(name)
+}
+
func GetProvider(name string) (IProvider, bool) {
return defaultManager.Get(name)
}
diff --git a/ai-provider/model-runtime/model-providers/bailian/assets/icon_l_en.svg b/ai-provider/model-runtime/model-providers/bailian/assets/icon_l_en.svg
new file mode 100644
index 00000000..2a7a4f4f
--- /dev/null
+++ b/ai-provider/model-runtime/model-providers/bailian/assets/icon_l_en.svg
@@ -0,0 +1,2 @@
+
\ No newline at end of file
diff --git a/ai-provider/model-runtime/model-providers/bailian/assets/icon_l_zh.svg b/ai-provider/model-runtime/model-providers/bailian/assets/icon_l_zh.svg
new file mode 100644
index 00000000..9f650f2b
--- /dev/null
+++ b/ai-provider/model-runtime/model-providers/bailian/assets/icon_l_zh.svg
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/ai-provider/model-runtime/model-providers/bailian/assets/icon_s_en.svg b/ai-provider/model-runtime/model-providers/bailian/assets/icon_s_en.svg
new file mode 100644
index 00000000..851ba565
--- /dev/null
+++ b/ai-provider/model-runtime/model-providers/bailian/assets/icon_s_en.svg
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/ai-provider/model-runtime/model-providers/bailian/bailian.yaml b/ai-provider/model-runtime/model-providers/bailian/bailian.yaml
new file mode 100644
index 00000000..7456a7de
--- /dev/null
+++ b/ai-provider/model-runtime/model-providers/bailian/bailian.yaml
@@ -0,0 +1,32 @@
+provider: bailian
+label:
+ zh_Hans: 阿里云百炼
+ en_US: bailian
+icon_small:
+ en_US: icon_s_en.svg
+icon_large:
+ zh_Hans: icon_l_zh.svg
+ en_US: icon_l_en.svg
+background: "#EFF1FE"
+help:
+ title:
+ en_US: Get your API key from AliCloud
+ zh_Hans: 从阿里云百炼获取 API Key
+ url:
+ en_US: https://bailian.console.aliyun.com/?apiKey=1#/api-key
+supported_model_types:
+ - llm
+configurate_methods:
+ - predefined-model
+ - customizable-model
+provider_credential_schema:
+ credential_form_schemas:
+ - variable: dashscope_api_key
+ label:
+ en_US: API Key
+ type: secret-input
+ required: true
+ placeholder:
+ zh_Hans: 在此输入您的 API Key
+ en_US: Enter your API Key
+address: https://dashscope.aliyuncs.com
\ No newline at end of file
diff --git a/ai-provider/model-runtime/model-providers/customize/assets/icon_l_en.svg b/ai-provider/model-runtime/model-providers/customize/assets/icon_l_en.svg
new file mode 100644
index 00000000..42d5eebe
--- /dev/null
+++ b/ai-provider/model-runtime/model-providers/customize/assets/icon_l_en.svg
@@ -0,0 +1,3 @@
+
diff --git a/ai-provider/model-runtime/model-providers/customize/assets/icon_l_zh.svg b/ai-provider/model-runtime/model-providers/customize/assets/icon_l_zh.svg
new file mode 100644
index 00000000..42d5eebe
--- /dev/null
+++ b/ai-provider/model-runtime/model-providers/customize/assets/icon_l_zh.svg
@@ -0,0 +1,3 @@
+
diff --git a/ai-provider/model-runtime/model-providers/customize/assets/icon_s_en.svg b/ai-provider/model-runtime/model-providers/customize/assets/icon_s_en.svg
new file mode 100644
index 00000000..42d5eebe
--- /dev/null
+++ b/ai-provider/model-runtime/model-providers/customize/assets/icon_s_en.svg
@@ -0,0 +1,3 @@
+
diff --git a/ai-provider/model-runtime/model-providers/huggingface_hub/assets/icon_l_en.svg b/ai-provider/model-runtime/model-providers/huggingface_hub/assets/icon_l_en.svg
new file mode 100644
index 00000000..d4a19a32
--- /dev/null
+++ b/ai-provider/model-runtime/model-providers/huggingface_hub/assets/icon_l_en.svg
@@ -0,0 +1,42 @@
+
\ No newline at end of file
diff --git a/ai-provider/model-runtime/model-providers/huggingface_hub/assets/icon_s_en.svg b/ai-provider/model-runtime/model-providers/huggingface_hub/assets/icon_s_en.svg
new file mode 100644
index 00000000..b05a4a1b
--- /dev/null
+++ b/ai-provider/model-runtime/model-providers/huggingface_hub/assets/icon_s_en.svg
@@ -0,0 +1,42 @@
+
\ No newline at end of file
diff --git a/ai-provider/model-runtime/model-providers/huggingface_hub/huggingface_hub.yaml b/ai-provider/model-runtime/model-providers/huggingface_hub/huggingface_hub.yaml
new file mode 100644
index 00000000..e51e6081
--- /dev/null
+++ b/ai-provider/model-runtime/model-providers/huggingface_hub/huggingface_hub.yaml
@@ -0,0 +1,103 @@
+provider: huggingface_hub
+label:
+ en_US: Hugging Face Model
+icon_small:
+ en_US: icon_s_en.svg
+icon_large:
+ en_US: icon_l_en.svg
+background: "#FFF8DC"
+help:
+ title:
+ en_US: Get your API key from Hugging Face Hub
+ zh_Hans: 从 Hugging Face Hub 获取 API Key
+ url:
+ en_US: https://huggingface.co/settings/tokens
+supported_model_types:
+ - llm
+ - text-embedding
+configurate_methods:
+ - customizable-model
+model_credential_schema:
+ model:
+ label:
+ en_US: Model Name
+ zh_Hans: 模型名称
+ credential_form_schemas:
+ - variable: huggingfacehub_api_type
+ label:
+ en_US: Endpoint Type
+ zh_Hans: 端点类型
+ type: radio
+ required: true
+ default: hosted_inference_api
+ options:
+ - value: hosted_inference_api
+ label:
+ en_US: Hosted Inference API
+ - value: inference_endpoints
+ label:
+ en_US: Inference Endpoints
+ - variable: huggingfacehub_api_token
+ label:
+ en_US: API Token
+ zh_Hans: API Token
+ type: secret-input
+ required: true
+ placeholder:
+ en_US: Enter your Hugging Face Hub API Token here
+ zh_Hans: 在此输入您的 Hugging Face Hub API Token
+ - variable: huggingface_namespace
+ label:
+ en_US: 'User Name / Organization Name'
+ zh_Hans: '用户名 / 组织名称'
+ type: text-input
+ required: true
+ placeholder:
+ en_US: 'Enter your User Name / Organization Name here'
+ zh_Hans: '在此输入您的用户名 / 组织名称'
+ show_on:
+ - variable: __model_type
+ value: text-embedding
+ - variable: huggingfacehub_api_type
+ value: inference_endpoints
+ - variable: huggingfacehub_endpoint_url
+ label:
+ en_US: Endpoint URL
+ zh_Hans: 端点 URL
+ type: text-input
+ required: true
+ placeholder:
+ en_US: Enter your Endpoint URL here
+ zh_Hans: 在此输入您的端点 URL
+ show_on:
+ - variable: huggingfacehub_api_type
+ value: inference_endpoints
+ - variable: task_type
+ label:
+ en_US: Task
+ zh_Hans: Task
+ type: select
+ options:
+ - value: text2text-generation
+ label:
+ en_US: Text-to-Text Generation
+ show_on:
+ - variable: __model_type
+ value: llm
+ - value: text-generation
+ label:
+ en_US: Text Generation
+ zh_Hans: 文本生成
+ show_on:
+ - variable: __model_type
+ value: llm
+ - value: feature-extraction
+ label:
+ en_US: Feature Extraction
+ show_on:
+ - variable: __model_type
+ value: text-embedding
+ show_on:
+ - variable: huggingfacehub_api_type
+ value: inference_endpoints
+address: https://api-inference.huggingface.co
\ No newline at end of file
diff --git a/ai-provider/model-runtime/model-providers/lm_studio/assets/icon_l_en.svg b/ai-provider/model-runtime/model-providers/lm_studio/assets/icon_l_en.svg
new file mode 100644
index 00000000..f1ef8d4b
--- /dev/null
+++ b/ai-provider/model-runtime/model-providers/lm_studio/assets/icon_l_en.svg
@@ -0,0 +1,11 @@
+
diff --git a/ai-provider/model-runtime/model-providers/lm_studio/assets/icon_s_en.svg b/ai-provider/model-runtime/model-providers/lm_studio/assets/icon_s_en.svg
new file mode 100644
index 00000000..86f2c419
--- /dev/null
+++ b/ai-provider/model-runtime/model-providers/lm_studio/assets/icon_s_en.svg
@@ -0,0 +1,4 @@
+
diff --git a/ai-provider/model-runtime/model-providers/lm_studio/lm_studio.yaml b/ai-provider/model-runtime/model-providers/lm_studio/lm_studio.yaml
new file mode 100644
index 00000000..5e62cd11
--- /dev/null
+++ b/ai-provider/model-runtime/model-providers/lm_studio/lm_studio.yaml
@@ -0,0 +1,99 @@
+provider: lm_studio
+label:
+ en_US: LM Studio
+icon_large:
+ en_US: icon_l_en.svg
+icon_small:
+ en_US: icon_s_en.svg
+background: "#F9FAFB"
+help:
+ title:
+ en_US: How to integrate with LM Studio
+ zh_Hans: 如何集成 LM Studio
+ url:
+ en_US: https://lmstudio.ai/docs/app
+supported_model_types:
+ - llm
+ - text-embedding
+configurate_methods:
+ - customizable-model
+model_credential_schema:
+ model:
+ label:
+ en_US: Model Name
+ zh_Hans: 模型名称
+ placeholder:
+ en_US: Enter your model name
+ zh_Hans: 输入模型名称
+ credential_form_schemas:
+ - variable: base_url
+ label:
+ zh_Hans: 基础 URL
+ en_US: Base URL
+ type: text-input
+ required: true
+ placeholder:
+ zh_Hans: LM Studio server 的基础 URL,例如 http://localhost:1234
+ en_US: Base url of LM Studio server, e.g. http://localhost:1234
+ - variable: mode
+ show_on:
+ - variable: __model_type
+ value: llm
+ label:
+ zh_Hans: 模型类型
+ en_US: Completion mode
+ type: select
+ required: true
+ default: chat
+ placeholder:
+ zh_Hans: 选择对话类型
+ en_US: Select completion mode
+ options:
+ - value: completion
+ label:
+ en_US: Completion
+ zh_Hans: 补全
+ - value: chat
+ label:
+ en_US: Chat
+ zh_Hans: 对话
+ - variable: context_size
+ label:
+ zh_Hans: 模型上下文长度
+ en_US: Model context size
+ required: true
+ type: text-input
+ default: '4096'
+ placeholder:
+ zh_Hans: 在此输入您的模型上下文长度
+ en_US: Enter your Model context size
+ - variable: max_tokens
+ label:
+ zh_Hans: 最大 token 上限
+ en_US: Upper bound for max tokens
+ show_on:
+ - variable: __model_type
+ value: llm
+ default: '4096'
+ type: text-input
+ required: true
+ - variable: function_call_support
+ label:
+ zh_Hans: 是否支持函数调用
+ en_US: Function call support
+ show_on:
+ - variable: __model_type
+ value: llm
+ default: 'false'
+ type: radio
+ required: false
+ options:
+ - value: 'true'
+ label:
+ en_US: 'Yes'
+ zh_Hans: 是
+ - value: 'false'
+ label:
+ en_US: 'No'
+ zh_Hans: 否
+address: https://lmstudio.ai
diff --git a/ai-provider/model-runtime/model-providers/ollama/assets/icon_l_en.svg b/ai-provider/model-runtime/model-providers/ollama/assets/icon_l_en.svg
new file mode 100644
index 00000000..5f08476c
--- /dev/null
+++ b/ai-provider/model-runtime/model-providers/ollama/assets/icon_l_en.svg
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/ai-provider/model-runtime/model-providers/ollama/assets/icon_s_en.svg b/ai-provider/model-runtime/model-providers/ollama/assets/icon_s_en.svg
new file mode 100644
index 00000000..5f08476c
--- /dev/null
+++ b/ai-provider/model-runtime/model-providers/ollama/assets/icon_s_en.svg
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/ai-provider/model-runtime/model-providers/ollama/ollama.yaml b/ai-provider/model-runtime/model-providers/ollama/ollama.yaml
new file mode 100644
index 00000000..b9144630
--- /dev/null
+++ b/ai-provider/model-runtime/model-providers/ollama/ollama.yaml
@@ -0,0 +1,118 @@
+provider: ollama
+label:
+ en_US: Ollama
+icon_large:
+ en_US: icon_l_en.svg
+icon_small:
+ en_US: icon_s_en.svg
+background: "#F9FAFB"
+help:
+ title:
+ en_US: How to integrate with Ollama
+ zh_Hans: 如何集成 Ollama
+ url:
+ en_US: https://docs.dify.ai/tutorials/model-configuration/ollama
+supported_model_types:
+ - llm
+ - text-embedding
+configurate_methods:
+ - customizable-model
+model_credential_schema:
+ model:
+ label:
+ en_US: Model Name
+ zh_Hans: 模型名称
+ placeholder:
+ en_US: Enter your model name
+ zh_Hans: 输入模型名称
+ credential_form_schemas:
+ - variable: base_url
+ label:
+ zh_Hans: 基础 URL
+ en_US: Base URL
+ type: text-input
+ required: true
+ placeholder:
+ zh_Hans: Ollama server 的基础 URL,例如 http://192.168.1.100:11434
+ en_US: Base url of Ollama server, e.g. http://192.168.1.100:11434
+ - variable: mode
+ show_on:
+ - variable: __model_type
+ value: llm
+ label:
+ zh_Hans: 模型类型
+ en_US: Completion mode
+ type: select
+ required: true
+ default: chat
+ placeholder:
+ zh_Hans: 选择对话类型
+ en_US: Select completion mode
+ options:
+ - value: completion
+ label:
+ en_US: Completion
+ zh_Hans: 补全
+ - value: chat
+ label:
+ en_US: Chat
+ zh_Hans: 对话
+ - variable: context_size
+ label:
+ zh_Hans: 模型上下文长度
+ en_US: Model context size
+ required: true
+ type: text-input
+ default: '4096'
+ placeholder:
+ zh_Hans: 在此输入您的模型上下文长度
+ en_US: Enter your Model context size
+ - variable: max_tokens
+ label:
+ zh_Hans: 最大 token 上限
+ en_US: Upper bound for max tokens
+ show_on:
+ - variable: __model_type
+ value: llm
+ default: '4096'
+ type: text-input
+ required: true
+ - variable: vision_support
+ label:
+ zh_Hans: 是否支持 Vision
+ en_US: Vision support
+ show_on:
+ - variable: __model_type
+ value: llm
+ default: 'false'
+ type: radio
+ required: false
+ options:
+ - value: 'true'
+ label:
+ en_US: 'Yes'
+ zh_Hans: 是
+ - value: 'false'
+ label:
+ en_US: 'No'
+ zh_Hans: 否
+ - variable: function_call_support
+ label:
+ zh_Hans: 是否支持函数调用
+ en_US: Function call support
+ show_on:
+ - variable: __model_type
+ value: llm
+ default: 'false'
+ type: radio
+ required: false
+ options:
+ - value: 'true'
+ label:
+ en_US: 'Yes'
+ zh_Hans: 是
+ - value: 'false'
+ label:
+ en_US: 'No'
+ zh_Hans: 否
+address: https://ollama.ai
diff --git a/ai-provider/model-runtime/model-providers/volcengine_maas/assets/icon_l_en.svg b/ai-provider/model-runtime/model-providers/volcengine_maas/assets/icon_l_en.svg
new file mode 100644
index 00000000..b34a9914
--- /dev/null
+++ b/ai-provider/model-runtime/model-providers/volcengine_maas/assets/icon_l_en.svg
@@ -0,0 +1,23 @@
+
+
\ No newline at end of file
diff --git a/ai-provider/model-runtime/model-providers/volcengine_maas/assets/icon_l_zh.svg b/ai-provider/model-runtime/model-providers/volcengine_maas/assets/icon_l_zh.svg
new file mode 100644
index 00000000..65808aa3
--- /dev/null
+++ b/ai-provider/model-runtime/model-providers/volcengine_maas/assets/icon_l_zh.svg
@@ -0,0 +1,39 @@
+
+
\ No newline at end of file
diff --git a/ai-provider/model-runtime/model-providers/volcengine_maas/assets/icon_s_en.svg b/ai-provider/model-runtime/model-providers/volcengine_maas/assets/icon_s_en.svg
new file mode 100644
index 00000000..5de409c0
--- /dev/null
+++ b/ai-provider/model-runtime/model-providers/volcengine_maas/assets/icon_s_en.svg
@@ -0,0 +1,8 @@
+
+
\ No newline at end of file
diff --git a/ai-provider/model-runtime/model-providers/volcengine_maas/volcengine_maas.yaml b/ai-provider/model-runtime/model-providers/volcengine_maas/volcengine_maas.yaml
new file mode 100644
index 00000000..1a970c63
--- /dev/null
+++ b/ai-provider/model-runtime/model-providers/volcengine_maas/volcengine_maas.yaml
@@ -0,0 +1,342 @@
+provider: volcengine_maas
+label:
+ en_US: Volcengine
+description:
+ en_US: Volcengine Ark models.
+ zh_Hans: 火山方舟提供的模型,例如 Doubao-pro-4k、Doubao-pro-32k 和 Doubao-pro-128k。
+icon_small:
+ en_US: icon_s_en.svg
+icon_large:
+ en_US: icon_l_en.svg
+ zh_Hans: icon_l_zh.svg
+background: "#F9FAFB"
+help:
+ title:
+ en_US: Get your Access Key and Secret Access Key from Volcengine Console
+ zh_Hans: 从火山引擎控制台获取您的 Access Key 和 Secret Access Key
+ url:
+ en_US: https://console.volcengine.com/iam/keymanage/
+supported_model_types:
+ - llm
+ - text-embedding
+configurate_methods:
+ - customizable-model
+model_credential_schema:
+ model:
+ label:
+ en_US: Model Name
+ zh_Hans: 模型名称
+ placeholder:
+ en_US: Enter your Model Name
+ zh_Hans: 输入模型名称
+ credential_form_schemas:
+ - variable: auth_method
+ required: true
+ label:
+ en_US: Authentication Method
+ zh_Hans: 鉴权方式
+ type: select
+ default: aksk
+ options:
+ - label:
+ en_US: API Key
+ value: api_key
+ - label:
+ en_US: Access Key / Secret Access Key
+ value: aksk
+ placeholder:
+ en_US: Enter your Authentication Method
+ zh_Hans: 选择鉴权方式
+ - variable: volc_access_key_id
+ required: true
+ show_on:
+ - variable: auth_method
+ value: aksk
+ label:
+ en_US: Access Key
+ zh_Hans: Access Key
+ type: secret-input
+ placeholder:
+ en_US: Enter your Access Key
+ zh_Hans: 输入您的 Access Key
+ - variable: volc_secret_access_key
+ required: true
+ show_on:
+ - variable: auth_method
+ value: aksk
+ label:
+ en_US: Secret Access Key
+ zh_Hans: Secret Access Key
+ type: secret-input
+ placeholder:
+ en_US: Enter your Secret Access Key
+ zh_Hans: 输入您的 Secret Access Key
+ - variable: volc_api_key
+ required: true
+ show_on:
+ - variable: auth_method
+ value: api_key
+ label:
+ en_US: API Key
+ type: secret-input
+ placeholder:
+ en_US: Enter your API Key
+ zh_Hans: 输入您的 API Key
+ - variable: volc_region
+ required: true
+ label:
+ en_US: Volcengine Region
+ zh_Hans: 火山引擎地域
+ type: text-input
+ default: cn-beijing
+ placeholder:
+ en_US: Enter Volcengine Region
+ zh_Hans: 输入火山引擎地域
+ - variable: api_endpoint_host
+ required: true
+ label:
+ en_US: API Endpoint Host
+ zh_Hans: API Endpoint Host
+ type: text-input
+ default: https://ark.cn-beijing.volces.com/api/v3
+ placeholder:
+ en_US: Enter your API Endpoint Host
+ zh_Hans: 输入 API Endpoint Host
+ - variable: endpoint_id
+ required: true
+ label:
+ en_US: Endpoint ID
+ zh_Hans: Endpoint ID
+ type: text-input
+ placeholder:
+ en_US: Enter your Endpoint ID
+ zh_Hans: 输入您的 Endpoint ID
+ - variable: base_model_name
+ label:
+ en_US: Base Model
+ zh_Hans: 基础模型
+ type: select
+ required: true
+ options:
+ - label:
+ en_US: DeepSeek-R1-Distill-Qwen-32B
+ value: DeepSeek-R1-Distill-Qwen-32B
+ show_on:
+ - variable: __model_type
+ value: llm
+ - label:
+ en_US: DeepSeek-R1-Distill-Qwen-7B
+ value: DeepSeek-R1-Distill-Qwen-7B
+ show_on:
+ - variable: __model_type
+ value: llm
+ - label:
+ en_US: DeepSeek-R1
+ value: DeepSeek-R1
+ show_on:
+ - variable: __model_type
+ value: llm
+ - label:
+ en_US: DeepSeek-V3
+ value: DeepSeek-V3
+ show_on:
+ - variable: __model_type
+ value: llm
+ - label:
+ en_US: Doubao-1.5-vision-pro-32k
+ value: Doubao-1.5-vision-pro-32k
+ show_on:
+ - variable: __model_type
+ value: llm
+ - label:
+ en_US: Doubao-1.5-pro-32k
+ value: Doubao-1.5-pro-32k
+ show_on:
+ - variable: __model_type
+ value: llm
+ - label:
+ en_US: Doubao-1.5-lite-32k
+ value: Doubao-1.5-lite-32k
+ show_on:
+ - variable: __model_type
+ value: llm
+ - label:
+ en_US: Doubao-1.5-pro-256k
+ value: Doubao-1.5-pro-256k
+ show_on:
+ - variable: __model_type
+ value: llm
+ - label:
+ en_US: Doubao-vision-pro-32k
+ value: Doubao-vision-pro-32k
+ show_on:
+ - variable: __model_type
+ value: llm
+ - label:
+ en_US: Doubao-vision-lite-32k
+ value: Doubao-vision-lite-32k
+ show_on:
+ - variable: __model_type
+ value: llm
+ - label:
+ en_US: Doubao-pro-4k
+ value: Doubao-pro-4k
+ show_on:
+ - variable: __model_type
+ value: llm
+ - label:
+ en_US: Doubao-lite-4k
+ value: Doubao-lite-4k
+ show_on:
+ - variable: __model_type
+ value: llm
+ - label:
+ en_US: Doubao-pro-32k
+ value: Doubao-pro-32k
+ show_on:
+ - variable: __model_type
+ value: llm
+ - label:
+ en_US: Doubao-lite-32k
+ value: Doubao-lite-32k
+ show_on:
+ - variable: __model_type
+ value: llm
+ - label:
+ en_US: Doubao-pro-128k
+ value: Doubao-pro-128k
+ show_on:
+ - variable: __model_type
+ value: llm
+ - label:
+ en_US: Doubao-lite-128k
+ value: Doubao-lite-128k
+ show_on:
+ - variable: __model_type
+ value: llm
+ - label:
+ en_US: Doubao-pro-256k
+ value: Doubao-pro-256k
+ show_on:
+ - variable: __model_type
+ value: llm
+ - label:
+ en_US: Llama3-8B
+ value: Llama3-8B
+ show_on:
+ - variable: __model_type
+ value: llm
+ - label:
+ en_US: Llama3-70B
+ value: Llama3-70B
+ show_on:
+ - variable: __model_type
+ value: llm
+ - label:
+ en_US: Moonshot-v1-8k
+ value: Moonshot-v1-8k
+ show_on:
+ - variable: __model_type
+ value: llm
+ - label:
+ en_US: Moonshot-v1-32k
+ value: Moonshot-v1-32k
+ show_on:
+ - variable: __model_type
+ value: llm
+ - label:
+ en_US: Moonshot-v1-128k
+ value: Moonshot-v1-128k
+ show_on:
+ - variable: __model_type
+ value: llm
+ - label:
+ en_US: GLM3-130B
+ value: GLM3-130B
+ show_on:
+ - variable: __model_type
+ value: llm
+ - label:
+ en_US: GLM3-130B-Fin
+ value: GLM3-130B-Fin
+ show_on:
+ - variable: __model_type
+ value: llm
+ - label:
+ en_US: Mistral-7B
+ value: Mistral-7B
+ show_on:
+ - variable: __model_type
+ value: llm
+ - label:
+ en_US: Doubao-embedding
+ value: Doubao-embedding
+ show_on:
+ - variable: __model_type
+ value: text-embedding
+ - label:
+ en_US: Doubao-embedding-large
+ value: Doubao-embedding-large
+ show_on:
+ - variable: __model_type
+ value: text-embedding
+ - label:
+ en_US: Custom
+ zh_Hans: 自定义
+ value: Custom
+ - variable: mode
+ required: true
+ show_on:
+ - variable: __model_type
+ value: llm
+ - variable: base_model_name
+ value: Custom
+ label:
+ zh_Hans: 模型类型
+ en_US: Completion Mode
+ type: select
+ default: chat
+ placeholder:
+ zh_Hans: 选择对话类型
+ en_US: Select Completion Mode
+ options:
+ - value: completion
+ label:
+ en_US: Completion
+ zh_Hans: 补全
+ - value: chat
+ label:
+ en_US: Chat
+ zh_Hans: 对话
+ - variable: context_size
+ required: true
+ show_on:
+ - variable: base_model_name
+ value: Custom
+ label:
+ zh_Hans: 模型上下文长度
+ en_US: Model Context Size
+ type: text-input
+ default: "4096"
+ placeholder:
+ zh_Hans: 输入您的模型上下文长度
+ en_US: Enter your Model Context Size
+ - variable: max_tokens
+ required: true
+ show_on:
+ - variable: __model_type
+ value: llm
+ - variable: base_model_name
+ value: Custom
+ label:
+ zh_Hans: 最大 token 上限
+ en_US: Upper Bound for Max Tokens
+ default: "4096"
+ type: text-input
+ placeholder:
+ zh_Hans: 输入您的模型最大 token 上限
+ en_US: Enter your model Upper Bound for Max Tokens
+address: https://open.volcengine.com
+model_config:
+ access_configuration_status: true
+ access_configuration_demo: "{\"endpoint\": \"https://196.1.1.2:3824\"}"
\ No newline at end of file
diff --git a/ai-provider/model-runtime/model-providers/xinference/assets/icon_l_en.svg b/ai-provider/model-runtime/model-providers/xinference/assets/icon_l_en.svg
new file mode 100644
index 00000000..5c601c8a
--- /dev/null
+++ b/ai-provider/model-runtime/model-providers/xinference/assets/icon_l_en.svg
@@ -0,0 +1,42 @@
+
\ No newline at end of file
diff --git a/ai-provider/model-runtime/model-providers/xinference/assets/icon_s_en.svg b/ai-provider/model-runtime/model-providers/xinference/assets/icon_s_en.svg
new file mode 100644
index 00000000..efe03479
--- /dev/null
+++ b/ai-provider/model-runtime/model-providers/xinference/assets/icon_s_en.svg
@@ -0,0 +1,24 @@
+
\ No newline at end of file
diff --git a/ai-provider/model-runtime/model-providers/xinference/xinference.yaml b/ai-provider/model-runtime/model-providers/xinference/xinference.yaml
new file mode 100644
index 00000000..6d284644
--- /dev/null
+++ b/ai-provider/model-runtime/model-providers/xinference/xinference.yaml
@@ -0,0 +1,79 @@
+provider: xinference
+label:
+ en_US: Xorbits Inference
+icon_small:
+ en_US: icon_s_en.svg
+icon_large:
+ en_US: icon_l_en.svg
+background: "#FAF5FF"
+help:
+ title:
+ en_US: How to deploy Xinference
+ zh_Hans: 如何部署 Xinference
+ url:
+ en_US: https://github.com/xorbitsai/inference
+supported_model_types:
+ - llm
+ - text-embedding
+ - rerank
+ - speech2text
+ - tts
+configurate_methods:
+ - customizable-model
+model_credential_schema:
+ model:
+ label:
+ en_US: Model Name
+ zh_Hans: 模型名称
+ placeholder:
+ en_US: Enter your model name
+ zh_Hans: 输入模型名称
+ credential_form_schemas:
+ - variable: server_url
+ label:
+ zh_Hans: 服务器URL
+ en_US: Server url
+ type: secret-input
+ required: true
+ placeholder:
+ zh_Hans: 在此输入Xinference的服务器地址,如 http://192.168.1.100:9997
+ en_US: Enter the url of your Xinference, e.g. http://192.168.1.100:9997
+ - variable: model_uid
+ label:
+ zh_Hans: 模型UID
+ en_US: Model uid
+ type: text-input
+ required: true
+ placeholder:
+ zh_Hans: 在此输入您的Model UID
+ en_US: Enter the model uid
+ - variable: api_key
+ label:
+ zh_Hans: API密钥
+ en_US: API key
+ type: secret-input
+ required: false
+ placeholder:
+ zh_Hans: 在此输入您的API密钥
+ en_US: Enter the api key
+ - variable: invoke_timeout
+ label:
+ zh_Hans: 调用超时时间 (单位:秒)
+ en_US: invoke timeout (unit:second)
+ type: text-input
+ required: true
+ default: '60'
+ placeholder:
+ zh_Hans: 在此输入调用超时时间
+ en_US: Enter invoke timeout value
+ - variable: max_retries
+ label:
+ zh_Hans: 调用重试次数
+ en_US: max retries
+ type: text-input
+ required: true
+ default: '3'
+ placeholder:
+ zh_Hans: 在此输入调用重试次数
+ en_US: Enter max retries
+address: https://xinference.ai
\ No newline at end of file
diff --git a/ai-provider/model-runtime/model.go b/ai-provider/model-runtime/model.go
index b1d0e235..ecdfcd81 100644
--- a/ai-provider/model-runtime/model.go
+++ b/ai-provider/model-runtime/model.go
@@ -3,32 +3,84 @@ package model_runtime
import (
"encoding/json"
"github.com/APIParkLab/APIPark/ai-provider/model-runtime/entity"
+ "github.com/APIParkLab/APIPark/common"
"gopkg.in/yaml.v3"
"strconv"
)
type IModel interface {
ID() string
+ Name() string
Logo() string
+ Source() string
+ SetLogo(logo string)
+ AccessConfiguration() string
+ ModelParameters() string
IConfig
}
type Model struct {
- id string
- logo string
+ id string
+ logo string
+ name string
+ accessConfiguration string
+ modelParameters string
+ // default: ""/"system", "customize"
+ source string
//defaultConfig string
IConfig
//validator IParamValidator
}
+func (m *Model) SetLogo(logo string) {
+ m.logo = logo
+}
+
+func (m *Model) Name() string {
+ return m.name
+}
+
+type CustomizeProviderConfig struct {
+ ApiEndpointUrl string `json:"api_endpoint_url"`
+ ApiKey string `json:"api_key"`
+}
+
func (m *Model) ID() string {
return m.id
}
+func (m *Model) Source() string {
+ return m.source
+}
+
func (m *Model) Logo() string {
return m.logo
}
+func (m *Model) AccessConfiguration() string {
+ return m.accessConfiguration
+}
+
+func (m *Model) ModelParameters() string {
+ return m.modelParameters
+}
+
+func NewCustomizeModel(id string, name string, logo string, accessConfiguration string, modelParameters string) (IModel, error) {
+ if logo == "" {
+ logo = GetCustomizeLogo()
+ }
+ // handle access_config & model_config
+ config := common.MergeJSON(accessConfiguration, modelParameters)
+ return &Model{
+ id: id,
+ name: name,
+ logo: logo,
+ source: "customize",
+ accessConfiguration: accessConfiguration,
+ IConfig: NewConfig(config, nil),
+ }, nil
+}
+
func NewModel(data string, logo string) (IModel, error) {
var cfg entity.AIModel
err := yaml.Unmarshal([]byte(data), &cfg)
@@ -100,8 +152,10 @@ func NewModel(data string, logo string) (IModel, error) {
return nil, err
}
return &Model{
- id: cfg.Model,
- logo: logo,
- IConfig: NewConfig(string(dCfg), params),
+ id: cfg.Model,
+ name: cfg.Model,
+ logo: logo,
+ accessConfiguration: "",
+ IConfig: NewConfig(string(dCfg), params),
}, nil
}
diff --git a/ai-provider/model-runtime/provider.go b/ai-provider/model-runtime/provider.go
index 666fc11f..dde582ea 100644
--- a/ai-provider/model-runtime/provider.go
+++ b/ai-provider/model-runtime/provider.go
@@ -4,6 +4,7 @@ import (
"encoding/json"
"fmt"
"net/url"
+ "strings"
yaml "gopkg.in/yaml.v3"
@@ -17,6 +18,10 @@ const (
type IProvider interface {
IProviderInfo
+ GetModelConfig() ModelConfig
+ SetModelsByType(modelType string, models []IModel)
+ SetModel(id string, model IModel)
+ SetDefaultModel(modelType string, model IModel)
GetModel(name string) (IModel, bool)
Models() []IModel
ModelsByType(modelType string) ([]IModel, bool)
@@ -41,6 +46,58 @@ type IProviderInfo interface {
URI() IProviderURI
}
+func GetCustomizeLogo() string {
+ logo, _ := providerDir.ReadFile("customize/assets/icon_s_en.svg")
+
+ return string(logo)
+}
+
+func NewCustomizeProvider(id string, name string, models []IModel, defaultModel string, config string) (IProvider, error) {
+ var providerCfg CustomizeProviderConfig
+ if strings.TrimSpace(config) != "" {
+ err := json.Unmarshal([]byte(config), &providerCfg)
+ if err != nil {
+ return nil, err
+ }
+ }
+ uri, err := newProviderUri(providerCfg.ApiEndpointUrl)
+ if err != nil {
+ return nil, err
+ }
+
+ provider := &Provider{
+ id: id,
+ name: name,
+ logo: GetCustomizeLogo(),
+ helpUrl: "",
+ models: eosc.BuildUntyped[string, IModel](),
+ defaultModels: eosc.BuildUntyped[string, IModel](),
+ modelsByType: eosc.BuildUntyped[string, []IModel](),
+ maskKeys: make([]string, 0),
+ recommend: false,
+ sort: 0,
+ uri: uri,
+ modelConfig: ModelConfig{
+ AccessConfigurationStatus: false,
+ AccessConfigurationDemo: "",
+ },
+ }
+ provider.IConfig = NewConfig("", nil)
+
+ for _, model := range models {
+ provider.SetModel(model.ID(), model)
+ if defaultModel == "" {
+ defaultModel = model.ID()
+ }
+ if model.ID() == defaultModel {
+ provider.SetDefaultModel(name, model)
+ }
+ }
+ provider.SetModelsByType(ModelTypeLLM, models)
+
+ return provider, nil
+}
+
func NewProvider(providerData string, modelContents map[string]eosc.Untyped[string, string]) (IProvider, error) {
var providerCfg entity.Provider
err := yaml.Unmarshal([]byte(providerData), &providerCfg)
@@ -77,6 +134,10 @@ func NewProvider(providerData string, modelContents map[string]eosc.Untyped[stri
recommend: providerCfg.Recommend,
sort: providerCfg.Sort,
uri: uri,
+ modelConfig: ModelConfig{
+ AccessConfigurationStatus: providerCfg.ModelConfig.AccessConfigurationStatus,
+ AccessConfigurationDemo: providerCfg.ModelConfig.AccessConfigurationDemo,
+ },
}
defaultCfg := make(map[string]string)
params := make(ParamValidator, 0, len(providerCfg.ProviderCredentialSchema.CredentialFormSchemas))
@@ -132,9 +193,19 @@ type Provider struct {
uri IProviderURI
sort int
recommend bool
+ modelConfig ModelConfig
IConfig
}
+type ModelConfig struct {
+ AccessConfigurationStatus bool
+ AccessConfigurationDemo string
+}
+
+func (p *Provider) GetModelConfig() ModelConfig {
+ return p.modelConfig
+}
+
func (p *Provider) Sort() int {
return p.sort
}
@@ -202,6 +273,10 @@ func (p *Provider) SetModel(id string, model IModel) {
p.models.Set(id, model)
}
+func (p *Provider) RemoveModel(id string) {
+ p.models.Del(id)
+}
+
func (p *Provider) SetModelsByType(modelType string, models []IModel) {
p.modelsByType.Set(modelType, models)
}
diff --git a/common/common.go b/common/common.go
new file mode 100644
index 00000000..dfc001d8
--- /dev/null
+++ b/common/common.go
@@ -0,0 +1,36 @@
+package common
+
+import (
+ "encoding/json"
+ "strings"
+)
+
+func MergeJSON(json1, json2 string) string {
+ var data1, data2 map[string]interface{}
+ if strings.TrimSpace(json1) != "" {
+ if err := json.Unmarshal([]byte(json1), &data1); err != nil {
+ return ""
+ }
+ }
+ if strings.TrimSpace(json2) != "" {
+ if err := json.Unmarshal([]byte(json2), &data2); err != nil {
+ return ""
+ }
+ }
+
+ merged := make(map[string]interface{})
+ // copy data1
+ for k, v := range data1 {
+ merged[k] = v
+ }
+ // merge data2 & cover same key
+ for k, v := range data2 {
+ merged[k] = v
+ }
+ // transfer to json string
+ result, err := json.Marshal(merged)
+ if err != nil {
+ return ""
+ }
+ return string(result)
+}
diff --git a/controller/ai-model/controller.go b/controller/ai-model/controller.go
new file mode 100644
index 00000000..84cefe5e
--- /dev/null
+++ b/controller/ai-model/controller.go
@@ -0,0 +1,20 @@
+package ai_model
+
+import (
+ model_dto "github.com/APIParkLab/APIPark/module/ai-model/dto"
+ "github.com/eolinker/go-common/autowire"
+ "github.com/gin-gonic/gin"
+ "reflect"
+)
+
+type IProviderModelController interface {
+ AddProviderModel(ctx *gin.Context, provider string, input *model_dto.Model) (*model_dto.SimpleModel, error)
+ UpdateProviderModel(ctx *gin.Context, provider string, input *model_dto.EditModel) error
+ DeleteProviderModel(ctx *gin.Context, provider string, id string) error
+}
+
+func init() {
+ autowire.Auto[IProviderModelController](func() reflect.Value {
+ return reflect.ValueOf(&imlProviderModelController{})
+ })
+}
diff --git a/controller/ai-model/iml.go b/controller/ai-model/iml.go
new file mode 100644
index 00000000..04d7341a
--- /dev/null
+++ b/controller/ai-model/iml.go
@@ -0,0 +1,67 @@
+package ai_model
+
+import (
+ "encoding/json"
+ "fmt"
+ ai_model "github.com/APIParkLab/APIPark/module/ai-model"
+ model_dto "github.com/APIParkLab/APIPark/module/ai-model/dto"
+ "github.com/gin-gonic/gin"
+ "strings"
+)
+
+var (
+ _ IProviderModelController = (*imlProviderModelController)(nil)
+)
+
+type imlProviderModelController struct {
+ module ai_model.IProviderModelModule `autowired:""`
+}
+
+func (i *imlProviderModelController) UpdateProviderModel(ctx *gin.Context, provider string, input *model_dto.EditModel) error {
+ if strings.TrimSpace(input.Name) == "" {
+ return fmt.Errorf("name is empty")
+ }
+ if strings.TrimSpace(input.Id) == "" {
+ return fmt.Errorf("id is empty")
+ }
+ if strings.TrimSpace(provider) == "" {
+ return fmt.Errorf("provider is empty")
+ }
+ // check access config & model parameter is json format
+ if strings.TrimSpace(input.AccessConfiguration) != "" && !json.Valid([]byte(input.AccessConfiguration)) {
+ return fmt.Errorf("access configuration is not json format")
+ }
+ if strings.TrimSpace(input.ModelParameters) != "" && !json.Valid([]byte(input.ModelParameters)) {
+ return fmt.Errorf("model parameters is not json format")
+ }
+
+ return i.module.UpdateProviderModel(ctx, provider, input)
+}
+
+func (i *imlProviderModelController) DeleteProviderModel(ctx *gin.Context, provider string, id string) error {
+ if strings.TrimSpace(id) == "" {
+ return fmt.Errorf("id is empty")
+ }
+ if strings.TrimSpace(provider) == "" {
+ return fmt.Errorf("provider is empty")
+ }
+
+ return i.module.DeleteProviderModel(ctx, provider, id)
+}
+
+func (i *imlProviderModelController) AddProviderModel(ctx *gin.Context, provider string, input *model_dto.Model) (*model_dto.SimpleModel, error) {
+ if strings.TrimSpace(input.Name) == "" {
+ return nil, fmt.Errorf("name is empty")
+ }
+ if strings.TrimSpace(provider) == "" {
+ return nil, fmt.Errorf("provider is empty")
+ }
+ // check access config & model parameter is json format
+ if strings.TrimSpace(input.AccessConfiguration) != "" && !json.Valid([]byte(input.AccessConfiguration)) {
+ return nil, fmt.Errorf("access configuration is not json format")
+ }
+ if strings.TrimSpace(input.ModelParameters) != "" && !json.Valid([]byte(input.ModelParameters)) {
+ return nil, fmt.Errorf("model parameters is not json format")
+ }
+ return i.module.AddProviderModel(ctx, provider, input)
+}
diff --git a/controller/ai/controller.go b/controller/ai/controller.go
index e4b2a5a5..f6eb5689 100644
--- a/controller/ai/controller.go
+++ b/controller/ai/controller.go
@@ -22,6 +22,7 @@ type IProviderController interface {
UpdateProviderDefaultLLM(ctx *gin.Context, id string, input *ai_dto.UpdateLLM) error
Delete(ctx *gin.Context, id string) error
//Sort(ctx *gin.Context, input *ai_dto.Sort) error
+ AddProvider(ctx *gin.Context, input *ai_dto.NewProvider) (*ai_dto.SimpleProvider, error)
}
type IStatisticController interface {
diff --git a/controller/ai/iml.go b/controller/ai/iml.go
index c7eee3b9..ae37239b 100644
--- a/controller/ai/iml.go
+++ b/controller/ai/iml.go
@@ -2,7 +2,9 @@ package ai
import (
"encoding/json"
+ "fmt"
"strconv"
+ "strings"
"github.com/APIParkLab/APIPark/module/ai"
ai_dto "github.com/APIParkLab/APIPark/module/ai/dto"
@@ -21,6 +23,13 @@ func (i *imlProviderController) Delete(ctx *gin.Context, id string) error {
return i.module.Delete(ctx, id)
}
+func (i *imlProviderController) AddProvider(ctx *gin.Context, input *ai_dto.NewProvider) (*ai_dto.SimpleProvider, error) {
+ if strings.TrimSpace(input.Name) == "" {
+ return nil, fmt.Errorf("name is empty")
+ }
+ return i.module.AddProvider(ctx, input)
+}
+
//func (i *imlProviderController) Sort(ctx *gin.Context, input *ai_dto.Sort) error {
// return i.module.Sort(ctx, input)
//}
@@ -67,6 +76,9 @@ func (i *imlProviderController) Disable(ctx *gin.Context, id string) error {
}
func (i *imlProviderController) UpdateProviderConfig(ctx *gin.Context, id string, input *ai_dto.UpdateConfig) error {
+ if strings.TrimSpace(id) == "" {
+ return fmt.Errorf("id is empty")
+ }
return i.module.UpdateProviderConfig(ctx, id, input)
}
diff --git a/go.sum b/go.sum
index 0322d23f..8e7dcdc0 100644
--- a/go.sum
+++ b/go.sum
@@ -40,7 +40,6 @@ github.com/getkin/kin-openapi v0.127.0 h1:Mghqi3Dhryf3F8vR370nN67pAERW+3a95vomb3
github.com/getkin/kin-openapi v0.127.0/go.mod h1:OZrfXzUfGrNbsKj+xmFBx6E5c6yH3At/tAKSc2UszXM=
github.com/ghodss/yaml v1.0.0 h1:wQHKEahhL6wmXdzwWG11gIVCkOv05bNOh+Rxn0yngAk=
github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04=
-github.com/gin-contrib/gzip v1.0.1 h1:HQ8ENHODeLY7a4g1Au/46Z92bdGFl74OhxcZble9WJE=
github.com/gin-contrib/gzip v1.0.1/go.mod h1:njt428fdUNRvjuJf16tZMYZ2Yl+WQB53X5wmhDwXvC4=
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
diff --git a/module/ai-model/dto/input.go b/module/ai-model/dto/input.go
new file mode 100644
index 00000000..97f554e8
--- /dev/null
+++ b/module/ai-model/dto/input.go
@@ -0,0 +1,12 @@
+package model_dto
+
+type Model struct {
+ Name string `json:"name"`
+ AccessConfiguration string `json:"access_configuration"`
+ ModelParameters string `json:"model_parameters"`
+}
+
+type EditModel struct {
+ Id string `json:"id"`
+ Model
+}
diff --git a/module/ai-model/dto/output.go b/module/ai-model/dto/output.go
new file mode 100644
index 00000000..936845fc
--- /dev/null
+++ b/module/ai-model/dto/output.go
@@ -0,0 +1,6 @@
+package model_dto
+
+type SimpleModel struct {
+ Id string `json:"id"`
+ Name string `json:"name"`
+}
diff --git a/module/ai-model/iml.go b/module/ai-model/iml.go
new file mode 100644
index 00000000..2196b36e
--- /dev/null
+++ b/module/ai-model/iml.go
@@ -0,0 +1,111 @@
+package ai_model
+
+import (
+ "fmt"
+ model_runtime "github.com/APIParkLab/APIPark/ai-provider/model-runtime"
+ model_dto "github.com/APIParkLab/APIPark/module/ai-model/dto"
+ "github.com/APIParkLab/APIPark/service/ai"
+ ai_model "github.com/APIParkLab/APIPark/service/ai-model"
+ "github.com/gin-gonic/gin"
+ "github.com/google/uuid"
+
+ "github.com/eolinker/go-common/store"
+)
+
+var (
+ _ IProviderModelModule = (*imlProviderModelModule)(nil)
+)
+
+type imlProviderModelModule struct {
+ providerService ai.IProviderService `autowired:""`
+ providerModelService ai_model.IProviderModelService `autowired:""`
+ transaction store.ITransaction `autowired:""`
+}
+
+func (i imlProviderModelModule) UpdateProviderModel(ctx *gin.Context, provider string, input *model_dto.EditModel) error {
+ p, has := model_runtime.GetProvider(provider)
+ if !has {
+ return fmt.Errorf("ai provider not found")
+ }
+ // check provider exist
+ providerInfo, err := i.providerService.Get(ctx, provider)
+ if err != nil {
+ return err
+ }
+ if providerInfo == nil {
+ return fmt.Errorf("provider not found")
+ }
+ modelInfo, _ := i.providerModelService.Get(ctx, input.Id)
+ if modelInfo == nil || modelInfo.Provider != provider {
+ return fmt.Errorf("model not found")
+ }
+ // check model name duplicate
+ if has := i.providerModelService.CheckNameDuplicate(ctx, provider, input.Name, input.Id); has {
+ return fmt.Errorf("model name: `%s` duplicate", input.Name)
+ }
+ if err := i.providerModelService.Save(ctx, input.Id, &ai_model.Model{
+ Name: &input.Name,
+ AccessConfiguration: &input.AccessConfiguration,
+ ModelParameters: &input.ModelParameters,
+ }); err != nil {
+ return err
+ }
+ // update provider model
+ iModel, _ := model_runtime.NewCustomizeModel(input.Id, input.Name, p.Logo(), input.AccessConfiguration, input.ModelParameters)
+ p.SetModel(input.Id, iModel)
+
+ return nil
+}
+
+func (i imlProviderModelModule) DeleteProviderModel(ctx *gin.Context, provider string, id string) error {
+ // check provider exist
+ providerInfo, err := i.providerService.Get(ctx, provider)
+ if err != nil {
+ return err
+ }
+ if providerInfo == nil {
+ return fmt.Errorf("provider not found")
+ }
+ modelInfo, _ := i.providerModelService.Get(ctx, id)
+ if modelInfo == nil || modelInfo.Provider != provider {
+ return fmt.Errorf("model not found")
+ }
+ return i.providerModelService.Delete(ctx, id)
+}
+
+func (i imlProviderModelModule) AddProviderModel(ctx *gin.Context, provider string, input *model_dto.Model) (*model_dto.SimpleModel, error) {
+ p, has := model_runtime.GetProvider(provider)
+ if !has {
+ return nil, fmt.Errorf("ai provider not found")
+ }
+ // check provider exist
+ providerInfo, err := i.providerService.Get(ctx, provider)
+ if err != nil {
+ return nil, err
+ }
+ if providerInfo == nil {
+ return nil, fmt.Errorf("provider not found")
+ }
+ // check model name duplicate
+ if has := i.providerModelService.CheckNameDuplicate(ctx, provider, input.Name, ""); has {
+ return nil, fmt.Errorf("model name: `%s` duplicate", input.Name)
+ }
+ id := uuid.New().String()
+ typeValue := "chat"
+ if err := i.providerModelService.Save(ctx, id, &ai_model.Model{
+ Name: &input.Name,
+ Type: &typeValue,
+ Provider: &provider,
+ AccessConfiguration: &input.AccessConfiguration,
+ ModelParameters: &input.ModelParameters,
+ }); err != nil {
+ return nil, err
+ }
+ // update provider model
+ iModel, _ := model_runtime.NewCustomizeModel(id, input.Name, p.Logo(), input.AccessConfiguration, input.ModelParameters)
+ p.SetModel(id, iModel)
+ return &model_dto.SimpleModel{
+ Id: id,
+ Name: input.Name,
+ }, nil
+}
diff --git a/module/ai-model/module.go b/module/ai-model/module.go
new file mode 100644
index 00000000..6f9f22d7
--- /dev/null
+++ b/module/ai-model/module.go
@@ -0,0 +1,22 @@
+package ai_model
+
+import (
+ model_dto "github.com/APIParkLab/APIPark/module/ai-model/dto"
+ "github.com/gin-gonic/gin"
+ "reflect"
+
+ "github.com/eolinker/go-common/autowire"
+)
+
+type IProviderModelModule interface {
+ AddProviderModel(ctx *gin.Context, provider string, input *model_dto.Model) (*model_dto.SimpleModel, error)
+ UpdateProviderModel(ctx *gin.Context, provider string, input *model_dto.EditModel) error
+ DeleteProviderModel(ctx *gin.Context, provider string, id string) error
+}
+
+func init() {
+ autowire.Auto[IProviderModelModule](func() reflect.Value {
+ module := new(imlProviderModelModule)
+ return reflect.ValueOf(module)
+ })
+}
diff --git a/module/ai/dto/input.go b/module/ai/dto/input.go
index 93a112f8..adeb7286 100644
--- a/module/ai/dto/input.go
+++ b/module/ai/dto/input.go
@@ -14,3 +14,7 @@ type UpdateConfig struct {
type Sort struct {
Providers []string `json:"providers"`
}
+
+type NewProvider struct {
+ Name string `json:"name"`
+}
diff --git a/module/ai/dto/output.go b/module/ai/dto/output.go
index 27d07eae..263b09ad 100644
--- a/module/ai/dto/output.go
+++ b/module/ai/dto/output.go
@@ -4,6 +4,11 @@ import (
"github.com/eolinker/go-common/auto"
)
+type SimpleModel struct {
+ Id string `json:"id"`
+ Name string `json:"name"`
+}
+
type SimpleProvider struct {
Id string `json:"id"`
Name string `json:"name"`
@@ -20,8 +25,14 @@ type Provider struct {
DefaultLLM string `json:"default_llm"`
DefaultLLMConfig string `json:"-"`
//Priority int `json:"priority"`
- Status ProviderStatus `json:"status"`
- Configured bool `json:"configured"`
+ Status ProviderStatus `json:"status"`
+ Configured bool `json:"configured"`
+ ModelConfig ModelConfig `json:"model_config"`
+}
+
+type ModelConfig struct {
+ AccessConfigurationStatus bool `json:"access_configuration_status"`
+ AccessConfigurationDemo string `json:"access_configuration_demo"`
}
type ConfiguredProviderItem struct {
@@ -32,6 +43,7 @@ type ConfiguredProviderItem struct {
Status ProviderStatus `json:"status"`
APICount int64 `json:"api_count"`
KeyCount int64 `json:"key_count"`
+ ModelCount int64 `json:"model_count"`
CanDelete bool `json:"can_delete"`
}
@@ -48,6 +60,7 @@ type ProviderItem struct {
Logo string `json:"logo"`
DefaultLLM string `json:"default_llm"`
Sort int `json:"-"`
+ Type int `json:"type"` // 0:default 1:customize
}
type SimpleProviderItem struct {
@@ -69,10 +82,15 @@ type BackupProvider struct {
}
type LLMItem struct {
- Id string `json:"id"`
- Logo string `json:"logo"`
- Config string `json:"config"`
- Scopes []string `json:"scopes"`
+ Id string `json:"id"`
+ Logo string `json:"logo"`
+ Config string `json:"config"`
+ AccessConfiguration string `json:"access_configuration"`
+ ModelParameters string `json:"model_parameters"`
+ Scopes []string `json:"scopes"`
+ Type string `json:"type"`
+ IsSystem bool `json:"is_system"`
+ ApiCount int64 `json:"api_count"`
}
type APIItem struct {
diff --git a/module/ai/iml.go b/module/ai/iml.go
index d4c06d89..658cf0f0 100644
--- a/module/ai/iml.go
+++ b/module/ai/iml.go
@@ -4,12 +4,12 @@ import (
"context"
"errors"
"fmt"
+ ai_model "github.com/APIParkLab/APIPark/service/ai-model"
+ "github.com/google/uuid"
"net/http"
"sort"
"time"
- "github.com/google/uuid"
-
ai_provider_local "github.com/APIParkLab/APIPark/ai-provider/local"
"github.com/eolinker/go-common/register"
@@ -63,12 +63,13 @@ func newKey(key *ai_key.Key) *gateway.DynamicRelease {
var _ IProviderModule = (*imlProviderModule)(nil)
type imlProviderModule struct {
- providerService ai.IProviderService `autowired:""`
- clusterService cluster.IClusterService `autowired:""`
- aiAPIService ai_api.IAPIService `autowired:""`
- aiKeyService ai_key.IKeyService `autowired:""`
- aiBalanceService ai_balance.IBalanceService `autowired:""`
- transaction store.ITransaction `autowired:""`
+ providerService ai.IProviderService `autowired:""`
+ providerModelService ai_model.IProviderModelService `autowired:""`
+ clusterService cluster.IClusterService `autowired:""`
+ aiAPIService ai_api.IAPIService `autowired:""`
+ aiKeyService ai_key.IKeyService `autowired:""`
+ aiBalanceService ai_balance.IBalanceService `autowired:""`
+ transaction store.ITransaction `autowired:""`
}
func (i *imlProviderModule) OnInit() {
@@ -79,6 +80,33 @@ func (i *imlProviderModule) OnInit() {
if err != nil {
return
}
+ // register provider
+ for _, p := range list {
+ // get customize models
+ models, _ := i.providerModelService.Search(ctx, "", map[string]interface{}{"provider": p.Id}, "update_at desc")
+ iModels := make([]model_runtime.IModel, 0, len(models))
+ if models != nil {
+ for _, model := range models {
+ // parse access_config & model_parameters
+ iModel, _ := model_runtime.NewCustomizeModel(model.Id, model.Name, model_runtime.GetCustomizeLogo(), model.AccessConfiguration, model.ModelParameters)
+ iModels = append(iModels, iModel)
+ }
+ }
+ // default provider
+ if p.Type == 0 {
+ runtimeProvider, _ := model_runtime.GetProvider(p.Id)
+ for _, tmpIModel := range iModels {
+ tmpIModel.SetLogo(runtimeProvider.Logo())
+ if p.DefaultLLM == tmpIModel.ID() {
+ runtimeProvider.SetDefaultModel(model_runtime.ModelTypeLLM, tmpIModel)
+ }
+ runtimeProvider.SetModel(tmpIModel.ID(), tmpIModel)
+ }
+ } else {
+ provider, _ := model_runtime.NewCustomizeProvider(p.Id, p.Name, iModels, p.DefaultLLM, p.Config)
+ model_runtime.Register(p.Id, provider)
+ }
+ }
i.transaction.Transaction(ctx, func(ctx context.Context) error {
for _, l := range list {
if l.Priority < 1 {
@@ -145,6 +173,8 @@ func (i *imlProviderModule) Delete(ctx context.Context, id string) error {
if err != nil {
return err
}
+ // delete register provider
+ model_runtime.Remove(id)
releases := make([]*gateway.DynamicRelease, 0, len(keys))
for _, key := range keys {
releases = append(releases, newKey(key))
@@ -164,6 +194,30 @@ func (i *imlProviderModule) Delete(ctx context.Context, id string) error {
})
}
+func (i *imlProviderModule) AddProvider(ctx context.Context, input *ai_dto.NewProvider) (*ai_dto.SimpleProvider, error) {
+ if has := i.providerService.CheckNameDuplicate(ctx, input.Name); has {
+ return nil, fmt.Errorf("provider `%s` duplicate", input.Name)
+ }
+ id := uuid.New().String()
+ config, defaultLLM := "{\"api_endpoint_url\": \"http://127.0.0.1\", \"api_key\": \"\"}", ""
+ typeValue := 1
+ if err := i.providerService.Save(ctx, id, &ai.SetProvider{
+ Name: &input.Name,
+ DefaultLLM: &defaultLLM,
+ Config: &config,
+ Type: &typeValue,
+ }); err != nil {
+ return nil, err
+ }
+ // register provider
+ iProvider, _ := model_runtime.NewCustomizeProvider(id, input.Name, []model_runtime.IModel{}, "", "")
+ model_runtime.Register(id, iProvider)
+ return &ai_dto.SimpleProvider{
+ Id: id,
+ Name: input.Name,
+ }, nil
+}
+
func (i *imlProviderModule) SimpleProvider(ctx context.Context, id string) (*ai_dto.SimpleProvider, error) {
p, has := model_runtime.GetProvider(id)
if !has {
@@ -231,6 +285,7 @@ func (i *imlProviderModule) ConfiguredProviders(ctx context.Context, keyword str
APICount: apiCount,
KeyCount: keyMap[l.Id],
CanDelete: apiCount < 1,
+ ModelCount: int64(len(p.Models())),
})
}
@@ -394,8 +449,9 @@ func (i *imlProviderModule) Provider(ctx context.Context, id string) (*ai_dto.Pr
}
defaultLLM, has := p.DefaultModel(model_runtime.ModelTypeLLM)
if !has {
- return nil, fmt.Errorf("ai provider llm not found")
+ defaultLLM, _ = model_runtime.NewCustomizeModel("", "", "", "", "")
}
+ providerModelConfig := p.GetModelConfig()
return &ai_dto.Provider{
Id: p.ID(),
Name: p.Name(),
@@ -405,13 +461,17 @@ func (i *imlProviderModule) Provider(ctx context.Context, id string) (*ai_dto.Pr
DefaultLLMConfig: defaultLLM.Logo(),
Status: ai_dto.ProviderDisabled,
//Priority: maxPriority,
+ ModelConfig: ai_dto.ModelConfig{
+ AccessConfigurationStatus: providerModelConfig.AccessConfigurationStatus,
+ AccessConfigurationDemo: providerModelConfig.AccessConfigurationDemo,
+ },
}, nil
}
defaultLLM, has := p.GetModel(info.DefaultLLM)
if !has {
model, has := p.DefaultModel(model_runtime.ModelTypeLLM)
if !has {
- return nil, fmt.Errorf("ai provider llm not found")
+ defaultLLM, _ = model_runtime.NewCustomizeModel("", "", "", "", "")
}
defaultLLM = model
}
@@ -426,6 +486,10 @@ func (i *imlProviderModule) Provider(ctx context.Context, id string) (*ai_dto.Pr
//Priority: info.Priority,
Status: ai_dto.ToProviderStatus(info.Status),
Configured: true,
+ ModelConfig: ai_dto.ModelConfig{
+ AccessConfigurationStatus: false,
+ AccessConfigurationDemo: "",
+ },
}, nil
}
@@ -439,16 +503,21 @@ func (i *imlProviderModule) LLMs(ctx context.Context, driver string) ([]*ai_dto.
if !has {
return nil, nil, fmt.Errorf("ai provider not found")
}
-
+ modelApiCountMap, _ := i.aiAPIService.CountMapByModel(ctx, "", map[string]interface{}{"provider": driver})
items := make([]*ai_dto.LLMItem, 0, len(llms))
for _, v := range llms {
items = append(items, &ai_dto.LLMItem{
- Id: v.ID(),
- Logo: v.Logo(),
- Config: v.DefaultConfig(),
+ Id: v.ID(),
+ Logo: v.Logo(),
+ Config: v.DefaultConfig(),
+ AccessConfiguration: v.AccessConfiguration(),
+ ModelParameters: v.ModelParameters(),
Scopes: []string{
"chat",
},
+ Type: "chat",
+ IsSystem: v.Source() != "customize",
+ ApiCount: modelApiCountMap[v.ID()],
})
}
info, err := i.providerService.Get(ctx, driver)
@@ -523,6 +592,11 @@ func (i *imlProviderModule) UpdateProviderConfig(ctx context.Context, id string,
if err != nil {
return err
}
+ if input.DefaultLLM != "" {
+ if defaultLLM, has := p.GetModel(input.DefaultLLM); has {
+ p.SetDefaultModel(model_runtime.ModelTypeLLM, defaultLLM)
+ }
+ }
status := 0
if input.Enable != nil && *input.Enable {
status = 1
diff --git a/module/ai/module.go b/module/ai/module.go
index 650de0ff..f036b048 100644
--- a/module/ai/module.go
+++ b/module/ai/module.go
@@ -19,6 +19,7 @@ type IProviderModule interface {
LLMs(ctx context.Context, driver string) ([]*ai_dto.LLMItem, *ai_dto.ProviderItem, error)
UpdateProviderConfig(ctx context.Context, id string, input *ai_dto.UpdateConfig) error
Delete(ctx context.Context, id string) error
+ AddProvider(ctx context.Context, input *ai_dto.NewProvider) (*ai_dto.SimpleProvider, error)
}
type IAIAPIModule interface {
diff --git a/plugins/core/ai.go b/plugins/core/ai.go
index 6bf7af69..2cc76219 100644
--- a/plugins/core/ai.go
+++ b/plugins/core/ai.go
@@ -19,6 +19,10 @@ func (p *plugin) aiAPIs() []pm3.Api {
pm3.CreateApiWidthDoc(http.MethodGet, "/api/v1/ai/provider/llms", []string{"context", "query:provider"}, []string{"llms", "provider"}, p.aiProviderController.LLMs),
pm3.CreateApiWidthDoc(http.MethodDelete, "/api/v1/ai/provider", []string{"context", "query:provider"}, nil, p.aiProviderController.Delete),
pm3.CreateApiWidthDoc(http.MethodPut, "/api/v1/ai/provider/config", []string{"context", "query:provider", "body"}, nil, p.aiProviderController.UpdateProviderConfig, access.SystemSettingsAiProviderManager),
+ pm3.CreateApiWidthDoc(http.MethodPost, "/api/v1/ai/provider", []string{"context", "body"}, []string{"provider"}, p.aiProviderController.AddProvider, access.SystemSettingsAiProviderManager),
+ pm3.CreateApiWidthDoc(http.MethodPost, "/api/v1/ai/provider/model", []string{"context", "query:provider", "body"}, []string{"model"}, p.aiProviderModelController.AddProviderModel, access.SystemSettingsAiProviderManager),
+ pm3.CreateApiWidthDoc(http.MethodPut, "/api/v1/ai/provider/model", []string{"context", "query:provider", "body"}, nil, p.aiProviderModelController.UpdateProviderModel, access.SystemSettingsAiProviderManager),
+ pm3.CreateApiWidthDoc(http.MethodDelete, "/api/v1/ai/provider/model", []string{"context", "query:provider", "query:id"}, nil, p.aiProviderModelController.DeleteProviderModel, access.SystemSettingsAiProviderManager),
pm3.CreateApiWidthDoc(http.MethodGet, "/api/v1/ai/apis", []string{"context", "query:keyword", "query:provider", "query:start", "query:end", "query:page", "query:page_size", "query:sort", "query:asc", "query:models", "query:services"}, []string{"apis", "condition", "total"}, p.aiStatisticController.APIs),
}
diff --git a/plugins/core/core.go b/plugins/core/core.go
index be4e3f00..b3afb4b1 100644
--- a/plugins/core/core.go
+++ b/plugins/core/core.go
@@ -1,6 +1,7 @@
package core
import (
+ ai_model "github.com/APIParkLab/APIPark/controller/ai-model"
"net/http"
ai_balance "github.com/APIParkLab/APIPark/controller/ai-balance"
@@ -97,6 +98,7 @@ type plugin struct {
exportConfigController system.IExportConfigController `autowired:""`
importConfigController system.IImportConfigController `autowired:""`
aiProviderController ai.IProviderController `autowired:""`
+ aiProviderModelController ai_model.IProviderModelController `autowired:""`
settingController system.ISettingController `autowired:""`
initController system.IInitController `autowired:""`
logController log.ILogController `autowired:""`
diff --git a/service/ai-model/iml.go b/service/ai-model/iml.go
new file mode 100644
index 00000000..7a594556
--- /dev/null
+++ b/service/ai-model/iml.go
@@ -0,0 +1,84 @@
+package ai_model
+
+import (
+ "context"
+ "errors"
+ "github.com/APIParkLab/APIPark/service/universally"
+ "github.com/APIParkLab/APIPark/stores/ai"
+ "github.com/eolinker/go-common/utils"
+ "gorm.io/gorm"
+ "time"
+)
+
+var _ IProviderModelService = (*imlProviderModelService)(nil)
+
+type imlProviderModelService struct {
+ universally.IServiceGet[ProviderModel]
+ universally.IServiceCreate[ProviderModel]
+ universally.IServiceDelete
+ store ai.IProviderModelStore `autowired:""`
+}
+
+func (i *imlProviderModelService) CountMapByProvider(ctx context.Context, conditions map[string]interface{}) (map[string]int64, error) {
+ return i.store.CountByGroup(ctx, "", conditions, "provider")
+}
+
+func (i *imlProviderModelService) Save(ctx context.Context, id string, model *Model) error {
+ userId := utils.UserId(ctx)
+ now := time.Now()
+ info, err := i.store.First(ctx, map[string]interface{}{"uuid": id})
+ if err != nil {
+ if !errors.Is(err, gorm.ErrRecordNotFound) {
+ return err
+ }
+ if model.Name == nil || model.Provider == nil {
+ return errors.New("invalid params")
+ }
+ info = &ai.ProviderModel{
+ UUID: id,
+ Name: *model.Name,
+ Type: *model.Type,
+ AccessConfiguration: *model.AccessConfiguration,
+ ModelParameters: *model.ModelParameters,
+ Provider: *model.Provider,
+ Creator: userId,
+ Updater: userId,
+ CreateAt: now,
+ UpdateAt: now,
+ }
+ } else {
+ if model.Name != nil {
+ info.Name = *model.Name
+ }
+ if model.Type != nil {
+ info.Type = *model.Type
+ }
+ if model.Provider != nil {
+ info.Provider = *model.Provider
+ }
+ if model.AccessConfiguration != nil {
+ info.AccessConfiguration = *model.AccessConfiguration
+ }
+ if model.ModelParameters != nil {
+ info.AccessConfiguration = *model.ModelParameters
+ }
+ info.Updater = userId
+ info.UpdateAt = now
+ }
+ return i.store.Save(ctx, info)
+}
+
+func (i *imlProviderModelService) CheckNameDuplicate(ctx context.Context, provider string, name string, excludeId string) bool {
+ v, _ := i.store.First(ctx, map[string]interface{}{"provider": provider, "name": name})
+ if v != nil {
+ return true
+ } else if excludeId != "" && v.UUID != excludeId {
+ return true
+ }
+ return false
+}
+
+func (i *imlProviderModelService) OnComplete() {
+ i.IServiceGet = universally.NewGet[ProviderModel, ai.ProviderModel](i.store, FromEntity)
+ i.IServiceDelete = universally.NewDelete[ai.ProviderModel](i.store)
+}
diff --git a/service/ai-model/model.go b/service/ai-model/model.go
new file mode 100644
index 00000000..96a5df41
--- /dev/null
+++ b/service/ai-model/model.go
@@ -0,0 +1,51 @@
+package ai_model
+
+import (
+ "encoding/base64"
+ "github.com/APIParkLab/APIPark/stores/ai"
+ "time"
+)
+
+type ProviderModel struct {
+ Id string // provider model:uuid
+ Name string
+ Type string
+ AccessConfiguration string
+ ModelParameters string
+ Provider string
+ Creator string
+ Updater string
+ CreateAt time.Time
+ UpdateAt time.Time
+}
+
+func FromEntity(e *ai.ProviderModel) *ProviderModel {
+ accessConfiguration, err := base64.RawStdEncoding.DecodeString(e.AccessConfiguration)
+ modelParameters, err := base64.RawStdEncoding.DecodeString(e.ModelParameters)
+ if err != nil {
+ accessConfiguration = []byte(e.AccessConfiguration)
+ }
+ if err != nil {
+ modelParameters = []byte(e.ModelParameters)
+ }
+ return &ProviderModel{
+ Id: e.UUID,
+ Name: e.Name,
+ Type: e.Type,
+ AccessConfiguration: string(accessConfiguration),
+ ModelParameters: string(modelParameters),
+ Provider: e.Provider,
+ Creator: e.Creator,
+ Updater: e.Updater,
+ CreateAt: e.CreateAt,
+ UpdateAt: e.UpdateAt,
+ }
+}
+
+type Model struct {
+ Name *string
+ Provider *string
+ Type *string
+ AccessConfiguration *string
+ ModelParameters *string
+}
diff --git a/service/ai-model/service.go b/service/ai-model/service.go
new file mode 100644
index 00000000..78c58cf4
--- /dev/null
+++ b/service/ai-model/service.go
@@ -0,0 +1,22 @@
+package ai_model
+
+import (
+ "context"
+ "github.com/APIParkLab/APIPark/service/universally"
+ "github.com/eolinker/go-common/autowire"
+ "reflect"
+)
+
+type IProviderModelService interface {
+ universally.IServiceGet[ProviderModel]
+ universally.IServiceDelete
+ CountMapByProvider(ctx context.Context, conditions map[string]interface{}) (map[string]int64, error)
+ Save(ctx context.Context, id string, cfg *Model) error
+ CheckNameDuplicate(ctx context.Context, provider string, name string, excludeId string) bool
+}
+
+func init() {
+ autowire.Auto[IProviderModelService](func() reflect.Value {
+ return reflect.ValueOf(new(imlProviderModelService))
+ })
+}
diff --git a/service/ai/iml.go b/service/ai/iml.go
index 6eda08a9..7e62c2d3 100644
--- a/service/ai/iml.go
+++ b/service/ai/iml.go
@@ -70,6 +70,14 @@ type imlProviderService struct {
// return i.store.Save(ctx, info)
//}
+func (i *imlProviderService) CheckNameDuplicate(ctx context.Context, name string) bool {
+ v, _ := i.store.First(ctx, map[string]interface{}{"name": name})
+ if v != nil {
+ return true
+ }
+ return false
+}
+
func (i *imlProviderService) GetLabels(ctx context.Context, ids ...string) map[string]string {
if len(ids) == 0 {
return nil
diff --git a/service/ai/model.go b/service/ai/model.go
index 263a79fd..a10fe7a5 100644
--- a/service/ai/model.go
+++ b/service/ai/model.go
@@ -15,6 +15,7 @@ type Provider struct {
Creator string
Updater string
Status int
+ Type int
Priority int
CreateAt time.Time
UpdateAt time.Time
@@ -34,6 +35,7 @@ type SetProvider struct {
Config *string
Priority *int
Status *int
+ Type *int
}
func FromEntity(e *ai.Provider) *Provider {
@@ -52,5 +54,6 @@ func FromEntity(e *ai.Provider) *Provider {
UpdateAt: e.UpdateAt,
Status: e.Status,
Priority: e.Priority,
+ Type: e.Type,
}
}
diff --git a/service/ai/service.go b/service/ai/service.go
index c4201db6..c7eb0071 100644
--- a/service/ai/service.go
+++ b/service/ai/service.go
@@ -1,6 +1,7 @@
package ai
import (
+ "context"
"reflect"
"github.com/APIParkLab/APIPark/service/universally"
@@ -14,6 +15,7 @@ type IProviderService interface {
universally.IServiceDelete
//Save(ctx context.Context, id string, cfg *SetProvider) error
//MaxPriority(ctx context.Context) (int, error)
+ CheckNameDuplicate(ctx context.Context, name string) bool
}
func init() {
diff --git a/stores/ai/model.go b/stores/ai/model.go
index 6d607d16..0eb2788b 100644
--- a/stores/ai/model.go
+++ b/stores/ai/model.go
@@ -10,6 +10,7 @@ type Provider struct {
Config string `gorm:"type:text;not null;column:config;comment:配置信息"`
Status int `gorm:"type:tinyint(1);not null;column:status;comment:状态,0:停用;1:启用,2:异常;default:1"`
Priority int `gorm:"type:int;not null;column:priority;comment:优先级,值越小优先级越高"`
+ Type int `gorm:"type:tinyint(1);not null;column:type;comment:type 0:default 1:customize"`
Creator string `gorm:"size:36;not null;column:creator;comment:创建人;index:creator" aovalue:"creator"` // 创建人
Updater string `gorm:"size:36;not null;column:updater;comment:更新人;index:updater" aovalue:"updater"` // 更新人
CreateAt time.Time `gorm:"type:timestamp;NOT NULL;DEFAULT:CURRENT_TIMESTAMP;column:create_at;comment:创建时间"`
@@ -164,3 +165,26 @@ func (i *LocalModelCache) TableName() string {
func (i *LocalModelCache) IdValue() int64 {
return i.Id
}
+
+type ProviderModel struct {
+ Id int64 `gorm:"column:id;type:BIGINT(20);AUTO_INCREMENT;NOT NULL;comment:id;primary_key;comment:PRIMARY ID;"`
+ UUID string `gorm:"type:varchar(36);not null;column:uuid;uniqueIndex:uuid;comment:UUID;"`
+ Name string `gorm:"type:varchar(100);not null;column:name;comment:name;index:name"`
+ Type string `gorm:"type:varchar(100);not null;column:type;comment:type:chat"`
+ AccessConfiguration string `gorm:"type:text;not null;column:access_configuration;comment:access_configuration json"`
+ ModelParameters string `gorm:"type:text;not null;column:model_parameters;comment:model_parameters json"`
+ Provider string `gorm:"type:varchar(36);not null;column:provider;comment:ai_provider:uuid;index:provider"`
+ Creator string `gorm:"size:36;not null;column:creator;comment:creator;index:creator" aovalue:"creator"`
+ Updater string `gorm:"size:36;not null;column:updater;comment:updater;index:updater" aovalue:"updater"`
+ CreateAt time.Time `gorm:"type:timestamp;NOT NULL;DEFAULT:CURRENT_TIMESTAMP;column:create_at;comment:create_at"`
+ UpdateAt time.Time `gorm:"type:timestamp;NOT NULL;DEFAULT:CURRENT_TIMESTAMP;column:update_at;comment:update_at"`
+}
+
+func (i *ProviderModel) TableName() string {
+ return "ai_provider_model"
+}
+
+func (i *ProviderModel) IdValue() int64 {
+ return i.Id
+}
+
diff --git a/stores/ai/store.go b/stores/ai/store.go
index 8c462f42..3ab6a36e 100644
--- a/stores/ai/store.go
+++ b/stores/ai/store.go
@@ -71,6 +71,14 @@ type imlLocalModelCacheStore struct {
store.Store[LocalModelCache]
}
+type IProviderModelStore interface {
+ store.ISearchStore[ProviderModel]
+}
+
+type imlProviderModelStore struct {
+ store.SearchStore[ProviderModel]
+}
+
func init() {
autowire.Auto[IProviderStore](func() reflect.Value {
return reflect.ValueOf(new(imlProviderStore))
@@ -103,4 +111,8 @@ func init() {
autowire.Auto[ILocalModelCacheStore](func() reflect.Value {
return reflect.ValueOf(new(imlLocalModelCacheStore))
})
+
+ autowire.Auto[IProviderModelStore](func() reflect.Value {
+ return reflect.ValueOf(new(imlProviderModelStore))
+ })
}