diff --git a/.gitignore b/.gitignore index cb2bf660..9708fbd2 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,6 @@ /apipark .gitlab-ci.yml /.vscode/ +.air.toml +/tmp/ +/work \ No newline at end of file diff --git a/ai-provider/model-runtime/entity/provider.go b/ai-provider/model-runtime/entity/provider.go index d9014560..f3ee1d6e 100644 --- a/ai-provider/model-runtime/entity/provider.go +++ b/ai-provider/model-runtime/entity/provider.go @@ -25,6 +25,11 @@ type Provider struct { Address string `json:"address" yaml:"address"` Recommend bool `json:"recommend" yaml:"recommend"` Sort int `json:"sort" yaml:"sort"` + ModelConfig ModelConfig `json:"model_config" yaml:"model_config"` +} +type ModelConfig struct { + AccessConfigurationStatus bool `json:"access_configuration_status" yaml:"access_configuration_status"` + AccessConfigurationDemo string `json:"access_configuration_demo" yaml:"access_configuration_demo"` } type ProviderCredentialSchema struct { diff --git a/ai-provider/model-runtime/loader.go b/ai-provider/model-runtime/loader.go index 5320ff0e..a6eefb25 100644 --- a/ai-provider/model-runtime/loader.go +++ b/ai-provider/model-runtime/loader.go @@ -6,8 +6,6 @@ import ( "fmt" "strings" - "github.com/APIParkLab/APIPark/gateway" - "github.com/eolinker/eosc" ) @@ -36,7 +34,10 @@ func (c *Config) Check(cfg string) error { if err != nil { return err } - return c.validator.Valid(data) + if c.validator != nil { + return c.validator.Valid(data) + } + return nil } func (c *Config) GenConfig(target string, origin string) (string, error) { @@ -83,6 +84,9 @@ func Load() error { continue } name := fmt.Sprintf("model-providers/%s", file.Name()) + if file.Name() == "customize" { + continue + } err = LoadProvider(name) if err != nil { return err @@ -119,10 +123,10 @@ func LoadProvider(name string) error { if err != nil { return err } - gateway.RegisterDynamicResourceDriver(provider.ID(), gateway.Worker{ - Profession: gateway.ProfessionAIProvider, - Driver: provider.ID(), - }) + //gateway.RegisterDynamicResourceDriver(provider.ID(), gateway.Worker{ + // Profession: gateway.ProfessionAIProvider, + // Driver: provider.ID(), + //}) Register(provider.ID(), provider) return nil } diff --git a/ai-provider/model-runtime/manager.go b/ai-provider/model-runtime/manager.go index dbefea3e..169ff0fc 100644 --- a/ai-provider/model-runtime/manager.go +++ b/ai-provider/model-runtime/manager.go @@ -43,6 +43,10 @@ func Register(name string, driver IProvider) { defaultManager.Set(name, driver) } +func Remove(name string) { + defaultManager.Del(name) +} + func GetProvider(name string) (IProvider, bool) { return defaultManager.Get(name) } diff --git a/ai-provider/model-runtime/model-providers/bailian/assets/icon_l_en.svg b/ai-provider/model-runtime/model-providers/bailian/assets/icon_l_en.svg new file mode 100644 index 00000000..2a7a4f4f --- /dev/null +++ b/ai-provider/model-runtime/model-providers/bailian/assets/icon_l_en.svg @@ -0,0 +1,2 @@ + + \ No newline at end of file diff --git a/ai-provider/model-runtime/model-providers/bailian/assets/icon_l_zh.svg b/ai-provider/model-runtime/model-providers/bailian/assets/icon_l_zh.svg new file mode 100644 index 00000000..9f650f2b --- /dev/null +++ b/ai-provider/model-runtime/model-providers/bailian/assets/icon_l_zh.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/ai-provider/model-runtime/model-providers/bailian/assets/icon_s_en.svg b/ai-provider/model-runtime/model-providers/bailian/assets/icon_s_en.svg new file mode 100644 index 00000000..851ba565 --- /dev/null +++ b/ai-provider/model-runtime/model-providers/bailian/assets/icon_s_en.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/ai-provider/model-runtime/model-providers/bailian/bailian.yaml b/ai-provider/model-runtime/model-providers/bailian/bailian.yaml new file mode 100644 index 00000000..7456a7de --- /dev/null +++ b/ai-provider/model-runtime/model-providers/bailian/bailian.yaml @@ -0,0 +1,32 @@ +provider: bailian +label: + zh_Hans: 阿里云百炼 + en_US: bailian +icon_small: + en_US: icon_s_en.svg +icon_large: + zh_Hans: icon_l_zh.svg + en_US: icon_l_en.svg +background: "#EFF1FE" +help: + title: + en_US: Get your API key from AliCloud + zh_Hans: 从阿里云百炼获取 API Key + url: + en_US: https://bailian.console.aliyun.com/?apiKey=1#/api-key +supported_model_types: + - llm +configurate_methods: + - predefined-model + - customizable-model +provider_credential_schema: + credential_form_schemas: + - variable: dashscope_api_key + label: + en_US: API Key + type: secret-input + required: true + placeholder: + zh_Hans: 在此输入您的 API Key + en_US: Enter your API Key +address: https://dashscope.aliyuncs.com \ No newline at end of file diff --git a/ai-provider/model-runtime/model-providers/customize/assets/icon_l_en.svg b/ai-provider/model-runtime/model-providers/customize/assets/icon_l_en.svg new file mode 100644 index 00000000..42d5eebe --- /dev/null +++ b/ai-provider/model-runtime/model-providers/customize/assets/icon_l_en.svg @@ -0,0 +1,3 @@ + + + diff --git a/ai-provider/model-runtime/model-providers/customize/assets/icon_l_zh.svg b/ai-provider/model-runtime/model-providers/customize/assets/icon_l_zh.svg new file mode 100644 index 00000000..42d5eebe --- /dev/null +++ b/ai-provider/model-runtime/model-providers/customize/assets/icon_l_zh.svg @@ -0,0 +1,3 @@ + + + diff --git a/ai-provider/model-runtime/model-providers/customize/assets/icon_s_en.svg b/ai-provider/model-runtime/model-providers/customize/assets/icon_s_en.svg new file mode 100644 index 00000000..42d5eebe --- /dev/null +++ b/ai-provider/model-runtime/model-providers/customize/assets/icon_s_en.svg @@ -0,0 +1,3 @@ + + + diff --git a/ai-provider/model-runtime/model-providers/huggingface_hub/assets/icon_l_en.svg b/ai-provider/model-runtime/model-providers/huggingface_hub/assets/icon_l_en.svg new file mode 100644 index 00000000..d4a19a32 --- /dev/null +++ b/ai-provider/model-runtime/model-providers/huggingface_hub/assets/icon_l_en.svg @@ -0,0 +1,42 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/ai-provider/model-runtime/model-providers/huggingface_hub/assets/icon_s_en.svg b/ai-provider/model-runtime/model-providers/huggingface_hub/assets/icon_s_en.svg new file mode 100644 index 00000000..b05a4a1b --- /dev/null +++ b/ai-provider/model-runtime/model-providers/huggingface_hub/assets/icon_s_en.svg @@ -0,0 +1,42 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/ai-provider/model-runtime/model-providers/huggingface_hub/huggingface_hub.yaml b/ai-provider/model-runtime/model-providers/huggingface_hub/huggingface_hub.yaml new file mode 100644 index 00000000..e51e6081 --- /dev/null +++ b/ai-provider/model-runtime/model-providers/huggingface_hub/huggingface_hub.yaml @@ -0,0 +1,103 @@ +provider: huggingface_hub +label: + en_US: Hugging Face Model +icon_small: + en_US: icon_s_en.svg +icon_large: + en_US: icon_l_en.svg +background: "#FFF8DC" +help: + title: + en_US: Get your API key from Hugging Face Hub + zh_Hans: 从 Hugging Face Hub 获取 API Key + url: + en_US: https://huggingface.co/settings/tokens +supported_model_types: + - llm + - text-embedding +configurate_methods: + - customizable-model +model_credential_schema: + model: + label: + en_US: Model Name + zh_Hans: 模型名称 + credential_form_schemas: + - variable: huggingfacehub_api_type + label: + en_US: Endpoint Type + zh_Hans: 端点类型 + type: radio + required: true + default: hosted_inference_api + options: + - value: hosted_inference_api + label: + en_US: Hosted Inference API + - value: inference_endpoints + label: + en_US: Inference Endpoints + - variable: huggingfacehub_api_token + label: + en_US: API Token + zh_Hans: API Token + type: secret-input + required: true + placeholder: + en_US: Enter your Hugging Face Hub API Token here + zh_Hans: 在此输入您的 Hugging Face Hub API Token + - variable: huggingface_namespace + label: + en_US: 'User Name / Organization Name' + zh_Hans: '用户名 / 组织名称' + type: text-input + required: true + placeholder: + en_US: 'Enter your User Name / Organization Name here' + zh_Hans: '在此输入您的用户名 / 组织名称' + show_on: + - variable: __model_type + value: text-embedding + - variable: huggingfacehub_api_type + value: inference_endpoints + - variable: huggingfacehub_endpoint_url + label: + en_US: Endpoint URL + zh_Hans: 端点 URL + type: text-input + required: true + placeholder: + en_US: Enter your Endpoint URL here + zh_Hans: 在此输入您的端点 URL + show_on: + - variable: huggingfacehub_api_type + value: inference_endpoints + - variable: task_type + label: + en_US: Task + zh_Hans: Task + type: select + options: + - value: text2text-generation + label: + en_US: Text-to-Text Generation + show_on: + - variable: __model_type + value: llm + - value: text-generation + label: + en_US: Text Generation + zh_Hans: 文本生成 + show_on: + - variable: __model_type + value: llm + - value: feature-extraction + label: + en_US: Feature Extraction + show_on: + - variable: __model_type + value: text-embedding + show_on: + - variable: huggingfacehub_api_type + value: inference_endpoints +address: https://api-inference.huggingface.co \ No newline at end of file diff --git a/ai-provider/model-runtime/model-providers/lm_studio/assets/icon_l_en.svg b/ai-provider/model-runtime/model-providers/lm_studio/assets/icon_l_en.svg new file mode 100644 index 00000000..f1ef8d4b --- /dev/null +++ b/ai-provider/model-runtime/model-providers/lm_studio/assets/icon_l_en.svg @@ -0,0 +1,11 @@ + + + + + + + + + + + diff --git a/ai-provider/model-runtime/model-providers/lm_studio/assets/icon_s_en.svg b/ai-provider/model-runtime/model-providers/lm_studio/assets/icon_s_en.svg new file mode 100644 index 00000000..86f2c419 --- /dev/null +++ b/ai-provider/model-runtime/model-providers/lm_studio/assets/icon_s_en.svg @@ -0,0 +1,4 @@ + + + + diff --git a/ai-provider/model-runtime/model-providers/lm_studio/lm_studio.yaml b/ai-provider/model-runtime/model-providers/lm_studio/lm_studio.yaml new file mode 100644 index 00000000..5e62cd11 --- /dev/null +++ b/ai-provider/model-runtime/model-providers/lm_studio/lm_studio.yaml @@ -0,0 +1,99 @@ +provider: lm_studio +label: + en_US: LM Studio +icon_large: + en_US: icon_l_en.svg +icon_small: + en_US: icon_s_en.svg +background: "#F9FAFB" +help: + title: + en_US: How to integrate with LM Studio + zh_Hans: 如何集成 LM Studio + url: + en_US: https://lmstudio.ai/docs/app +supported_model_types: + - llm + - text-embedding +configurate_methods: + - customizable-model +model_credential_schema: + model: + label: + en_US: Model Name + zh_Hans: 模型名称 + placeholder: + en_US: Enter your model name + zh_Hans: 输入模型名称 + credential_form_schemas: + - variable: base_url + label: + zh_Hans: 基础 URL + en_US: Base URL + type: text-input + required: true + placeholder: + zh_Hans: LM Studio server 的基础 URL,例如 http://localhost:1234 + en_US: Base url of LM Studio server, e.g. http://localhost:1234 + - variable: mode + show_on: + - variable: __model_type + value: llm + label: + zh_Hans: 模型类型 + en_US: Completion mode + type: select + required: true + default: chat + placeholder: + zh_Hans: 选择对话类型 + en_US: Select completion mode + options: + - value: completion + label: + en_US: Completion + zh_Hans: 补全 + - value: chat + label: + en_US: Chat + zh_Hans: 对话 + - variable: context_size + label: + zh_Hans: 模型上下文长度 + en_US: Model context size + required: true + type: text-input + default: '4096' + placeholder: + zh_Hans: 在此输入您的模型上下文长度 + en_US: Enter your Model context size + - variable: max_tokens + label: + zh_Hans: 最大 token 上限 + en_US: Upper bound for max tokens + show_on: + - variable: __model_type + value: llm + default: '4096' + type: text-input + required: true + - variable: function_call_support + label: + zh_Hans: 是否支持函数调用 + en_US: Function call support + show_on: + - variable: __model_type + value: llm + default: 'false' + type: radio + required: false + options: + - value: 'true' + label: + en_US: 'Yes' + zh_Hans: 是 + - value: 'false' + label: + en_US: 'No' + zh_Hans: 否 +address: https://lmstudio.ai diff --git a/ai-provider/model-runtime/model-providers/ollama/assets/icon_l_en.svg b/ai-provider/model-runtime/model-providers/ollama/assets/icon_l_en.svg new file mode 100644 index 00000000..5f08476c --- /dev/null +++ b/ai-provider/model-runtime/model-providers/ollama/assets/icon_l_en.svg @@ -0,0 +1 @@ +LM Studio \ No newline at end of file diff --git a/ai-provider/model-runtime/model-providers/ollama/assets/icon_s_en.svg b/ai-provider/model-runtime/model-providers/ollama/assets/icon_s_en.svg new file mode 100644 index 00000000..5f08476c --- /dev/null +++ b/ai-provider/model-runtime/model-providers/ollama/assets/icon_s_en.svg @@ -0,0 +1 @@ +LM Studio \ No newline at end of file diff --git a/ai-provider/model-runtime/model-providers/ollama/ollama.yaml b/ai-provider/model-runtime/model-providers/ollama/ollama.yaml new file mode 100644 index 00000000..b9144630 --- /dev/null +++ b/ai-provider/model-runtime/model-providers/ollama/ollama.yaml @@ -0,0 +1,118 @@ +provider: ollama +label: + en_US: Ollama +icon_large: + en_US: icon_l_en.svg +icon_small: + en_US: icon_s_en.svg +background: "#F9FAFB" +help: + title: + en_US: How to integrate with Ollama + zh_Hans: 如何集成 Ollama + url: + en_US: https://docs.dify.ai/tutorials/model-configuration/ollama +supported_model_types: + - llm + - text-embedding +configurate_methods: + - customizable-model +model_credential_schema: + model: + label: + en_US: Model Name + zh_Hans: 模型名称 + placeholder: + en_US: Enter your model name + zh_Hans: 输入模型名称 + credential_form_schemas: + - variable: base_url + label: + zh_Hans: 基础 URL + en_US: Base URL + type: text-input + required: true + placeholder: + zh_Hans: Ollama server 的基础 URL,例如 http://192.168.1.100:11434 + en_US: Base url of Ollama server, e.g. http://192.168.1.100:11434 + - variable: mode + show_on: + - variable: __model_type + value: llm + label: + zh_Hans: 模型类型 + en_US: Completion mode + type: select + required: true + default: chat + placeholder: + zh_Hans: 选择对话类型 + en_US: Select completion mode + options: + - value: completion + label: + en_US: Completion + zh_Hans: 补全 + - value: chat + label: + en_US: Chat + zh_Hans: 对话 + - variable: context_size + label: + zh_Hans: 模型上下文长度 + en_US: Model context size + required: true + type: text-input + default: '4096' + placeholder: + zh_Hans: 在此输入您的模型上下文长度 + en_US: Enter your Model context size + - variable: max_tokens + label: + zh_Hans: 最大 token 上限 + en_US: Upper bound for max tokens + show_on: + - variable: __model_type + value: llm + default: '4096' + type: text-input + required: true + - variable: vision_support + label: + zh_Hans: 是否支持 Vision + en_US: Vision support + show_on: + - variable: __model_type + value: llm + default: 'false' + type: radio + required: false + options: + - value: 'true' + label: + en_US: 'Yes' + zh_Hans: 是 + - value: 'false' + label: + en_US: 'No' + zh_Hans: 否 + - variable: function_call_support + label: + zh_Hans: 是否支持函数调用 + en_US: Function call support + show_on: + - variable: __model_type + value: llm + default: 'false' + type: radio + required: false + options: + - value: 'true' + label: + en_US: 'Yes' + zh_Hans: 是 + - value: 'false' + label: + en_US: 'No' + zh_Hans: 否 +address: https://ollama.ai diff --git a/ai-provider/model-runtime/model-providers/volcengine_maas/assets/icon_l_en.svg b/ai-provider/model-runtime/model-providers/volcengine_maas/assets/icon_l_en.svg new file mode 100644 index 00000000..b34a9914 --- /dev/null +++ b/ai-provider/model-runtime/model-providers/volcengine_maas/assets/icon_l_en.svg @@ -0,0 +1,23 @@ + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/ai-provider/model-runtime/model-providers/volcengine_maas/assets/icon_l_zh.svg b/ai-provider/model-runtime/model-providers/volcengine_maas/assets/icon_l_zh.svg new file mode 100644 index 00000000..65808aa3 --- /dev/null +++ b/ai-provider/model-runtime/model-providers/volcengine_maas/assets/icon_l_zh.svg @@ -0,0 +1,39 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/ai-provider/model-runtime/model-providers/volcengine_maas/assets/icon_s_en.svg b/ai-provider/model-runtime/model-providers/volcengine_maas/assets/icon_s_en.svg new file mode 100644 index 00000000..5de409c0 --- /dev/null +++ b/ai-provider/model-runtime/model-providers/volcengine_maas/assets/icon_s_en.svg @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/ai-provider/model-runtime/model-providers/volcengine_maas/volcengine_maas.yaml b/ai-provider/model-runtime/model-providers/volcengine_maas/volcengine_maas.yaml new file mode 100644 index 00000000..1a970c63 --- /dev/null +++ b/ai-provider/model-runtime/model-providers/volcengine_maas/volcengine_maas.yaml @@ -0,0 +1,342 @@ +provider: volcengine_maas +label: + en_US: Volcengine +description: + en_US: Volcengine Ark models. + zh_Hans: 火山方舟提供的模型,例如 Doubao-pro-4k、Doubao-pro-32k 和 Doubao-pro-128k。 +icon_small: + en_US: icon_s_en.svg +icon_large: + en_US: icon_l_en.svg + zh_Hans: icon_l_zh.svg +background: "#F9FAFB" +help: + title: + en_US: Get your Access Key and Secret Access Key from Volcengine Console + zh_Hans: 从火山引擎控制台获取您的 Access Key 和 Secret Access Key + url: + en_US: https://console.volcengine.com/iam/keymanage/ +supported_model_types: + - llm + - text-embedding +configurate_methods: + - customizable-model +model_credential_schema: + model: + label: + en_US: Model Name + zh_Hans: 模型名称 + placeholder: + en_US: Enter your Model Name + zh_Hans: 输入模型名称 + credential_form_schemas: + - variable: auth_method + required: true + label: + en_US: Authentication Method + zh_Hans: 鉴权方式 + type: select + default: aksk + options: + - label: + en_US: API Key + value: api_key + - label: + en_US: Access Key / Secret Access Key + value: aksk + placeholder: + en_US: Enter your Authentication Method + zh_Hans: 选择鉴权方式 + - variable: volc_access_key_id + required: true + show_on: + - variable: auth_method + value: aksk + label: + en_US: Access Key + zh_Hans: Access Key + type: secret-input + placeholder: + en_US: Enter your Access Key + zh_Hans: 输入您的 Access Key + - variable: volc_secret_access_key + required: true + show_on: + - variable: auth_method + value: aksk + label: + en_US: Secret Access Key + zh_Hans: Secret Access Key + type: secret-input + placeholder: + en_US: Enter your Secret Access Key + zh_Hans: 输入您的 Secret Access Key + - variable: volc_api_key + required: true + show_on: + - variable: auth_method + value: api_key + label: + en_US: API Key + type: secret-input + placeholder: + en_US: Enter your API Key + zh_Hans: 输入您的 API Key + - variable: volc_region + required: true + label: + en_US: Volcengine Region + zh_Hans: 火山引擎地域 + type: text-input + default: cn-beijing + placeholder: + en_US: Enter Volcengine Region + zh_Hans: 输入火山引擎地域 + - variable: api_endpoint_host + required: true + label: + en_US: API Endpoint Host + zh_Hans: API Endpoint Host + type: text-input + default: https://ark.cn-beijing.volces.com/api/v3 + placeholder: + en_US: Enter your API Endpoint Host + zh_Hans: 输入 API Endpoint Host + - variable: endpoint_id + required: true + label: + en_US: Endpoint ID + zh_Hans: Endpoint ID + type: text-input + placeholder: + en_US: Enter your Endpoint ID + zh_Hans: 输入您的 Endpoint ID + - variable: base_model_name + label: + en_US: Base Model + zh_Hans: 基础模型 + type: select + required: true + options: + - label: + en_US: DeepSeek-R1-Distill-Qwen-32B + value: DeepSeek-R1-Distill-Qwen-32B + show_on: + - variable: __model_type + value: llm + - label: + en_US: DeepSeek-R1-Distill-Qwen-7B + value: DeepSeek-R1-Distill-Qwen-7B + show_on: + - variable: __model_type + value: llm + - label: + en_US: DeepSeek-R1 + value: DeepSeek-R1 + show_on: + - variable: __model_type + value: llm + - label: + en_US: DeepSeek-V3 + value: DeepSeek-V3 + show_on: + - variable: __model_type + value: llm + - label: + en_US: Doubao-1.5-vision-pro-32k + value: Doubao-1.5-vision-pro-32k + show_on: + - variable: __model_type + value: llm + - label: + en_US: Doubao-1.5-pro-32k + value: Doubao-1.5-pro-32k + show_on: + - variable: __model_type + value: llm + - label: + en_US: Doubao-1.5-lite-32k + value: Doubao-1.5-lite-32k + show_on: + - variable: __model_type + value: llm + - label: + en_US: Doubao-1.5-pro-256k + value: Doubao-1.5-pro-256k + show_on: + - variable: __model_type + value: llm + - label: + en_US: Doubao-vision-pro-32k + value: Doubao-vision-pro-32k + show_on: + - variable: __model_type + value: llm + - label: + en_US: Doubao-vision-lite-32k + value: Doubao-vision-lite-32k + show_on: + - variable: __model_type + value: llm + - label: + en_US: Doubao-pro-4k + value: Doubao-pro-4k + show_on: + - variable: __model_type + value: llm + - label: + en_US: Doubao-lite-4k + value: Doubao-lite-4k + show_on: + - variable: __model_type + value: llm + - label: + en_US: Doubao-pro-32k + value: Doubao-pro-32k + show_on: + - variable: __model_type + value: llm + - label: + en_US: Doubao-lite-32k + value: Doubao-lite-32k + show_on: + - variable: __model_type + value: llm + - label: + en_US: Doubao-pro-128k + value: Doubao-pro-128k + show_on: + - variable: __model_type + value: llm + - label: + en_US: Doubao-lite-128k + value: Doubao-lite-128k + show_on: + - variable: __model_type + value: llm + - label: + en_US: Doubao-pro-256k + value: Doubao-pro-256k + show_on: + - variable: __model_type + value: llm + - label: + en_US: Llama3-8B + value: Llama3-8B + show_on: + - variable: __model_type + value: llm + - label: + en_US: Llama3-70B + value: Llama3-70B + show_on: + - variable: __model_type + value: llm + - label: + en_US: Moonshot-v1-8k + value: Moonshot-v1-8k + show_on: + - variable: __model_type + value: llm + - label: + en_US: Moonshot-v1-32k + value: Moonshot-v1-32k + show_on: + - variable: __model_type + value: llm + - label: + en_US: Moonshot-v1-128k + value: Moonshot-v1-128k + show_on: + - variable: __model_type + value: llm + - label: + en_US: GLM3-130B + value: GLM3-130B + show_on: + - variable: __model_type + value: llm + - label: + en_US: GLM3-130B-Fin + value: GLM3-130B-Fin + show_on: + - variable: __model_type + value: llm + - label: + en_US: Mistral-7B + value: Mistral-7B + show_on: + - variable: __model_type + value: llm + - label: + en_US: Doubao-embedding + value: Doubao-embedding + show_on: + - variable: __model_type + value: text-embedding + - label: + en_US: Doubao-embedding-large + value: Doubao-embedding-large + show_on: + - variable: __model_type + value: text-embedding + - label: + en_US: Custom + zh_Hans: 自定义 + value: Custom + - variable: mode + required: true + show_on: + - variable: __model_type + value: llm + - variable: base_model_name + value: Custom + label: + zh_Hans: 模型类型 + en_US: Completion Mode + type: select + default: chat + placeholder: + zh_Hans: 选择对话类型 + en_US: Select Completion Mode + options: + - value: completion + label: + en_US: Completion + zh_Hans: 补全 + - value: chat + label: + en_US: Chat + zh_Hans: 对话 + - variable: context_size + required: true + show_on: + - variable: base_model_name + value: Custom + label: + zh_Hans: 模型上下文长度 + en_US: Model Context Size + type: text-input + default: "4096" + placeholder: + zh_Hans: 输入您的模型上下文长度 + en_US: Enter your Model Context Size + - variable: max_tokens + required: true + show_on: + - variable: __model_type + value: llm + - variable: base_model_name + value: Custom + label: + zh_Hans: 最大 token 上限 + en_US: Upper Bound for Max Tokens + default: "4096" + type: text-input + placeholder: + zh_Hans: 输入您的模型最大 token 上限 + en_US: Enter your model Upper Bound for Max Tokens +address: https://open.volcengine.com +model_config: + access_configuration_status: true + access_configuration_demo: "{\"endpoint\": \"https://196.1.1.2:3824\"}" \ No newline at end of file diff --git a/ai-provider/model-runtime/model-providers/xinference/assets/icon_l_en.svg b/ai-provider/model-runtime/model-providers/xinference/assets/icon_l_en.svg new file mode 100644 index 00000000..5c601c8a --- /dev/null +++ b/ai-provider/model-runtime/model-providers/xinference/assets/icon_l_en.svg @@ -0,0 +1,42 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/ai-provider/model-runtime/model-providers/xinference/assets/icon_s_en.svg b/ai-provider/model-runtime/model-providers/xinference/assets/icon_s_en.svg new file mode 100644 index 00000000..efe03479 --- /dev/null +++ b/ai-provider/model-runtime/model-providers/xinference/assets/icon_s_en.svg @@ -0,0 +1,24 @@ + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/ai-provider/model-runtime/model-providers/xinference/xinference.yaml b/ai-provider/model-runtime/model-providers/xinference/xinference.yaml new file mode 100644 index 00000000..6d284644 --- /dev/null +++ b/ai-provider/model-runtime/model-providers/xinference/xinference.yaml @@ -0,0 +1,79 @@ +provider: xinference +label: + en_US: Xorbits Inference +icon_small: + en_US: icon_s_en.svg +icon_large: + en_US: icon_l_en.svg +background: "#FAF5FF" +help: + title: + en_US: How to deploy Xinference + zh_Hans: 如何部署 Xinference + url: + en_US: https://github.com/xorbitsai/inference +supported_model_types: + - llm + - text-embedding + - rerank + - speech2text + - tts +configurate_methods: + - customizable-model +model_credential_schema: + model: + label: + en_US: Model Name + zh_Hans: 模型名称 + placeholder: + en_US: Enter your model name + zh_Hans: 输入模型名称 + credential_form_schemas: + - variable: server_url + label: + zh_Hans: 服务器URL + en_US: Server url + type: secret-input + required: true + placeholder: + zh_Hans: 在此输入Xinference的服务器地址,如 http://192.168.1.100:9997 + en_US: Enter the url of your Xinference, e.g. http://192.168.1.100:9997 + - variable: model_uid + label: + zh_Hans: 模型UID + en_US: Model uid + type: text-input + required: true + placeholder: + zh_Hans: 在此输入您的Model UID + en_US: Enter the model uid + - variable: api_key + label: + zh_Hans: API密钥 + en_US: API key + type: secret-input + required: false + placeholder: + zh_Hans: 在此输入您的API密钥 + en_US: Enter the api key + - variable: invoke_timeout + label: + zh_Hans: 调用超时时间 (单位:秒) + en_US: invoke timeout (unit:second) + type: text-input + required: true + default: '60' + placeholder: + zh_Hans: 在此输入调用超时时间 + en_US: Enter invoke timeout value + - variable: max_retries + label: + zh_Hans: 调用重试次数 + en_US: max retries + type: text-input + required: true + default: '3' + placeholder: + zh_Hans: 在此输入调用重试次数 + en_US: Enter max retries +address: https://xinference.ai \ No newline at end of file diff --git a/ai-provider/model-runtime/model.go b/ai-provider/model-runtime/model.go index b1d0e235..ecdfcd81 100644 --- a/ai-provider/model-runtime/model.go +++ b/ai-provider/model-runtime/model.go @@ -3,32 +3,84 @@ package model_runtime import ( "encoding/json" "github.com/APIParkLab/APIPark/ai-provider/model-runtime/entity" + "github.com/APIParkLab/APIPark/common" "gopkg.in/yaml.v3" "strconv" ) type IModel interface { ID() string + Name() string Logo() string + Source() string + SetLogo(logo string) + AccessConfiguration() string + ModelParameters() string IConfig } type Model struct { - id string - logo string + id string + logo string + name string + accessConfiguration string + modelParameters string + // default: ""/"system", "customize" + source string //defaultConfig string IConfig //validator IParamValidator } +func (m *Model) SetLogo(logo string) { + m.logo = logo +} + +func (m *Model) Name() string { + return m.name +} + +type CustomizeProviderConfig struct { + ApiEndpointUrl string `json:"api_endpoint_url"` + ApiKey string `json:"api_key"` +} + func (m *Model) ID() string { return m.id } +func (m *Model) Source() string { + return m.source +} + func (m *Model) Logo() string { return m.logo } +func (m *Model) AccessConfiguration() string { + return m.accessConfiguration +} + +func (m *Model) ModelParameters() string { + return m.modelParameters +} + +func NewCustomizeModel(id string, name string, logo string, accessConfiguration string, modelParameters string) (IModel, error) { + if logo == "" { + logo = GetCustomizeLogo() + } + // handle access_config & model_config + config := common.MergeJSON(accessConfiguration, modelParameters) + return &Model{ + id: id, + name: name, + logo: logo, + source: "customize", + accessConfiguration: accessConfiguration, + IConfig: NewConfig(config, nil), + }, nil +} + func NewModel(data string, logo string) (IModel, error) { var cfg entity.AIModel err := yaml.Unmarshal([]byte(data), &cfg) @@ -100,8 +152,10 @@ func NewModel(data string, logo string) (IModel, error) { return nil, err } return &Model{ - id: cfg.Model, - logo: logo, - IConfig: NewConfig(string(dCfg), params), + id: cfg.Model, + name: cfg.Model, + logo: logo, + accessConfiguration: "", + IConfig: NewConfig(string(dCfg), params), }, nil } diff --git a/ai-provider/model-runtime/provider.go b/ai-provider/model-runtime/provider.go index 666fc11f..dde582ea 100644 --- a/ai-provider/model-runtime/provider.go +++ b/ai-provider/model-runtime/provider.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "net/url" + "strings" yaml "gopkg.in/yaml.v3" @@ -17,6 +18,10 @@ const ( type IProvider interface { IProviderInfo + GetModelConfig() ModelConfig + SetModelsByType(modelType string, models []IModel) + SetModel(id string, model IModel) + SetDefaultModel(modelType string, model IModel) GetModel(name string) (IModel, bool) Models() []IModel ModelsByType(modelType string) ([]IModel, bool) @@ -41,6 +46,58 @@ type IProviderInfo interface { URI() IProviderURI } +func GetCustomizeLogo() string { + logo, _ := providerDir.ReadFile("customize/assets/icon_s_en.svg") + + return string(logo) +} + +func NewCustomizeProvider(id string, name string, models []IModel, defaultModel string, config string) (IProvider, error) { + var providerCfg CustomizeProviderConfig + if strings.TrimSpace(config) != "" { + err := json.Unmarshal([]byte(config), &providerCfg) + if err != nil { + return nil, err + } + } + uri, err := newProviderUri(providerCfg.ApiEndpointUrl) + if err != nil { + return nil, err + } + + provider := &Provider{ + id: id, + name: name, + logo: GetCustomizeLogo(), + helpUrl: "", + models: eosc.BuildUntyped[string, IModel](), + defaultModels: eosc.BuildUntyped[string, IModel](), + modelsByType: eosc.BuildUntyped[string, []IModel](), + maskKeys: make([]string, 0), + recommend: false, + sort: 0, + uri: uri, + modelConfig: ModelConfig{ + AccessConfigurationStatus: false, + AccessConfigurationDemo: "", + }, + } + provider.IConfig = NewConfig("", nil) + + for _, model := range models { + provider.SetModel(model.ID(), model) + if defaultModel == "" { + defaultModel = model.ID() + } + if model.ID() == defaultModel { + provider.SetDefaultModel(name, model) + } + } + provider.SetModelsByType(ModelTypeLLM, models) + + return provider, nil +} + func NewProvider(providerData string, modelContents map[string]eosc.Untyped[string, string]) (IProvider, error) { var providerCfg entity.Provider err := yaml.Unmarshal([]byte(providerData), &providerCfg) @@ -77,6 +134,10 @@ func NewProvider(providerData string, modelContents map[string]eosc.Untyped[stri recommend: providerCfg.Recommend, sort: providerCfg.Sort, uri: uri, + modelConfig: ModelConfig{ + AccessConfigurationStatus: providerCfg.ModelConfig.AccessConfigurationStatus, + AccessConfigurationDemo: providerCfg.ModelConfig.AccessConfigurationDemo, + }, } defaultCfg := make(map[string]string) params := make(ParamValidator, 0, len(providerCfg.ProviderCredentialSchema.CredentialFormSchemas)) @@ -132,9 +193,19 @@ type Provider struct { uri IProviderURI sort int recommend bool + modelConfig ModelConfig IConfig } +type ModelConfig struct { + AccessConfigurationStatus bool + AccessConfigurationDemo string +} + +func (p *Provider) GetModelConfig() ModelConfig { + return p.modelConfig +} + func (p *Provider) Sort() int { return p.sort } @@ -202,6 +273,10 @@ func (p *Provider) SetModel(id string, model IModel) { p.models.Set(id, model) } +func (p *Provider) RemoveModel(id string) { + p.models.Del(id) +} + func (p *Provider) SetModelsByType(modelType string, models []IModel) { p.modelsByType.Set(modelType, models) } diff --git a/common/common.go b/common/common.go new file mode 100644 index 00000000..dfc001d8 --- /dev/null +++ b/common/common.go @@ -0,0 +1,36 @@ +package common + +import ( + "encoding/json" + "strings" +) + +func MergeJSON(json1, json2 string) string { + var data1, data2 map[string]interface{} + if strings.TrimSpace(json1) != "" { + if err := json.Unmarshal([]byte(json1), &data1); err != nil { + return "" + } + } + if strings.TrimSpace(json2) != "" { + if err := json.Unmarshal([]byte(json2), &data2); err != nil { + return "" + } + } + + merged := make(map[string]interface{}) + // copy data1 + for k, v := range data1 { + merged[k] = v + } + // merge data2 & cover same key + for k, v := range data2 { + merged[k] = v + } + // transfer to json string + result, err := json.Marshal(merged) + if err != nil { + return "" + } + return string(result) +} diff --git a/controller/ai-model/controller.go b/controller/ai-model/controller.go new file mode 100644 index 00000000..84cefe5e --- /dev/null +++ b/controller/ai-model/controller.go @@ -0,0 +1,20 @@ +package ai_model + +import ( + model_dto "github.com/APIParkLab/APIPark/module/ai-model/dto" + "github.com/eolinker/go-common/autowire" + "github.com/gin-gonic/gin" + "reflect" +) + +type IProviderModelController interface { + AddProviderModel(ctx *gin.Context, provider string, input *model_dto.Model) (*model_dto.SimpleModel, error) + UpdateProviderModel(ctx *gin.Context, provider string, input *model_dto.EditModel) error + DeleteProviderModel(ctx *gin.Context, provider string, id string) error +} + +func init() { + autowire.Auto[IProviderModelController](func() reflect.Value { + return reflect.ValueOf(&imlProviderModelController{}) + }) +} diff --git a/controller/ai-model/iml.go b/controller/ai-model/iml.go new file mode 100644 index 00000000..04d7341a --- /dev/null +++ b/controller/ai-model/iml.go @@ -0,0 +1,67 @@ +package ai_model + +import ( + "encoding/json" + "fmt" + ai_model "github.com/APIParkLab/APIPark/module/ai-model" + model_dto "github.com/APIParkLab/APIPark/module/ai-model/dto" + "github.com/gin-gonic/gin" + "strings" +) + +var ( + _ IProviderModelController = (*imlProviderModelController)(nil) +) + +type imlProviderModelController struct { + module ai_model.IProviderModelModule `autowired:""` +} + +func (i *imlProviderModelController) UpdateProviderModel(ctx *gin.Context, provider string, input *model_dto.EditModel) error { + if strings.TrimSpace(input.Name) == "" { + return fmt.Errorf("name is empty") + } + if strings.TrimSpace(input.Id) == "" { + return fmt.Errorf("id is empty") + } + if strings.TrimSpace(provider) == "" { + return fmt.Errorf("provider is empty") + } + // check access config & model parameter is json format + if strings.TrimSpace(input.AccessConfiguration) != "" && !json.Valid([]byte(input.AccessConfiguration)) { + return fmt.Errorf("access configuration is not json format") + } + if strings.TrimSpace(input.ModelParameters) != "" && !json.Valid([]byte(input.ModelParameters)) { + return fmt.Errorf("model parameters is not json format") + } + + return i.module.UpdateProviderModel(ctx, provider, input) +} + +func (i *imlProviderModelController) DeleteProviderModel(ctx *gin.Context, provider string, id string) error { + if strings.TrimSpace(id) == "" { + return fmt.Errorf("id is empty") + } + if strings.TrimSpace(provider) == "" { + return fmt.Errorf("provider is empty") + } + + return i.module.DeleteProviderModel(ctx, provider, id) +} + +func (i *imlProviderModelController) AddProviderModel(ctx *gin.Context, provider string, input *model_dto.Model) (*model_dto.SimpleModel, error) { + if strings.TrimSpace(input.Name) == "" { + return nil, fmt.Errorf("name is empty") + } + if strings.TrimSpace(provider) == "" { + return nil, fmt.Errorf("provider is empty") + } + // check access config & model parameter is json format + if strings.TrimSpace(input.AccessConfiguration) != "" && !json.Valid([]byte(input.AccessConfiguration)) { + return nil, fmt.Errorf("access configuration is not json format") + } + if strings.TrimSpace(input.ModelParameters) != "" && !json.Valid([]byte(input.ModelParameters)) { + return nil, fmt.Errorf("model parameters is not json format") + } + return i.module.AddProviderModel(ctx, provider, input) +} diff --git a/controller/ai/controller.go b/controller/ai/controller.go index e4b2a5a5..f6eb5689 100644 --- a/controller/ai/controller.go +++ b/controller/ai/controller.go @@ -22,6 +22,7 @@ type IProviderController interface { UpdateProviderDefaultLLM(ctx *gin.Context, id string, input *ai_dto.UpdateLLM) error Delete(ctx *gin.Context, id string) error //Sort(ctx *gin.Context, input *ai_dto.Sort) error + AddProvider(ctx *gin.Context, input *ai_dto.NewProvider) (*ai_dto.SimpleProvider, error) } type IStatisticController interface { diff --git a/controller/ai/iml.go b/controller/ai/iml.go index c7eee3b9..ae37239b 100644 --- a/controller/ai/iml.go +++ b/controller/ai/iml.go @@ -2,7 +2,9 @@ package ai import ( "encoding/json" + "fmt" "strconv" + "strings" "github.com/APIParkLab/APIPark/module/ai" ai_dto "github.com/APIParkLab/APIPark/module/ai/dto" @@ -21,6 +23,13 @@ func (i *imlProviderController) Delete(ctx *gin.Context, id string) error { return i.module.Delete(ctx, id) } +func (i *imlProviderController) AddProvider(ctx *gin.Context, input *ai_dto.NewProvider) (*ai_dto.SimpleProvider, error) { + if strings.TrimSpace(input.Name) == "" { + return nil, fmt.Errorf("name is empty") + } + return i.module.AddProvider(ctx, input) +} + //func (i *imlProviderController) Sort(ctx *gin.Context, input *ai_dto.Sort) error { // return i.module.Sort(ctx, input) //} @@ -67,6 +76,9 @@ func (i *imlProviderController) Disable(ctx *gin.Context, id string) error { } func (i *imlProviderController) UpdateProviderConfig(ctx *gin.Context, id string, input *ai_dto.UpdateConfig) error { + if strings.TrimSpace(id) == "" { + return fmt.Errorf("id is empty") + } return i.module.UpdateProviderConfig(ctx, id, input) } diff --git a/go.sum b/go.sum index 0322d23f..8e7dcdc0 100644 --- a/go.sum +++ b/go.sum @@ -40,7 +40,6 @@ github.com/getkin/kin-openapi v0.127.0 h1:Mghqi3Dhryf3F8vR370nN67pAERW+3a95vomb3 github.com/getkin/kin-openapi v0.127.0/go.mod h1:OZrfXzUfGrNbsKj+xmFBx6E5c6yH3At/tAKSc2UszXM= github.com/ghodss/yaml v1.0.0 h1:wQHKEahhL6wmXdzwWG11gIVCkOv05bNOh+Rxn0yngAk= github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= -github.com/gin-contrib/gzip v1.0.1 h1:HQ8ENHODeLY7a4g1Au/46Z92bdGFl74OhxcZble9WJE= github.com/gin-contrib/gzip v1.0.1/go.mod h1:njt428fdUNRvjuJf16tZMYZ2Yl+WQB53X5wmhDwXvC4= github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= diff --git a/module/ai-model/dto/input.go b/module/ai-model/dto/input.go new file mode 100644 index 00000000..97f554e8 --- /dev/null +++ b/module/ai-model/dto/input.go @@ -0,0 +1,12 @@ +package model_dto + +type Model struct { + Name string `json:"name"` + AccessConfiguration string `json:"access_configuration"` + ModelParameters string `json:"model_parameters"` +} + +type EditModel struct { + Id string `json:"id"` + Model +} diff --git a/module/ai-model/dto/output.go b/module/ai-model/dto/output.go new file mode 100644 index 00000000..936845fc --- /dev/null +++ b/module/ai-model/dto/output.go @@ -0,0 +1,6 @@ +package model_dto + +type SimpleModel struct { + Id string `json:"id"` + Name string `json:"name"` +} diff --git a/module/ai-model/iml.go b/module/ai-model/iml.go new file mode 100644 index 00000000..2196b36e --- /dev/null +++ b/module/ai-model/iml.go @@ -0,0 +1,111 @@ +package ai_model + +import ( + "fmt" + model_runtime "github.com/APIParkLab/APIPark/ai-provider/model-runtime" + model_dto "github.com/APIParkLab/APIPark/module/ai-model/dto" + "github.com/APIParkLab/APIPark/service/ai" + ai_model "github.com/APIParkLab/APIPark/service/ai-model" + "github.com/gin-gonic/gin" + "github.com/google/uuid" + + "github.com/eolinker/go-common/store" +) + +var ( + _ IProviderModelModule = (*imlProviderModelModule)(nil) +) + +type imlProviderModelModule struct { + providerService ai.IProviderService `autowired:""` + providerModelService ai_model.IProviderModelService `autowired:""` + transaction store.ITransaction `autowired:""` +} + +func (i imlProviderModelModule) UpdateProviderModel(ctx *gin.Context, provider string, input *model_dto.EditModel) error { + p, has := model_runtime.GetProvider(provider) + if !has { + return fmt.Errorf("ai provider not found") + } + // check provider exist + providerInfo, err := i.providerService.Get(ctx, provider) + if err != nil { + return err + } + if providerInfo == nil { + return fmt.Errorf("provider not found") + } + modelInfo, _ := i.providerModelService.Get(ctx, input.Id) + if modelInfo == nil || modelInfo.Provider != provider { + return fmt.Errorf("model not found") + } + // check model name duplicate + if has := i.providerModelService.CheckNameDuplicate(ctx, provider, input.Name, input.Id); has { + return fmt.Errorf("model name: `%s` duplicate", input.Name) + } + if err := i.providerModelService.Save(ctx, input.Id, &ai_model.Model{ + Name: &input.Name, + AccessConfiguration: &input.AccessConfiguration, + ModelParameters: &input.ModelParameters, + }); err != nil { + return err + } + // update provider model + iModel, _ := model_runtime.NewCustomizeModel(input.Id, input.Name, p.Logo(), input.AccessConfiguration, input.ModelParameters) + p.SetModel(input.Id, iModel) + + return nil +} + +func (i imlProviderModelModule) DeleteProviderModel(ctx *gin.Context, provider string, id string) error { + // check provider exist + providerInfo, err := i.providerService.Get(ctx, provider) + if err != nil { + return err + } + if providerInfo == nil { + return fmt.Errorf("provider not found") + } + modelInfo, _ := i.providerModelService.Get(ctx, id) + if modelInfo == nil || modelInfo.Provider != provider { + return fmt.Errorf("model not found") + } + return i.providerModelService.Delete(ctx, id) +} + +func (i imlProviderModelModule) AddProviderModel(ctx *gin.Context, provider string, input *model_dto.Model) (*model_dto.SimpleModel, error) { + p, has := model_runtime.GetProvider(provider) + if !has { + return nil, fmt.Errorf("ai provider not found") + } + // check provider exist + providerInfo, err := i.providerService.Get(ctx, provider) + if err != nil { + return nil, err + } + if providerInfo == nil { + return nil, fmt.Errorf("provider not found") + } + // check model name duplicate + if has := i.providerModelService.CheckNameDuplicate(ctx, provider, input.Name, ""); has { + return nil, fmt.Errorf("model name: `%s` duplicate", input.Name) + } + id := uuid.New().String() + typeValue := "chat" + if err := i.providerModelService.Save(ctx, id, &ai_model.Model{ + Name: &input.Name, + Type: &typeValue, + Provider: &provider, + AccessConfiguration: &input.AccessConfiguration, + ModelParameters: &input.ModelParameters, + }); err != nil { + return nil, err + } + // update provider model + iModel, _ := model_runtime.NewCustomizeModel(id, input.Name, p.Logo(), input.AccessConfiguration, input.ModelParameters) + p.SetModel(id, iModel) + return &model_dto.SimpleModel{ + Id: id, + Name: input.Name, + }, nil +} diff --git a/module/ai-model/module.go b/module/ai-model/module.go new file mode 100644 index 00000000..6f9f22d7 --- /dev/null +++ b/module/ai-model/module.go @@ -0,0 +1,22 @@ +package ai_model + +import ( + model_dto "github.com/APIParkLab/APIPark/module/ai-model/dto" + "github.com/gin-gonic/gin" + "reflect" + + "github.com/eolinker/go-common/autowire" +) + +type IProviderModelModule interface { + AddProviderModel(ctx *gin.Context, provider string, input *model_dto.Model) (*model_dto.SimpleModel, error) + UpdateProviderModel(ctx *gin.Context, provider string, input *model_dto.EditModel) error + DeleteProviderModel(ctx *gin.Context, provider string, id string) error +} + +func init() { + autowire.Auto[IProviderModelModule](func() reflect.Value { + module := new(imlProviderModelModule) + return reflect.ValueOf(module) + }) +} diff --git a/module/ai/dto/input.go b/module/ai/dto/input.go index 93a112f8..adeb7286 100644 --- a/module/ai/dto/input.go +++ b/module/ai/dto/input.go @@ -14,3 +14,7 @@ type UpdateConfig struct { type Sort struct { Providers []string `json:"providers"` } + +type NewProvider struct { + Name string `json:"name"` +} diff --git a/module/ai/dto/output.go b/module/ai/dto/output.go index 27d07eae..263b09ad 100644 --- a/module/ai/dto/output.go +++ b/module/ai/dto/output.go @@ -4,6 +4,11 @@ import ( "github.com/eolinker/go-common/auto" ) +type SimpleModel struct { + Id string `json:"id"` + Name string `json:"name"` +} + type SimpleProvider struct { Id string `json:"id"` Name string `json:"name"` @@ -20,8 +25,14 @@ type Provider struct { DefaultLLM string `json:"default_llm"` DefaultLLMConfig string `json:"-"` //Priority int `json:"priority"` - Status ProviderStatus `json:"status"` - Configured bool `json:"configured"` + Status ProviderStatus `json:"status"` + Configured bool `json:"configured"` + ModelConfig ModelConfig `json:"model_config"` +} + +type ModelConfig struct { + AccessConfigurationStatus bool `json:"access_configuration_status"` + AccessConfigurationDemo string `json:"access_configuration_demo"` } type ConfiguredProviderItem struct { @@ -32,6 +43,7 @@ type ConfiguredProviderItem struct { Status ProviderStatus `json:"status"` APICount int64 `json:"api_count"` KeyCount int64 `json:"key_count"` + ModelCount int64 `json:"model_count"` CanDelete bool `json:"can_delete"` } @@ -48,6 +60,7 @@ type ProviderItem struct { Logo string `json:"logo"` DefaultLLM string `json:"default_llm"` Sort int `json:"-"` + Type int `json:"type"` // 0:default 1:customize } type SimpleProviderItem struct { @@ -69,10 +82,15 @@ type BackupProvider struct { } type LLMItem struct { - Id string `json:"id"` - Logo string `json:"logo"` - Config string `json:"config"` - Scopes []string `json:"scopes"` + Id string `json:"id"` + Logo string `json:"logo"` + Config string `json:"config"` + AccessConfiguration string `json:"access_configuration"` + ModelParameters string `json:"model_parameters"` + Scopes []string `json:"scopes"` + Type string `json:"type"` + IsSystem bool `json:"is_system"` + ApiCount int64 `json:"api_count"` } type APIItem struct { diff --git a/module/ai/iml.go b/module/ai/iml.go index d4c06d89..658cf0f0 100644 --- a/module/ai/iml.go +++ b/module/ai/iml.go @@ -4,12 +4,12 @@ import ( "context" "errors" "fmt" + ai_model "github.com/APIParkLab/APIPark/service/ai-model" + "github.com/google/uuid" "net/http" "sort" "time" - "github.com/google/uuid" - ai_provider_local "github.com/APIParkLab/APIPark/ai-provider/local" "github.com/eolinker/go-common/register" @@ -63,12 +63,13 @@ func newKey(key *ai_key.Key) *gateway.DynamicRelease { var _ IProviderModule = (*imlProviderModule)(nil) type imlProviderModule struct { - providerService ai.IProviderService `autowired:""` - clusterService cluster.IClusterService `autowired:""` - aiAPIService ai_api.IAPIService `autowired:""` - aiKeyService ai_key.IKeyService `autowired:""` - aiBalanceService ai_balance.IBalanceService `autowired:""` - transaction store.ITransaction `autowired:""` + providerService ai.IProviderService `autowired:""` + providerModelService ai_model.IProviderModelService `autowired:""` + clusterService cluster.IClusterService `autowired:""` + aiAPIService ai_api.IAPIService `autowired:""` + aiKeyService ai_key.IKeyService `autowired:""` + aiBalanceService ai_balance.IBalanceService `autowired:""` + transaction store.ITransaction `autowired:""` } func (i *imlProviderModule) OnInit() { @@ -79,6 +80,33 @@ func (i *imlProviderModule) OnInit() { if err != nil { return } + // register provider + for _, p := range list { + // get customize models + models, _ := i.providerModelService.Search(ctx, "", map[string]interface{}{"provider": p.Id}, "update_at desc") + iModels := make([]model_runtime.IModel, 0, len(models)) + if models != nil { + for _, model := range models { + // parse access_config & model_parameters + iModel, _ := model_runtime.NewCustomizeModel(model.Id, model.Name, model_runtime.GetCustomizeLogo(), model.AccessConfiguration, model.ModelParameters) + iModels = append(iModels, iModel) + } + } + // default provider + if p.Type == 0 { + runtimeProvider, _ := model_runtime.GetProvider(p.Id) + for _, tmpIModel := range iModels { + tmpIModel.SetLogo(runtimeProvider.Logo()) + if p.DefaultLLM == tmpIModel.ID() { + runtimeProvider.SetDefaultModel(model_runtime.ModelTypeLLM, tmpIModel) + } + runtimeProvider.SetModel(tmpIModel.ID(), tmpIModel) + } + } else { + provider, _ := model_runtime.NewCustomizeProvider(p.Id, p.Name, iModels, p.DefaultLLM, p.Config) + model_runtime.Register(p.Id, provider) + } + } i.transaction.Transaction(ctx, func(ctx context.Context) error { for _, l := range list { if l.Priority < 1 { @@ -145,6 +173,8 @@ func (i *imlProviderModule) Delete(ctx context.Context, id string) error { if err != nil { return err } + // delete register provider + model_runtime.Remove(id) releases := make([]*gateway.DynamicRelease, 0, len(keys)) for _, key := range keys { releases = append(releases, newKey(key)) @@ -164,6 +194,30 @@ func (i *imlProviderModule) Delete(ctx context.Context, id string) error { }) } +func (i *imlProviderModule) AddProvider(ctx context.Context, input *ai_dto.NewProvider) (*ai_dto.SimpleProvider, error) { + if has := i.providerService.CheckNameDuplicate(ctx, input.Name); has { + return nil, fmt.Errorf("provider `%s` duplicate", input.Name) + } + id := uuid.New().String() + config, defaultLLM := "{\"api_endpoint_url\": \"http://127.0.0.1\", \"api_key\": \"\"}", "" + typeValue := 1 + if err := i.providerService.Save(ctx, id, &ai.SetProvider{ + Name: &input.Name, + DefaultLLM: &defaultLLM, + Config: &config, + Type: &typeValue, + }); err != nil { + return nil, err + } + // register provider + iProvider, _ := model_runtime.NewCustomizeProvider(id, input.Name, []model_runtime.IModel{}, "", "") + model_runtime.Register(id, iProvider) + return &ai_dto.SimpleProvider{ + Id: id, + Name: input.Name, + }, nil +} + func (i *imlProviderModule) SimpleProvider(ctx context.Context, id string) (*ai_dto.SimpleProvider, error) { p, has := model_runtime.GetProvider(id) if !has { @@ -231,6 +285,7 @@ func (i *imlProviderModule) ConfiguredProviders(ctx context.Context, keyword str APICount: apiCount, KeyCount: keyMap[l.Id], CanDelete: apiCount < 1, + ModelCount: int64(len(p.Models())), }) } @@ -394,8 +449,9 @@ func (i *imlProviderModule) Provider(ctx context.Context, id string) (*ai_dto.Pr } defaultLLM, has := p.DefaultModel(model_runtime.ModelTypeLLM) if !has { - return nil, fmt.Errorf("ai provider llm not found") + defaultLLM, _ = model_runtime.NewCustomizeModel("", "", "", "", "") } + providerModelConfig := p.GetModelConfig() return &ai_dto.Provider{ Id: p.ID(), Name: p.Name(), @@ -405,13 +461,17 @@ func (i *imlProviderModule) Provider(ctx context.Context, id string) (*ai_dto.Pr DefaultLLMConfig: defaultLLM.Logo(), Status: ai_dto.ProviderDisabled, //Priority: maxPriority, + ModelConfig: ai_dto.ModelConfig{ + AccessConfigurationStatus: providerModelConfig.AccessConfigurationStatus, + AccessConfigurationDemo: providerModelConfig.AccessConfigurationDemo, + }, }, nil } defaultLLM, has := p.GetModel(info.DefaultLLM) if !has { model, has := p.DefaultModel(model_runtime.ModelTypeLLM) if !has { - return nil, fmt.Errorf("ai provider llm not found") + defaultLLM, _ = model_runtime.NewCustomizeModel("", "", "", "", "") } defaultLLM = model } @@ -426,6 +486,10 @@ func (i *imlProviderModule) Provider(ctx context.Context, id string) (*ai_dto.Pr //Priority: info.Priority, Status: ai_dto.ToProviderStatus(info.Status), Configured: true, + ModelConfig: ai_dto.ModelConfig{ + AccessConfigurationStatus: false, + AccessConfigurationDemo: "", + }, }, nil } @@ -439,16 +503,21 @@ func (i *imlProviderModule) LLMs(ctx context.Context, driver string) ([]*ai_dto. if !has { return nil, nil, fmt.Errorf("ai provider not found") } - + modelApiCountMap, _ := i.aiAPIService.CountMapByModel(ctx, "", map[string]interface{}{"provider": driver}) items := make([]*ai_dto.LLMItem, 0, len(llms)) for _, v := range llms { items = append(items, &ai_dto.LLMItem{ - Id: v.ID(), - Logo: v.Logo(), - Config: v.DefaultConfig(), + Id: v.ID(), + Logo: v.Logo(), + Config: v.DefaultConfig(), + AccessConfiguration: v.AccessConfiguration(), + ModelParameters: v.ModelParameters(), Scopes: []string{ "chat", }, + Type: "chat", + IsSystem: v.Source() != "customize", + ApiCount: modelApiCountMap[v.ID()], }) } info, err := i.providerService.Get(ctx, driver) @@ -523,6 +592,11 @@ func (i *imlProviderModule) UpdateProviderConfig(ctx context.Context, id string, if err != nil { return err } + if input.DefaultLLM != "" { + if defaultLLM, has := p.GetModel(input.DefaultLLM); has { + p.SetDefaultModel(model_runtime.ModelTypeLLM, defaultLLM) + } + } status := 0 if input.Enable != nil && *input.Enable { status = 1 diff --git a/module/ai/module.go b/module/ai/module.go index 650de0ff..f036b048 100644 --- a/module/ai/module.go +++ b/module/ai/module.go @@ -19,6 +19,7 @@ type IProviderModule interface { LLMs(ctx context.Context, driver string) ([]*ai_dto.LLMItem, *ai_dto.ProviderItem, error) UpdateProviderConfig(ctx context.Context, id string, input *ai_dto.UpdateConfig) error Delete(ctx context.Context, id string) error + AddProvider(ctx context.Context, input *ai_dto.NewProvider) (*ai_dto.SimpleProvider, error) } type IAIAPIModule interface { diff --git a/plugins/core/ai.go b/plugins/core/ai.go index 6bf7af69..2cc76219 100644 --- a/plugins/core/ai.go +++ b/plugins/core/ai.go @@ -19,6 +19,10 @@ func (p *plugin) aiAPIs() []pm3.Api { pm3.CreateApiWidthDoc(http.MethodGet, "/api/v1/ai/provider/llms", []string{"context", "query:provider"}, []string{"llms", "provider"}, p.aiProviderController.LLMs), pm3.CreateApiWidthDoc(http.MethodDelete, "/api/v1/ai/provider", []string{"context", "query:provider"}, nil, p.aiProviderController.Delete), pm3.CreateApiWidthDoc(http.MethodPut, "/api/v1/ai/provider/config", []string{"context", "query:provider", "body"}, nil, p.aiProviderController.UpdateProviderConfig, access.SystemSettingsAiProviderManager), + pm3.CreateApiWidthDoc(http.MethodPost, "/api/v1/ai/provider", []string{"context", "body"}, []string{"provider"}, p.aiProviderController.AddProvider, access.SystemSettingsAiProviderManager), + pm3.CreateApiWidthDoc(http.MethodPost, "/api/v1/ai/provider/model", []string{"context", "query:provider", "body"}, []string{"model"}, p.aiProviderModelController.AddProviderModel, access.SystemSettingsAiProviderManager), + pm3.CreateApiWidthDoc(http.MethodPut, "/api/v1/ai/provider/model", []string{"context", "query:provider", "body"}, nil, p.aiProviderModelController.UpdateProviderModel, access.SystemSettingsAiProviderManager), + pm3.CreateApiWidthDoc(http.MethodDelete, "/api/v1/ai/provider/model", []string{"context", "query:provider", "query:id"}, nil, p.aiProviderModelController.DeleteProviderModel, access.SystemSettingsAiProviderManager), pm3.CreateApiWidthDoc(http.MethodGet, "/api/v1/ai/apis", []string{"context", "query:keyword", "query:provider", "query:start", "query:end", "query:page", "query:page_size", "query:sort", "query:asc", "query:models", "query:services"}, []string{"apis", "condition", "total"}, p.aiStatisticController.APIs), } diff --git a/plugins/core/core.go b/plugins/core/core.go index be4e3f00..b3afb4b1 100644 --- a/plugins/core/core.go +++ b/plugins/core/core.go @@ -1,6 +1,7 @@ package core import ( + ai_model "github.com/APIParkLab/APIPark/controller/ai-model" "net/http" ai_balance "github.com/APIParkLab/APIPark/controller/ai-balance" @@ -97,6 +98,7 @@ type plugin struct { exportConfigController system.IExportConfigController `autowired:""` importConfigController system.IImportConfigController `autowired:""` aiProviderController ai.IProviderController `autowired:""` + aiProviderModelController ai_model.IProviderModelController `autowired:""` settingController system.ISettingController `autowired:""` initController system.IInitController `autowired:""` logController log.ILogController `autowired:""` diff --git a/service/ai-model/iml.go b/service/ai-model/iml.go new file mode 100644 index 00000000..7a594556 --- /dev/null +++ b/service/ai-model/iml.go @@ -0,0 +1,84 @@ +package ai_model + +import ( + "context" + "errors" + "github.com/APIParkLab/APIPark/service/universally" + "github.com/APIParkLab/APIPark/stores/ai" + "github.com/eolinker/go-common/utils" + "gorm.io/gorm" + "time" +) + +var _ IProviderModelService = (*imlProviderModelService)(nil) + +type imlProviderModelService struct { + universally.IServiceGet[ProviderModel] + universally.IServiceCreate[ProviderModel] + universally.IServiceDelete + store ai.IProviderModelStore `autowired:""` +} + +func (i *imlProviderModelService) CountMapByProvider(ctx context.Context, conditions map[string]interface{}) (map[string]int64, error) { + return i.store.CountByGroup(ctx, "", conditions, "provider") +} + +func (i *imlProviderModelService) Save(ctx context.Context, id string, model *Model) error { + userId := utils.UserId(ctx) + now := time.Now() + info, err := i.store.First(ctx, map[string]interface{}{"uuid": id}) + if err != nil { + if !errors.Is(err, gorm.ErrRecordNotFound) { + return err + } + if model.Name == nil || model.Provider == nil { + return errors.New("invalid params") + } + info = &ai.ProviderModel{ + UUID: id, + Name: *model.Name, + Type: *model.Type, + AccessConfiguration: *model.AccessConfiguration, + ModelParameters: *model.ModelParameters, + Provider: *model.Provider, + Creator: userId, + Updater: userId, + CreateAt: now, + UpdateAt: now, + } + } else { + if model.Name != nil { + info.Name = *model.Name + } + if model.Type != nil { + info.Type = *model.Type + } + if model.Provider != nil { + info.Provider = *model.Provider + } + if model.AccessConfiguration != nil { + info.AccessConfiguration = *model.AccessConfiguration + } + if model.ModelParameters != nil { + info.AccessConfiguration = *model.ModelParameters + } + info.Updater = userId + info.UpdateAt = now + } + return i.store.Save(ctx, info) +} + +func (i *imlProviderModelService) CheckNameDuplicate(ctx context.Context, provider string, name string, excludeId string) bool { + v, _ := i.store.First(ctx, map[string]interface{}{"provider": provider, "name": name}) + if v != nil { + return true + } else if excludeId != "" && v.UUID != excludeId { + return true + } + return false +} + +func (i *imlProviderModelService) OnComplete() { + i.IServiceGet = universally.NewGet[ProviderModel, ai.ProviderModel](i.store, FromEntity) + i.IServiceDelete = universally.NewDelete[ai.ProviderModel](i.store) +} diff --git a/service/ai-model/model.go b/service/ai-model/model.go new file mode 100644 index 00000000..96a5df41 --- /dev/null +++ b/service/ai-model/model.go @@ -0,0 +1,51 @@ +package ai_model + +import ( + "encoding/base64" + "github.com/APIParkLab/APIPark/stores/ai" + "time" +) + +type ProviderModel struct { + Id string // provider model:uuid + Name string + Type string + AccessConfiguration string + ModelParameters string + Provider string + Creator string + Updater string + CreateAt time.Time + UpdateAt time.Time +} + +func FromEntity(e *ai.ProviderModel) *ProviderModel { + accessConfiguration, err := base64.RawStdEncoding.DecodeString(e.AccessConfiguration) + modelParameters, err := base64.RawStdEncoding.DecodeString(e.ModelParameters) + if err != nil { + accessConfiguration = []byte(e.AccessConfiguration) + } + if err != nil { + modelParameters = []byte(e.ModelParameters) + } + return &ProviderModel{ + Id: e.UUID, + Name: e.Name, + Type: e.Type, + AccessConfiguration: string(accessConfiguration), + ModelParameters: string(modelParameters), + Provider: e.Provider, + Creator: e.Creator, + Updater: e.Updater, + CreateAt: e.CreateAt, + UpdateAt: e.UpdateAt, + } +} + +type Model struct { + Name *string + Provider *string + Type *string + AccessConfiguration *string + ModelParameters *string +} diff --git a/service/ai-model/service.go b/service/ai-model/service.go new file mode 100644 index 00000000..78c58cf4 --- /dev/null +++ b/service/ai-model/service.go @@ -0,0 +1,22 @@ +package ai_model + +import ( + "context" + "github.com/APIParkLab/APIPark/service/universally" + "github.com/eolinker/go-common/autowire" + "reflect" +) + +type IProviderModelService interface { + universally.IServiceGet[ProviderModel] + universally.IServiceDelete + CountMapByProvider(ctx context.Context, conditions map[string]interface{}) (map[string]int64, error) + Save(ctx context.Context, id string, cfg *Model) error + CheckNameDuplicate(ctx context.Context, provider string, name string, excludeId string) bool +} + +func init() { + autowire.Auto[IProviderModelService](func() reflect.Value { + return reflect.ValueOf(new(imlProviderModelService)) + }) +} diff --git a/service/ai/iml.go b/service/ai/iml.go index 6eda08a9..7e62c2d3 100644 --- a/service/ai/iml.go +++ b/service/ai/iml.go @@ -70,6 +70,14 @@ type imlProviderService struct { // return i.store.Save(ctx, info) //} +func (i *imlProviderService) CheckNameDuplicate(ctx context.Context, name string) bool { + v, _ := i.store.First(ctx, map[string]interface{}{"name": name}) + if v != nil { + return true + } + return false +} + func (i *imlProviderService) GetLabels(ctx context.Context, ids ...string) map[string]string { if len(ids) == 0 { return nil diff --git a/service/ai/model.go b/service/ai/model.go index 263a79fd..a10fe7a5 100644 --- a/service/ai/model.go +++ b/service/ai/model.go @@ -15,6 +15,7 @@ type Provider struct { Creator string Updater string Status int + Type int Priority int CreateAt time.Time UpdateAt time.Time @@ -34,6 +35,7 @@ type SetProvider struct { Config *string Priority *int Status *int + Type *int } func FromEntity(e *ai.Provider) *Provider { @@ -52,5 +54,6 @@ func FromEntity(e *ai.Provider) *Provider { UpdateAt: e.UpdateAt, Status: e.Status, Priority: e.Priority, + Type: e.Type, } } diff --git a/service/ai/service.go b/service/ai/service.go index c4201db6..c7eb0071 100644 --- a/service/ai/service.go +++ b/service/ai/service.go @@ -1,6 +1,7 @@ package ai import ( + "context" "reflect" "github.com/APIParkLab/APIPark/service/universally" @@ -14,6 +15,7 @@ type IProviderService interface { universally.IServiceDelete //Save(ctx context.Context, id string, cfg *SetProvider) error //MaxPriority(ctx context.Context) (int, error) + CheckNameDuplicate(ctx context.Context, name string) bool } func init() { diff --git a/stores/ai/model.go b/stores/ai/model.go index 6d607d16..0eb2788b 100644 --- a/stores/ai/model.go +++ b/stores/ai/model.go @@ -10,6 +10,7 @@ type Provider struct { Config string `gorm:"type:text;not null;column:config;comment:配置信息"` Status int `gorm:"type:tinyint(1);not null;column:status;comment:状态,0:停用;1:启用,2:异常;default:1"` Priority int `gorm:"type:int;not null;column:priority;comment:优先级,值越小优先级越高"` + Type int `gorm:"type:tinyint(1);not null;column:type;comment:type 0:default 1:customize"` Creator string `gorm:"size:36;not null;column:creator;comment:创建人;index:creator" aovalue:"creator"` // 创建人 Updater string `gorm:"size:36;not null;column:updater;comment:更新人;index:updater" aovalue:"updater"` // 更新人 CreateAt time.Time `gorm:"type:timestamp;NOT NULL;DEFAULT:CURRENT_TIMESTAMP;column:create_at;comment:创建时间"` @@ -164,3 +165,26 @@ func (i *LocalModelCache) TableName() string { func (i *LocalModelCache) IdValue() int64 { return i.Id } + +type ProviderModel struct { + Id int64 `gorm:"column:id;type:BIGINT(20);AUTO_INCREMENT;NOT NULL;comment:id;primary_key;comment:PRIMARY ID;"` + UUID string `gorm:"type:varchar(36);not null;column:uuid;uniqueIndex:uuid;comment:UUID;"` + Name string `gorm:"type:varchar(100);not null;column:name;comment:name;index:name"` + Type string `gorm:"type:varchar(100);not null;column:type;comment:type:chat"` + AccessConfiguration string `gorm:"type:text;not null;column:access_configuration;comment:access_configuration json"` + ModelParameters string `gorm:"type:text;not null;column:model_parameters;comment:model_parameters json"` + Provider string `gorm:"type:varchar(36);not null;column:provider;comment:ai_provider:uuid;index:provider"` + Creator string `gorm:"size:36;not null;column:creator;comment:creator;index:creator" aovalue:"creator"` + Updater string `gorm:"size:36;not null;column:updater;comment:updater;index:updater" aovalue:"updater"` + CreateAt time.Time `gorm:"type:timestamp;NOT NULL;DEFAULT:CURRENT_TIMESTAMP;column:create_at;comment:create_at"` + UpdateAt time.Time `gorm:"type:timestamp;NOT NULL;DEFAULT:CURRENT_TIMESTAMP;column:update_at;comment:update_at"` +} + +func (i *ProviderModel) TableName() string { + return "ai_provider_model" +} + +func (i *ProviderModel) IdValue() int64 { + return i.Id +} + diff --git a/stores/ai/store.go b/stores/ai/store.go index 8c462f42..3ab6a36e 100644 --- a/stores/ai/store.go +++ b/stores/ai/store.go @@ -71,6 +71,14 @@ type imlLocalModelCacheStore struct { store.Store[LocalModelCache] } +type IProviderModelStore interface { + store.ISearchStore[ProviderModel] +} + +type imlProviderModelStore struct { + store.SearchStore[ProviderModel] +} + func init() { autowire.Auto[IProviderStore](func() reflect.Value { return reflect.ValueOf(new(imlProviderStore)) @@ -103,4 +111,8 @@ func init() { autowire.Auto[ILocalModelCacheStore](func() reflect.Value { return reflect.ValueOf(new(imlLocalModelCacheStore)) }) + + autowire.Auto[IProviderModelStore](func() reflect.Value { + return reflect.ValueOf(new(imlProviderModelStore)) + }) }