mirror of
https://github.com/YFGaia/dify-plus.git
synced 2026-06-04 10:14:00 +08:00
1415 lines
49 KiB
Go
1415 lines
49 KiB
Go
package gaia
|
||
|
||
import (
|
||
"bufio"
|
||
"bytes"
|
||
"context"
|
||
"crypto/aes"
|
||
"crypto/rand"
|
||
"crypto/rsa"
|
||
"crypto/sha1"
|
||
"crypto/x509"
|
||
"encoding/base64"
|
||
"encoding/json"
|
||
"encoding/pem"
|
||
"errors"
|
||
"fmt"
|
||
"io"
|
||
"net/http"
|
||
"os"
|
||
"strings"
|
||
"sync"
|
||
"time"
|
||
|
||
"github.com/flipped-aurora/gin-vue-admin/server/global"
|
||
"github.com/flipped-aurora/gin-vue-admin/server/model/gaia"
|
||
gaiaRequest "github.com/flipped-aurora/gin-vue-admin/server/model/gaia/request"
|
||
gaiaResponse "github.com/flipped-aurora/gin-vue-admin/server/model/gaia/response"
|
||
"github.com/flipped-aurora/gin-vue-admin/server/model/system"
|
||
"github.com/flipped-aurora/gin-vue-admin/server/utils"
|
||
"go.gnd.pw/crypto/eax"
|
||
"go.uber.org/zap"
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
// fetchAdminToken 查询一个管理员用户,生成 Dify Console API 兼容的 JWT。
|
||
// 结果缓存到 Redis,TTL 50 分钟(JWT 有效期内复用,避免频繁生成)。
|
||
func (s *ModelProviderService) fetchAdminToken() (token string, err error) {
|
||
ctx := context.Background()
|
||
|
||
// 优先从 Redis 读取缓存
|
||
if cached, e := global.GVA_REDIS.Get(ctx, gaia.RedisKeyGaiaAdminConsoleToken).Result(); e == nil && cached != "" {
|
||
return cached, nil
|
||
}
|
||
|
||
// 查询一个活跃管理员
|
||
var adminUser system.SysUser
|
||
if err = global.GVA_DB.Where("authority_id = ? AND enable = ?",
|
||
system.AdminAuthorityId, system.UserActive).First(&adminUser).Error; err != nil {
|
||
return "", fmt.Errorf("找不到可用的管理员账号:%w", err)
|
||
}
|
||
|
||
token, _, _, err = utils.LoginTokenWithCSRF(&adminUser)
|
||
if err != nil {
|
||
return "", fmt.Errorf("生成管理员 token 失败:%w", err)
|
||
}
|
||
|
||
// 缓存 50 分钟(JWT 缓冲时间内有效)
|
||
global.GVA_REDIS.Set(ctx, gaia.RedisKeyGaiaAdminConsoleToken, token, 50*time.Minute)
|
||
return token, nil
|
||
}
|
||
|
||
// fetchModelPricingFromDify 通过 Dify Console API 拉取 LLM 模型定价,结果按 model 名缓存到 Redis(TTL 1 小时)。
|
||
// Dify Console API:GET /console/api/workspaces/current/models/model-types/llm
|
||
// 响应结构:{"data": [{"models": [{"model": "gpt-4o", "fetch_from": "...", "pricing": {"input":"0.005","output":"0.015","unit":"0.001","currency":"USD"}}]}]}
|
||
func (s *ModelProviderService) fetchModelPricingFromDify(modelName string) (*gaia.ModelPricing, error) {
|
||
const redisTTL = time.Hour
|
||
cacheKey := gaia.RedisKeyGaiaModelPricingPrefix + modelName
|
||
ctx := context.Background()
|
||
|
||
// 先查 Redis
|
||
if cached, err := global.GVA_REDIS.Get(ctx, cacheKey).Result(); err == nil && cached != "" {
|
||
var p gaia.ModelPricing
|
||
if json.Unmarshal([]byte(cached), &p) == nil {
|
||
return &p, nil
|
||
}
|
||
}
|
||
|
||
// 获取管理员 token
|
||
token, err := s.fetchAdminToken()
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 调用 Dify Console API
|
||
apiURL := strings.TrimSuffix(global.GVA_CONFIG.Gaia.Url, "/") +
|
||
"/console/api/workspaces/current/models/model-types/llm"
|
||
req, err := http.NewRequest(http.MethodGet, apiURL, nil)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("构建定价请求失败:%w", err)
|
||
}
|
||
req.Header.Set("Authorization", "Bearer "+token)
|
||
req.Header.Set("Content-Type", "application/json")
|
||
|
||
client := &http.Client{Timeout: 15 * time.Second}
|
||
resp, err := client.Do(req)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("请求 Dify 定价接口失败:%w", err)
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
respBody, err := io.ReadAll(resp.Body)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("读取定价响应失败:%w", err)
|
||
}
|
||
if resp.StatusCode != http.StatusOK {
|
||
return nil, fmt.Errorf("dify 定价接口返回 %d:%s", resp.StatusCode, string(respBody))
|
||
}
|
||
|
||
// 解析响应,批量缓存所有模型的定价
|
||
var apiResp gaia.DifyModelsResponse
|
||
if err = json.Unmarshal(respBody, &apiResp); err != nil {
|
||
return nil, fmt.Errorf("解析定价响应失败:%w", err)
|
||
}
|
||
|
||
var targetPricing *gaia.ModelPricing
|
||
for _, providerData := range apiResp.Data {
|
||
for _, m := range providerData.Models {
|
||
if m.Pricing == nil {
|
||
continue
|
||
}
|
||
p := gaia.ModelPricing{Currency: m.Pricing.Currency}
|
||
_, _ = fmt.Sscanf(m.Pricing.Input, "%f", &p.Input)
|
||
_, _ = fmt.Sscanf(m.Pricing.Output, "%f", &p.Output)
|
||
_, _ = fmt.Sscanf(m.Pricing.Unit, "%f", &p.Unit)
|
||
if p.Unit == 0 {
|
||
p.Unit = 0.001 // 默认按千 token 计费
|
||
}
|
||
|
||
// 缓存每个模型的定价
|
||
if b, e := json.Marshal(p); e == nil {
|
||
global.GVA_REDIS.Set(ctx, gaia.RedisKeyGaiaModelPricingPrefix+m.Model, string(b), redisTTL)
|
||
}
|
||
if m.Model == modelName {
|
||
cp := p
|
||
targetPricing = &cp
|
||
}
|
||
}
|
||
}
|
||
|
||
if targetPricing != nil {
|
||
return targetPricing, nil
|
||
}
|
||
// 未找到该模型的定价,写入空标记避免反复请求(TTL 10 分钟)
|
||
global.GVA_REDIS.Set(ctx, cacheKey, "{}", 10*time.Minute)
|
||
return nil, nil
|
||
}
|
||
|
||
// rmbToUSD 将人民币金额按固定汇率换算为 USD。
|
||
func rmbToUSD(rmb float64) float64 {
|
||
return rmb / gaia.RmbToUSDRate
|
||
}
|
||
|
||
// resolvePricing 返回模型定价:优先用从 Dify 拉取的 pricing,
|
||
// 其次查内置兜底定价表(BuiltinModelPricing),最后返回 nil。
|
||
func resolvePricing(pricing *gaia.ModelPricing, modelName string) *gaia.ModelPricing {
|
||
if pricing != nil && pricing.Unit > 0 {
|
||
return pricing
|
||
}
|
||
// 内置定价表精确匹配
|
||
if p, ok := gaia.BuiltinModelPricing[modelName]; ok {
|
||
return &p
|
||
}
|
||
// 前缀模糊匹配(如 "qwen3.5-plus-xxx" 匹配 "qwen3.5-plus")
|
||
lower := strings.ToLower(modelName)
|
||
for k, p := range gaia.BuiltinModelPricing {
|
||
if strings.HasPrefix(lower, strings.ToLower(k)) {
|
||
cp := p
|
||
return &cp
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// calcQuotaDelta 根据定价和 token 用量计算本次消耗的配额金额(统一以 USD 计)。
|
||
// Dify pricing 字段语义:input/output 为每「unit」个 token 的价格,unit 通常为 0.001(千分之一),
|
||
// 即 input=0.0014, unit=0.001 表示每千 token ¥0.0014 × (tokens/1000)。
|
||
// 公式:cost = tokens × input × unit(因为 unit=1/1000,等价于 tokens/1000 × input)。
|
||
// 若货币为 RMB/CNY,则除以汇率 7.26 换算为 USD,与 account_money_extend.used_quota 存储单位保持一致。
|
||
// 若 Dify 未返回定价则查内置兜底表;均未命中时按极小默认值记账,避免多扣。
|
||
func calcQuotaDelta(pricing *gaia.ModelPricing, modelName string, promptTokens, completionTokens int) float64 {
|
||
p := resolvePricing(pricing, modelName)
|
||
if p == nil {
|
||
// 兜底:仅做记账占位,不应大量触发
|
||
global.GVA_LOG.Warn("calcQuotaDelta 未找到模型定价,使用兜底值",
|
||
zap.String("model", modelName),
|
||
zap.Int("prompt_tokens", promptTokens),
|
||
zap.Int("completion_tokens", completionTokens),
|
||
)
|
||
return float64(promptTokens+completionTokens) * gaia.DefaultQuotaFallbackUSDPerToken
|
||
}
|
||
inputCost := float64(promptTokens) * p.Input * p.Unit
|
||
outputPrice := p.Output
|
||
if outputPrice == 0 {
|
||
outputPrice = p.Input
|
||
}
|
||
outputCost := float64(completionTokens) * outputPrice * p.Unit
|
||
total := inputCost + outputCost
|
||
|
||
// RMB/CNY 定价统一换算为 USD 后再扣费,与 used_quota 存储单位保持一致
|
||
if strings.EqualFold(p.Currency, "RMB") || strings.EqualFold(p.Currency, "CNY") {
|
||
total = rmbToUSD(total)
|
||
}
|
||
return total
|
||
}
|
||
|
||
// CheckAccountQuota 检查用户是否还有可用余额(total_quota - used_quota > 0)。
|
||
// total_quota = 0 视为"未设置限额",不拦截;total_quota > 0 时才做余额校验。
|
||
func (s *ModelProviderService) CheckAccountQuota(userID string) error {
|
||
var row gaiaResponse.CheckAccountQuotaRow
|
||
err := global.GVA_DB.Table("account_money_extend").
|
||
Select("total_quota, used_quota").
|
||
Where("account_id = ?::uuid", userID).
|
||
First(&row).Error
|
||
if err != nil {
|
||
// 记录未找到:可能尚未初始化,放行
|
||
return nil
|
||
}
|
||
// total_quota = 0 表示不限额,放行
|
||
if row.TotalQuota <= 0 {
|
||
return nil
|
||
}
|
||
if row.UsedQuota >= row.TotalQuota {
|
||
return fmt.Errorf("余额不足,已用 %.6f / 总额 %.6f USD,请联系管理员充值", row.UsedQuota, row.TotalQuota)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// deductAccountQuota 将消耗配额计入 account_money_extend.used_quota(原子累加)。
|
||
func deductAccountQuota(userID string, delta float64) {
|
||
if delta <= 0 {
|
||
return
|
||
}
|
||
if err := global.GVA_DB.Exec(
|
||
`UPDATE account_money_extend SET used_quota = used_quota + ?, updated_at = NOW() WHERE account_id = ?::uuid`,
|
||
delta, userID,
|
||
).Error; err != nil {
|
||
global.GVA_LOG.Warn("deductAccountQuota 失败",
|
||
zap.String("user_id", userID), zap.Float64("delta", delta), zap.Error(err))
|
||
}
|
||
}
|
||
|
||
// ModelProviderService 模型提供商服务,负责提供商配置、凭证获取、可用模型拉取及聊天请求代理。
|
||
type ModelProviderService struct{}
|
||
|
||
// GetProviderList 获取提供商配置列表
|
||
// @Tags System Integrated
|
||
// @Summary 获取提供商配置列表
|
||
// @Security ApiKeyAuth
|
||
// @accept application/json
|
||
// @Produce application/json
|
||
//
|
||
// 只展示三种逻辑提供商:openai(OpenAI)、tongyi(千问/通义)、google(Google)。
|
||
// Dify 里插件名为 langgenius/openai/openai、langgenius/tongyi/tongyi 等,与上述一一对应,不单独成行。
|
||
// 匹配规则:
|
||
// - 列表项 provider_name 固定为短名:openai / tongyi / google
|
||
// - 启用/已选模型:来自 admin 表 model_provider_config,按短名存储(provider_name = openai 等)
|
||
// - 可用模型:通过各提供商官方 API 拉取(OpenAI/通义兼容 GET /v1/models),不再使用 Dify provider_models
|
||
// - 凭证:来自 Dify providers + provider_credentials,按候选名查(见 difyProviderNameCandidates)
|
||
func (s *ModelProviderService) GetProviderList() ([]gaiaResponse.ProviderListItem, error) {
|
||
var configs []gaia.ModelProviderConfig
|
||
if err := global.GVA_DB.Find(&configs).Error; err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 只展示三种逻辑提供商;langgenius/openai/openai 等视为 openai 的数据来源,不单独列出
|
||
result := make([]gaiaResponse.ProviderListItem, len(gaia.SupportedProviders))
|
||
for i, providerName := range gaia.SupportedProviders {
|
||
var config *gaia.ModelProviderConfig
|
||
for j := range configs {
|
||
if configs[j].ProviderName == providerName {
|
||
config = &configs[j]
|
||
break
|
||
}
|
||
}
|
||
|
||
item := gaiaResponse.ProviderListItem{
|
||
ProviderName: providerName,
|
||
Enabled: false,
|
||
Models: []string{},
|
||
AvailableModels: []gaiaResponse.ModelInfo{},
|
||
}
|
||
if config != nil {
|
||
item.Enabled = config.Enabled
|
||
if config.Models != "" {
|
||
json.Unmarshal([]byte(config.Models), &item.Models)
|
||
}
|
||
}
|
||
result[i] = item
|
||
}
|
||
|
||
// 异步并发拉取各提供商的可用模型
|
||
var wg sync.WaitGroup
|
||
for i, providerName := range gaia.SupportedProviders {
|
||
wg.Add(1)
|
||
go func(idx int, name string) {
|
||
defer wg.Done()
|
||
availableModels, err := s.GetAvailableModelsFromDify(name)
|
||
if err != nil {
|
||
global.GVA_LOG.Warn("获取提供商可用模型失败", zap.String("provider", name), zap.Error(err))
|
||
} else {
|
||
result[idx].AvailableModels = availableModels
|
||
}
|
||
}(i, providerName)
|
||
}
|
||
wg.Wait()
|
||
|
||
return result, nil
|
||
}
|
||
|
||
// UpdateProviderConfig 更新指定提供商的启用状态及已选模型列表。
|
||
// @Tags System Integrated
|
||
// @Summary 更新提供商配置
|
||
// @Security ApiKeyAuth
|
||
// @accept application/json
|
||
// @Produce application/json
|
||
//
|
||
// 参数:
|
||
// - providerName: 提供商短名(openai/tongyi/google)
|
||
// - enabled: 是否启用
|
||
// - models: 已选模型 ID 列表
|
||
func (s *ModelProviderService) UpdateProviderConfig(providerName string, enabled bool, models []string) error {
|
||
modelsJSON, err := json.Marshal(models)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
var config gaia.ModelProviderConfig
|
||
err = global.GVA_DB.Where("provider_name = ?", providerName).First(&config).Error
|
||
|
||
if err != nil {
|
||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||
// 创建新记录
|
||
config = gaia.ModelProviderConfig{
|
||
ProviderName: providerName,
|
||
Enabled: enabled,
|
||
Models: string(modelsJSON),
|
||
}
|
||
return global.GVA_DB.Create(&config).Error
|
||
}
|
||
return err
|
||
}
|
||
|
||
// 更新现有记录
|
||
config.Enabled = enabled
|
||
config.Models = string(modelsJSON)
|
||
return global.GVA_DB.Save(&config).Error
|
||
}
|
||
|
||
// GetEnabledModels 获取所有已启用提供商的已选模型,以 OpenAI /v1/models 响应格式返回。
|
||
// @Tags System Integrated
|
||
// @Summary 获取已启用的模型列表
|
||
// @Security ApiKeyAuth
|
||
// @accept application/json
|
||
// @Produce application/json
|
||
func (s *ModelProviderService) GetEnabledModels() (gaiaResponse.OpenAIModelsResponse, error) {
|
||
var configs []gaia.ModelProviderConfig
|
||
if err := global.GVA_DB.Where("enabled = ?", true).Find(&configs).Error; err != nil {
|
||
return gaiaResponse.OpenAIModelsResponse{}, err
|
||
}
|
||
|
||
resp := gaiaResponse.OpenAIModelsResponse{
|
||
Data: []gaiaResponse.ModelInfo{},
|
||
}
|
||
|
||
for _, config := range configs {
|
||
var models []string
|
||
if config.Models != "" {
|
||
if err := json.Unmarshal([]byte(config.Models), &models); err != nil {
|
||
continue
|
||
}
|
||
}
|
||
|
||
for _, modelID := range models {
|
||
resp.Data = append(resp.Data, gaiaResponse.ModelInfo{
|
||
ID: modelID,
|
||
Name: modelID,
|
||
})
|
||
}
|
||
}
|
||
|
||
return resp, nil
|
||
}
|
||
|
||
// getAvailableModelsFromProviderModelCredentials 从 Dify provider_model_credentials 表拉取指定提供商的可用模型列表。
|
||
// 返回的模型 ID/Name 均为表内 model_name;实际请求 GPT 时由 API 侧根据 encrypted_config 中的 base_model_name 调用。
|
||
// 用于 admin 第三方模型列表展示(如 azure_openai 展示 model_name,调用时用 base_model_name)。
|
||
func (s *ModelProviderService) getAvailableModelsFromProviderModelCredentials(providerName string) ([]gaiaResponse.ModelInfo, error) {
|
||
var firstTenant gaia.Tenants
|
||
tenantID := firstTenant.GetSuperAdminTenantId()
|
||
if tenantID == "" {
|
||
return nil, nil
|
||
}
|
||
var modelNames []string
|
||
err := global.GVA_DB.Table("provider_model_credentials").
|
||
Where("tenant_id = ? AND provider_name LIKE ?", tenantID, fmt.Sprintf("%%%s%%", providerName)).
|
||
Distinct("model_name").
|
||
Pluck("model_name", &modelNames).Error
|
||
if err != nil {
|
||
global.GVA_LOG.Warn("从 provider_model_credentials 拉取模型列表失败", zap.String(
|
||
"provider", providerName), zap.Error(err))
|
||
return nil, nil
|
||
}
|
||
list := make([]gaiaResponse.ModelInfo, 0, len(modelNames))
|
||
for _, name := range modelNames {
|
||
if name != "" {
|
||
list = append(list, gaiaResponse.ModelInfo{ID: name, Name: name})
|
||
}
|
||
}
|
||
return list, nil
|
||
}
|
||
|
||
// GetAvailableModelsFromDify 获取提供商的可用模型列表。
|
||
// - Azure:仅从 provider_model_credentials 表拉取,列表展示 model_name,实际请求 GPT 时由 API 侧用 encrypted_config 的 base_model_name。
|
||
// - OpenAI / 通义 / Google:与原先一致,通过各提供商官方 API 拉取可用模型。未配置凭证时返回空列表且不报错。
|
||
//
|
||
// 参数 providerName 为短名(openai/tongyi/google/azure)。
|
||
func (s *ModelProviderService) GetAvailableModelsFromDify(providerName string) ([]gaiaResponse.ModelInfo, error) {
|
||
if providerName == gaia.ProviderAzure {
|
||
return s.getAvailableModelsFromProviderModelCredentials(providerName)
|
||
}
|
||
|
||
creds, err := s.GetDifyProviderCredentials(providerName)
|
||
if err != nil || creds.APIKey == "" {
|
||
return nil, nil
|
||
}
|
||
|
||
client := &http.Client{Timeout: 15 * time.Second}
|
||
switch providerName {
|
||
case gaia.ProviderOpenai:
|
||
base := creds.Endpoint
|
||
if base == "" {
|
||
base = "https://api.openai.com"
|
||
}
|
||
return s.fetchOpenAICompatibleModels(client, base, creds.APIKey)
|
||
case gaia.ProviderTongyi:
|
||
return s.fetchOpenAICompatibleModels(
|
||
client, "https://dashscope.aliyuncs.com/api", creds.APIKey)
|
||
case gaia.ProviderGoogle:
|
||
base := creds.Endpoint
|
||
if base == "" {
|
||
base = gaia.DefaultAPIBase[gaia.ProviderGoogle]
|
||
}
|
||
return s.fetchGeminiModels(client, base, creds.APIKey)
|
||
case gaia.ProviderAnthropic:
|
||
return nil, nil
|
||
case gaia.ProviderAWS:
|
||
// AWS Bedrock 没有统一的 OpenAI 兼容 /v1/models 接口,模型由前端 allow-create 手输
|
||
return nil, nil
|
||
default:
|
||
if creds.Endpoint != "" {
|
||
return s.fetchOpenAICompatibleModels(client, creds.Endpoint, creds.APIKey)
|
||
}
|
||
return nil, nil
|
||
}
|
||
}
|
||
|
||
// fetchOpenAICompatibleModels 调用 OpenAI 兼容的 GET /v1/models,解析为 ModelInfo 列表。
|
||
// 兼容两种响应格式:
|
||
// 1) OpenAI: { "data": [ { "id": "..." }, ... ] }
|
||
// 2) 通义: { "success": true, "output": { "models": [ { "model": "...", "name": "..." }, ... ] } }
|
||
func (s *ModelProviderService) fetchOpenAICompatibleModels(client *http.Client, baseURL, apiKey string) (
|
||
[]gaiaResponse.ModelInfo, error) {
|
||
url := strings.TrimSuffix(baseURL, "/") + "/v1/models"
|
||
req, err := http.NewRequest("GET", url, nil)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||
req.Header.Set("Content-Type", "application/json")
|
||
resp, err := client.Do(req)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer resp.Body.Close()
|
||
body, _ := io.ReadAll(resp.Body)
|
||
if resp.StatusCode != http.StatusOK {
|
||
global.GVA_LOG.Warn("拉取模型列表接口非 200", zap.String("url", url), zap.Int(
|
||
"status", resp.StatusCode), zap.String("body", string(body)))
|
||
return nil, fmt.Errorf("接口返回 %d", resp.StatusCode)
|
||
}
|
||
|
||
// 先尝试 OpenAI 格式
|
||
var listResp gaiaResponse.OpenAIModelsListResponse
|
||
if err = json.Unmarshal(body, &listResp); err == nil && len(listResp.Data) > 0 {
|
||
list := make([]gaiaResponse.ModelInfo, 0, len(listResp.Data))
|
||
for _, m := range listResp.Data {
|
||
if m.ID != "" {
|
||
list = append(list, gaiaResponse.ModelInfo{ID: m.ID, Name: m.ID})
|
||
}
|
||
}
|
||
return list, nil
|
||
}
|
||
|
||
// 再尝试通义格式:success + output.models
|
||
var tongyiResp gaiaResponse.TongyiModelsListResponse
|
||
if err = json.Unmarshal(body, &tongyiResp); err != nil {
|
||
return nil, fmt.Errorf("解析模型列表失败(非 OpenAI 也非通义格式): %w", err)
|
||
}
|
||
if !tongyiResp.Success || len(tongyiResp.Output.Models) == 0 {
|
||
return nil, fmt.Errorf("通义接口返回无模型或 success 不为 true")
|
||
}
|
||
list := make([]gaiaResponse.ModelInfo, 0, len(tongyiResp.Output.Models))
|
||
for _, m := range tongyiResp.Output.Models {
|
||
if m.Model != "" {
|
||
name := m.Name
|
||
if name == "" {
|
||
name = m.Model
|
||
}
|
||
list = append(list, gaiaResponse.ModelInfo{ID: m.Model, Name: name})
|
||
}
|
||
}
|
||
return list, nil
|
||
}
|
||
|
||
// fetchGeminiModels 调用 Google Gemini GET /v1beta/models?key=API_KEY,解析 models[],支持分页。
|
||
// 认证使用 query 参数 key,响应格式:{ "models": [ { "name": "models/xxx", "baseModelId": "xxx", "displayName": "..." } ], "nextPageToken": "..." }
|
||
func (s *ModelProviderService) fetchGeminiModels(client *http.Client, baseURL, apiKey string) (
|
||
[]gaiaResponse.ModelInfo, error) {
|
||
baseURL = strings.TrimSuffix(baseURL, "/")
|
||
all := make([]gaiaResponse.ModelInfo, 0)
|
||
pageToken := ""
|
||
|
||
for {
|
||
url := baseURL + "/v1beta/models?key=" + apiKey
|
||
if pageToken != "" {
|
||
url += "&pageToken=" + pageToken
|
||
}
|
||
req, err := http.NewRequest("GET", url, nil)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
req.Header.Set("Content-Type", "application/json")
|
||
resp, err := client.Do(req)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
body, _ := io.ReadAll(resp.Body)
|
||
resp.Body.Close()
|
||
if resp.StatusCode != http.StatusOK {
|
||
global.GVA_LOG.Warn("拉取 Gemini 模型列表非 200", zap.String(
|
||
"url", baseURL+"/v1beta/models"), zap.Int("status", resp.StatusCode), zap.String(
|
||
"body", string(body)))
|
||
return nil, fmt.Errorf("接口返回 %d", resp.StatusCode)
|
||
}
|
||
|
||
var listResp gaiaResponse.GeminiModelsListResponse
|
||
if err = json.Unmarshal(body, &listResp); err != nil {
|
||
return nil, fmt.Errorf("解析 Gemini 模型列表失败: %w", err)
|
||
}
|
||
|
||
for _, m := range listResp.Models {
|
||
// 请求时使用 baseModelId(如 gemini-1.5-flash),无则用 name 去掉 "models/" 前缀
|
||
id := m.BaseModelID
|
||
if id == "" && m.Name != "" {
|
||
id = strings.TrimPrefix(m.Name, "models/")
|
||
}
|
||
if id == "" {
|
||
continue
|
||
}
|
||
name := m.DisplayName
|
||
if name == "" {
|
||
name = id
|
||
}
|
||
all = append(all, gaiaResponse.ModelInfo{ID: id, Name: name})
|
||
}
|
||
|
||
pageToken = listResp.NextPageToken
|
||
if pageToken == "" {
|
||
break
|
||
}
|
||
}
|
||
|
||
return all, nil
|
||
}
|
||
|
||
// fetchAzureOpenAIModels 调用 Azure OpenAI GET {endpoint}/openai/models?api-version={version},解析 data[]。
|
||
// 认证使用 api-key 请求头,响应格式:{ "data": [ { "id": "...", "object": "model" } ] }
|
||
func (s *ModelProviderService) fetchAzureOpenAIModels(client *http.Client, baseURL, apiKey, apiVersion string) (
|
||
[]gaiaResponse.ModelInfo, error) {
|
||
baseURL = strings.TrimSuffix(baseURL, "/")
|
||
if apiVersion == "" {
|
||
apiVersion = "2024-08-01-preview" // 默认 API 版本
|
||
}
|
||
|
||
url := fmt.Sprintf("%s/openai/models?api-version=%s", baseURL, apiVersion)
|
||
req, err := http.NewRequest("GET", url, nil)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
req.Header.Set("Content-Type", "application/json")
|
||
req.Header.Set("api-key", apiKey)
|
||
|
||
resp, err := client.Do(req)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
body, _ := io.ReadAll(resp.Body)
|
||
if resp.StatusCode != http.StatusOK {
|
||
global.GVA_LOG.Warn("拉取 Azure OpenAI 模型列表非 200", zap.String("url", url), zap.Int(
|
||
"status", resp.StatusCode), zap.String("body", string(body)))
|
||
return nil, fmt.Errorf("接口返回 %d", resp.StatusCode)
|
||
}
|
||
|
||
// Azure OpenAI 返回格式与 OpenAI 类似
|
||
var listResp gaiaResponse.OpenAIModelsListResponse
|
||
if err = json.Unmarshal(body, &listResp); err != nil {
|
||
return nil, fmt.Errorf("解析 Azure OpenAI 模型列表失败: %w", err)
|
||
}
|
||
|
||
list := make([]gaiaResponse.ModelInfo, 0, len(listResp.Data))
|
||
for _, m := range listResp.Data {
|
||
if m.ID != "" {
|
||
list = append(list, gaiaResponse.ModelInfo{ID: m.ID, Name: m.ID})
|
||
}
|
||
}
|
||
|
||
return list, nil
|
||
}
|
||
|
||
// GetDifyProviderCredentials 从 Dify 数据库读取指定提供商的凭证,支持缓存与解密。
|
||
// 查询优先级:
|
||
// 1. providers + provider_credentials 表(传统方式)
|
||
// 2. provider_model_credentials 表(多凭证方式,按 updated_at 倒序取最新)
|
||
//
|
||
// @Tags System Integrated
|
||
// @Summary 获取提供商凭证
|
||
// @Security ApiKeyAuth
|
||
// @accept application/json
|
||
// @Produce application/json
|
||
func (s *ModelProviderService) GetDifyProviderCredentials(providerName string) (
|
||
creds *gaiaResponse.ProviderCredentials, err error) {
|
||
creds = &gaiaResponse.ProviderCredentials{}
|
||
|
||
// 首先尝试从Redis缓存获取(按请求的 providerName 缓存)
|
||
var cached string
|
||
var firstTenant gaia.Tenants
|
||
tenantID := firstTenant.GetSuperAdminTenantId()
|
||
cacheKey := gaia.RedisKeyModelProviderCredentialsPrefix + providerName
|
||
if cached, err = global.GVA_Dify_REDIS.Get(context.Background(), cacheKey).Result(); err == nil {
|
||
if err = json.Unmarshal([]byte(cached), &creds); err == nil {
|
||
return creds, nil
|
||
}
|
||
}
|
||
|
||
// 尝试方式1: 从 providers + provider_credentials 表查询
|
||
var row gaia.ProviderCredential
|
||
err = global.GVA_DB.Table("providers").
|
||
Select("provider_credentials.encrypted_config, providers.tenant_id").
|
||
Joins("LEFT JOIN provider_credentials ON providers.credential_id = provider_credentials.id").
|
||
Where("providers.tenant_id = ? AND providers.provider_name LIKE ? AND providers.provider_type = ? AND providers.is_valid = ?",
|
||
tenantID, fmt.Sprintf("%%%s%%", providerName), gaia.DifyProviderTypeCustom, true).
|
||
Order("provider_credentials.updated_at DESC").
|
||
First(&row).Error
|
||
|
||
// 如果方式1 未找到记录,尝试方式2: 从 provider_model_credentials 表查询
|
||
if err != nil || row.EncryptedConfig == "" {
|
||
var pmcRow gaia.ProviderCredential
|
||
if pmcErr := global.GVA_DB.Table("provider_model_credentials").
|
||
Select("encrypted_config, tenant_id, provider_name, updated_at").
|
||
Where("tenant_id = ? AND provider_name LIKE ?", tenantID, fmt.Sprintf("%%%s%%", providerName)).
|
||
Order("updated_at DESC"). // 按 updated_at 倒序,取最新的凭证
|
||
First(&pmcRow).Error; pmcErr == nil && pmcRow.EncryptedConfig != "" {
|
||
row = pmcRow
|
||
err = nil
|
||
}
|
||
}
|
||
|
||
if err != nil || row.EncryptedConfig == "" {
|
||
return creds, fmt.Errorf("未找到提供商 %s 的凭证配置", providerName)
|
||
}
|
||
|
||
// 兼容两种存储:1) 明文 JSON(如 {"openai_api_key":"...", "openai_api_base":"..."});2) Dify RSA+AES-EAX 加密后再 base64
|
||
var base, apiVersion string
|
||
var configMap map[string]interface{}
|
||
if err = json.Unmarshal([]byte(row.EncryptedConfig), &configMap); err == nil {
|
||
// 解密函数用于处理加密的值
|
||
if config, ok := configMap[gaia.ConfigKeyOpenaiAPIKey]; ok {
|
||
creds.APIKey, err = s.decryptConfig(config.(string), row.TenantID)
|
||
if base, ok = configMap[gaia.ConfigKeyOpenaiAPIBase].(string); ok && strings.TrimSpace(base) != "" {
|
||
creds.Endpoint = strings.TrimSuffix(strings.TrimSpace(base), "/")
|
||
}
|
||
// 提取 API 版本(Azure 使用)
|
||
if apiVersion, ok = configMap[gaia.ConfigKeyOpenaiAPIVersion].(string); ok && strings.TrimSpace(apiVersion) != "" {
|
||
creds.APIVersion = strings.TrimSpace(apiVersion)
|
||
}
|
||
} else if config, ok = configMap[gaia.ConfigKeyDashScopeAPIKey]; ok {
|
||
creds.APIKey, err = s.decryptConfig(config.(string), row.TenantID)
|
||
} else if config, ok = configMap[gaia.ConfigKeyAPIKey]; ok {
|
||
creds.APIKey, err = s.decryptConfig(config.(string), row.TenantID)
|
||
} else {
|
||
// 尝试从备选字段中查找
|
||
for _, key := range gaia.CredentialKeyFallback {
|
||
var v string
|
||
if v, ok = configMap[key].(string); ok && v != "" {
|
||
if creds.APIKey, err = s.decryptConfig(v, row.TenantID); err == nil && creds.APIKey != "" {
|
||
break
|
||
}
|
||
}
|
||
}
|
||
if base, ok = configMap[gaia.ConfigKeyOpenaiAPIBase].(string); ok && strings.TrimSpace(base) != "" {
|
||
creds.Endpoint = strings.TrimSuffix(strings.TrimSpace(base), "/")
|
||
}
|
||
// 提取 API 版本(Azure 使用)
|
||
if apiVersion, ok = configMap[gaia.ConfigKeyOpenaiAPIVersion].(string); ok && strings.TrimSpace(apiVersion) != "" {
|
||
creds.APIVersion = strings.TrimSpace(apiVersion)
|
||
}
|
||
}
|
||
if err != nil {
|
||
return nil, fmt.Errorf("解密凭证失败: %w", err)
|
||
}
|
||
}
|
||
if creds.APIKey == "" && creds.AWSAccessKeyID == "" {
|
||
return nil, fmt.Errorf("未能从配置中提取API Key(也未找到 AWS 凭证)")
|
||
}
|
||
|
||
// 缓存凭证(1小时)
|
||
var cacheJSON []byte
|
||
if cacheJSON, err = json.Marshal(creds); err == nil {
|
||
global.GVA_Dify_REDIS.Set(context.Background(), cacheKey, cacheJSON, time.Hour)
|
||
}
|
||
|
||
return creds, nil
|
||
}
|
||
|
||
// decryptConfig 解密Dify的加密配置(RSA + AES-EAX 混合加密)
|
||
// Dify 使用 RSA 2048 + AES-EAX 混合加密,密文格式为:
|
||
// Base64( "HYBRID:" + enc_aes_key(256字节) + nonce(16字节) + tag(16字节) + ciphertext )
|
||
func (s *ModelProviderService) decryptConfig(encryptedConfig string, tenantID string) (string, error) {
|
||
// 1. Base64 解码
|
||
encrypted, err := base64.StdEncoding.DecodeString(encryptedConfig)
|
||
if err != nil {
|
||
return "", fmt.Errorf("base64 decode failed: %w", err)
|
||
}
|
||
|
||
// 2. 检查并去除 "HYBRID:" 前缀
|
||
prefix := []byte("HYBRID:")
|
||
if !bytes.HasPrefix(encrypted, prefix) {
|
||
// 如果没有 HYBRID 前缀,可能是明文或其他格式,直接返回原值
|
||
return encryptedConfig, nil
|
||
}
|
||
encrypted = encrypted[len(prefix):]
|
||
|
||
// 3. 读取 tenant 私钥
|
||
privateKey, err := s.loadPrivateKey(tenantID)
|
||
if err != nil {
|
||
return "", fmt.Errorf("load private key failed: %w", err)
|
||
}
|
||
|
||
// 4. 解析密文结构
|
||
// RSA 2048 = 256 字节密钥
|
||
rsaKeySize := privateKey.Size() // 通常是 256
|
||
if len(encrypted) < rsaKeySize+32 {
|
||
return "", errors.New("encrypted data too short")
|
||
}
|
||
|
||
encAESKey := encrypted[:rsaKeySize]
|
||
nonce := encrypted[rsaKeySize : rsaKeySize+16]
|
||
tag := encrypted[rsaKeySize+16 : rsaKeySize+32]
|
||
ciphertext := encrypted[rsaKeySize+32:]
|
||
|
||
// 5. RSA OAEP 解密 AES 密钥(使用 SHA-1,与 Dify Python 实现一致)
|
||
aesKey, err := rsa.DecryptOAEP(sha1.New(), rand.Reader, privateKey, encAESKey, nil)
|
||
if err != nil {
|
||
return "", fmt.Errorf("RSA decrypt failed: %w", err)
|
||
}
|
||
|
||
// 6. AES-EAX 解密数据
|
||
plaintext, err := s.aesEAXDecrypt(aesKey, nonce, ciphertext, tag)
|
||
if err != nil {
|
||
return "", fmt.Errorf("AES-EAX decrypt failed: %w", err)
|
||
}
|
||
|
||
return string(plaintext), nil
|
||
}
|
||
|
||
// loadPrivateKey 从配置的存储路径加载指定 tenant 的 RSA 私钥(PEM 文件)。
|
||
// 若该 tenant 的私钥文件不存在,则回退到「第一个创建的空间」(tenants 表 created_at 最早)的私钥路径,与 Dify 默认空间约定一致。
|
||
func (s *ModelProviderService) loadPrivateKey(tenantID string) (*rsa.PrivateKey, error) {
|
||
key, err := s.loadPrivateKeyFromPath(tenantID)
|
||
if err != nil {
|
||
// 若错误为文件不存在,尝试使用第一个创建的空间的私钥
|
||
if errors.Is(err, os.ErrNotExist) || strings.Contains(err.Error(), "no such file or directory") {
|
||
var firstTenant gaia.Tenants
|
||
firstID := firstTenant.GetSuperAdminTenantId()
|
||
if firstID != "" && firstID != tenantID {
|
||
key, fallbackErr := s.loadPrivateKeyFromPath(firstID)
|
||
if fallbackErr == nil {
|
||
return key, nil
|
||
}
|
||
}
|
||
}
|
||
return nil, err
|
||
}
|
||
return key, nil
|
||
}
|
||
|
||
// loadPrivateKeyFromPath 根据 tenantID 解析私钥文件路径并读取、解析 PEM,不做回退。
|
||
func (s *ModelProviderService) loadPrivateKeyFromPath(tenantID string) (*rsa.PrivateKey, error) {
|
||
storagePath := global.GVA_CONFIG.Gaia.StoragePath
|
||
if storagePath == "" {
|
||
storagePath = "/app/storage"
|
||
}
|
||
|
||
filepath := fmt.Sprintf("%s/privkeys/%s/private.pem", storagePath, tenantID)
|
||
|
||
if _, err := os.Stat(filepath); os.IsNotExist(err) && storagePath == "/app/storage" {
|
||
localPath := fmt.Sprintf("../../api/storage/privkeys/%s/private.pem", tenantID)
|
||
if _, err := os.Stat(localPath); err == nil {
|
||
filepath = localPath
|
||
}
|
||
}
|
||
|
||
pemData, err := os.ReadFile(filepath)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("read private key file failed: %w", err)
|
||
}
|
||
|
||
block, _ := pem.Decode(pemData)
|
||
if block == nil {
|
||
return nil, errors.New("failed to decode PEM block")
|
||
}
|
||
|
||
privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
|
||
if err != nil {
|
||
key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("parse private key failed: %w", err)
|
||
}
|
||
var ok bool
|
||
privateKey, ok = key.(*rsa.PrivateKey)
|
||
if !ok {
|
||
return nil, errors.New("private key is not RSA key")
|
||
}
|
||
}
|
||
|
||
return privateKey, nil
|
||
}
|
||
|
||
// aesEAXDecrypt 使用 AES-EAX 解密数据
|
||
// EAX 模式是一种认证加密模式,使用第三方库 go.gnd.pw/crypto/eax 实现
|
||
func (s *ModelProviderService) aesEAXDecrypt(key, nonce, ciphertext, tag []byte) ([]byte, error) {
|
||
block, err := aes.NewCipher(key)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 创建 EAX AEAD 实例
|
||
aead, err := eax.NewEAX(block)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("create EAX cipher failed: %w", err)
|
||
}
|
||
|
||
// EAX 的 Open 方法需要 nonce 和 ciphertext+tag 的组合
|
||
// Python pycryptodome 的格式: ciphertext 和 tag 是分开的
|
||
// Go EAX 库的 Open 期望格式: ciphertext || tag
|
||
combined := make([]byte, len(ciphertext)+len(tag))
|
||
copy(combined, ciphertext)
|
||
copy(combined[len(ciphertext):], tag)
|
||
|
||
// 解密并验证
|
||
plaintext, err := aead.Open(nil, nonce, combined, nil)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("EAX decrypt failed: %w", err)
|
||
}
|
||
|
||
return plaintext, nil
|
||
}
|
||
|
||
// ProxyChat 将聊天请求代理到上游提供商,校验模型已开启并写入流式/非流式响应到 writer,并记录代理日志。
|
||
// @Tags System Integrated
|
||
// @Summary 代理聊天请求
|
||
// @Security ApiKeyAuth
|
||
// @accept application/json
|
||
// @Produce application/json
|
||
func (s *ModelProviderService) ProxyChat(userID string, req gaiaRequest.ChatRequest, writer io.Writer) error {
|
||
// 按“已选模型”解析实际渠道(如 gpt-5-chat 若只在 Azure 下勾选则走 azure)
|
||
providerName, err := s.resolveProviderByModel(req.Model)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
// 获取提供商凭证
|
||
creds, err := s.GetDifyProviderCredentials(providerName)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
// 获取上游端点
|
||
endpoint := s.getUpstreamEndpoint(providerName)
|
||
|
||
// 构建请求
|
||
reqBody, err := json.Marshal(req)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
httpReq, err := http.NewRequest("POST", endpoint, bytes.NewReader(reqBody))
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
// 设置请求头
|
||
httpReq.Header.Set("Content-Type", "application/json")
|
||
httpReq.Header.Set("Authorization", fmt.Sprintf("Bearer %s", creds.APIKey))
|
||
|
||
// 发送请求
|
||
client := &http.Client{
|
||
Timeout: 5 * time.Minute,
|
||
}
|
||
|
||
resp, err := client.Do(httpReq)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
if resp.StatusCode != http.StatusOK {
|
||
body, _ := io.ReadAll(resp.Body)
|
||
return fmt.Errorf("上游返回错误: %d %s", resp.StatusCode, string(body))
|
||
}
|
||
|
||
// 记录开始时间(用于日志)
|
||
startTime := time.Now()
|
||
var requestTokens, responseTokens int
|
||
status := "success"
|
||
var errorMsg string
|
||
|
||
defer func() {
|
||
// 记录日志
|
||
global.GVA_DB.Create(&gaia.ModelProxyLog{
|
||
UserId: userID,
|
||
ProviderName: providerName,
|
||
ModelName: req.Model,
|
||
RequestTokens: requestTokens,
|
||
ResponseTokens: responseTokens,
|
||
Status: status,
|
||
ErrorMessage: errorMsg,
|
||
CreatedAt: startTime,
|
||
})
|
||
}()
|
||
|
||
// 处理流式响应
|
||
if req.Stream {
|
||
scanner := bufio.NewScanner(resp.Body)
|
||
for scanner.Scan() {
|
||
line := scanner.Text()
|
||
if _, err = writer.Write([]byte(line + "\n")); err != nil {
|
||
status = "error"
|
||
errorMsg = err.Error()
|
||
return err
|
||
}
|
||
// Flush if writer supports it
|
||
if flusher, ok := writer.(http.Flusher); ok {
|
||
flusher.Flush()
|
||
}
|
||
}
|
||
if err = scanner.Err(); err != nil {
|
||
status = "error"
|
||
errorMsg = err.Error()
|
||
return err
|
||
}
|
||
} else {
|
||
// 非流式响应
|
||
if _, err = io.Copy(writer, resp.Body); err != nil {
|
||
status = "error"
|
||
errorMsg = err.Error()
|
||
return err
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// getProviderCandidatesByModel 返回可能服务该模型的提供商短名列表(用于按“已选模型”解析实际渠道)。
|
||
// 例如 gpt 系列可能走 openai 或 azure,返回 [azure, openai] 以便优先匹配用户在 admin 里配置的渠道。
|
||
func (s *ModelProviderService) getProviderCandidatesByModel(modelName string) []string {
|
||
modelLower := strings.ToLower(modelName)
|
||
if strings.HasPrefix(modelLower, "gpt") || strings.Contains(modelLower, "openai") {
|
||
return []string{gaia.ProviderAzure, gaia.ProviderOpenai}
|
||
}
|
||
if strings.Contains(modelLower, "azure") {
|
||
return []string{gaia.ProviderAzure}
|
||
}
|
||
if strings.HasPrefix(modelLower, "qwen") || strings.Contains(modelLower, "tongyi") {
|
||
return []string{gaia.ProviderTongyi}
|
||
}
|
||
if strings.HasPrefix(modelLower, "gemini") || strings.Contains(modelLower, "google") {
|
||
return []string{gaia.ProviderGoogle}
|
||
}
|
||
if strings.Contains(modelLower, "claude") || strings.Contains(modelLower, "anthropic") {
|
||
// 顺序即优先级:anthropic 直连优先,未开启则回落到 AWS Bedrock;都开则走 anthropic
|
||
return []string{gaia.ProviderAnthropic, gaia.ProviderAWS}
|
||
}
|
||
// Kimi / Moonshot 系列经由 tongyi(百炼)渠道转发
|
||
if strings.HasPrefix(modelLower, "kimi") || strings.Contains(modelLower, "moonshot") {
|
||
return []string{gaia.ProviderTongyi}
|
||
}
|
||
// GLM/智谱 可能配置在 tongyi(统一入口)或 zhipuai 下,先试 tongyi
|
||
if strings.HasPrefix(modelLower, "glm") || strings.Contains(modelLower, "zhipu") || strings.Contains(modelLower, "chatglm") {
|
||
return []string{gaia.ProviderTongyi, gaia.ProviderZhipuai}
|
||
}
|
||
// 支持 "minimax" 或 "MiniMax/MiniMax-M2.5" 等形式;可能配置在 tongyi(统一入口)或 minimax 下,先试 tongyi
|
||
if strings.HasPrefix(modelLower, "minimax") || strings.Contains(modelLower, "abab") {
|
||
return []string{gaia.ProviderTongyi, gaia.ProviderMinimax}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// resolveProviderByModel 根据模型名解析实际使用的提供商:在“可能服务该模型的”渠道中,取第一个已启用且已选模型列表包含该模型的渠道。
|
||
// 这样当模型名是 gpt-5-chat 且用户只在 Azure 渠道下勾选了该模型时,会正确走 azure 而不是 openai。
|
||
func (s *ModelProviderService) resolveProviderByModel(modelName string) (string, error) {
|
||
candidates := s.getProviderCandidatesByModel(modelName)
|
||
if len(candidates) == 0 {
|
||
global.GVA_LOG.Warn("resolveProviderByModel 无法识别提供商", zap.String("model", modelName), zap.String("model_lower", strings.ToLower(modelName)))
|
||
return "", fmt.Errorf("无法识别模型 %s 的提供商", modelName)
|
||
}
|
||
for _, p := range candidates {
|
||
if s.isModelEnabled(p, modelName) {
|
||
return p, nil
|
||
}
|
||
}
|
||
return "", fmt.Errorf("模型 %s 未开启", modelName)
|
||
}
|
||
|
||
// getProviderByModel 仅根据模型名称推断提供商短名(不查配置表)。代理校验“是否开启”请用 resolveProviderByModel。
|
||
func (s *ModelProviderService) getProviderByModel(modelName string) (string, error) {
|
||
modelLower := strings.ToLower(modelName)
|
||
if strings.HasPrefix(modelLower, "gpt") || strings.Contains(modelLower, "openai") {
|
||
return gaia.ProviderOpenai, nil
|
||
}
|
||
if strings.HasPrefix(modelLower, "qwen") || strings.Contains(modelLower, "tongyi") {
|
||
return gaia.ProviderTongyi, nil
|
||
}
|
||
if strings.HasPrefix(modelLower, "gemini") || strings.Contains(modelLower, "google") {
|
||
return gaia.ProviderGoogle, nil
|
||
}
|
||
if strings.Contains(modelLower, "claude") || strings.Contains(modelLower, "anthropic") {
|
||
// 仅按名字推断时默认 anthropic;实际渠道(含 AWS Bedrock)由 resolveProviderByModel 决定
|
||
return gaia.ProviderAnthropic, nil
|
||
}
|
||
// Kimi / Moonshot 默认走 tongyi(百炼)渠道
|
||
if strings.HasPrefix(modelLower, "kimi") || strings.Contains(modelLower, "moonshot") {
|
||
return gaia.ProviderTongyi, nil
|
||
}
|
||
if strings.Contains(modelLower, "azure") {
|
||
return gaia.ProviderAzure, nil
|
||
}
|
||
// GLM 类模型可能走 tongyi 或 zhipuai,仅推断时默认 tongyi
|
||
if strings.HasPrefix(modelLower, "glm") || strings.Contains(modelLower, "zhipu") || strings.Contains(modelLower, "chatglm") {
|
||
return gaia.ProviderTongyi, nil
|
||
}
|
||
// MiniMax 类模型可能走 tongyi 或 minimax,仅推断时默认 tongyi(实际以 resolveProviderByModel + 已选模型为准)
|
||
if strings.HasPrefix(modelLower, "minimax") || strings.Contains(modelLower, "abab") {
|
||
return gaia.ProviderTongyi, nil
|
||
}
|
||
return "", fmt.Errorf("无法识别模型 %s 的提供商", modelName)
|
||
}
|
||
|
||
// isModelEnabled 检查指定提供商下该模型是否在已启用且已选模型列表中。
|
||
func (s *ModelProviderService) isModelEnabled(providerName, modelName string) bool {
|
||
var config gaia.ModelProviderConfig
|
||
if err := global.GVA_DB.Where("provider_name = ? AND enabled = ?", providerName, true).First(&config).Error; err != nil {
|
||
return false
|
||
}
|
||
|
||
var models []string
|
||
if err := json.Unmarshal([]byte(config.Models), &models); err != nil {
|
||
return false
|
||
}
|
||
|
||
suffixOfRequest := modelName
|
||
if idx := strings.LastIndex(modelName, "/"); idx >= 0 && idx < len(modelName)-1 {
|
||
suffixOfRequest = modelName[idx+1:]
|
||
}
|
||
for _, m := range models {
|
||
if m == modelName {
|
||
return true
|
||
}
|
||
// 请求 model 为 "MiniMax/MiniMax-M2.5" 时,与配置中的 "MiniMax-M2.5" 视为同一模型
|
||
if m == suffixOfRequest {
|
||
return true
|
||
}
|
||
suffixOfConfig := m
|
||
if idx := strings.LastIndex(m, "/"); idx >= 0 && idx < len(m)-1 {
|
||
suffixOfConfig = m[idx+1:]
|
||
}
|
||
if suffixOfRequest == suffixOfConfig {
|
||
return true
|
||
}
|
||
}
|
||
|
||
return false
|
||
}
|
||
|
||
// getUpstreamEndpoint 根据提供商短名返回聊天补全接口的上游 URL。
|
||
func (s *ModelProviderService) getUpstreamEndpoint(providerName string) string {
|
||
if endpoint, ok := gaia.DefaultChatCompletionsEndpoints[providerName]; ok {
|
||
return endpoint
|
||
}
|
||
return ""
|
||
}
|
||
|
||
// getUpstreamBase 返回提供商的上游根地址(用于通用代理)。优先使用 provider_credentials 的 openai_api_base(如 "https://yunwu.ai"),便于计费与多租户区分。
|
||
func (s *ModelProviderService) getUpstreamBase(providerName string, creds *gaiaResponse.ProviderCredentials) string {
|
||
if creds != nil && strings.TrimSpace(creds.Endpoint) != "" {
|
||
return strings.TrimSuffix(strings.TrimSpace(creds.Endpoint), "/")
|
||
}
|
||
if base, ok := gaia.DefaultAPIBase[providerName]; ok {
|
||
return strings.TrimSuffix(base, "/")
|
||
}
|
||
return ""
|
||
}
|
||
|
||
// ProxyRequest 将任意路径的请求转发到上游(anthropic /v1/messages、gemini /v1beta/...、openai /v1/chat/completions、/v1/images/generations、/v1/embeddings 等)。
|
||
// @Tags System Integrated
|
||
// @Summary 通用代理请求
|
||
// @Security ApiKeyAuth
|
||
// @accept application/json
|
||
// @Produce application/json
|
||
// provider 可通过 X-Gaia-Provider 头、query provider= 或 body 中的 model 字段推断;上游 base 优先使用 creds.Endpoint(openai_api_base)。
|
||
func (s *ModelProviderService) ProxyRequest(
|
||
userID, path, method string, reqHeader http.Header, body []byte, writer io.Writer) (err error) {
|
||
// init
|
||
var providerName string
|
||
if path = strings.TrimPrefix(path, "/"); path == "" {
|
||
return fmt.Errorf("代理路径不能为空")
|
||
}
|
||
|
||
// 解析 provider:头 > query 已在 handler 传入;此处从 body 取 model 仅当 body 为 JSON 且含 model 时用于推断
|
||
xGaiaProvider := reqHeader.Get("X-Gaia-Provider")
|
||
global.GVA_LOG.Info("ProxyRequest 解析 provider", zap.String("path", path), zap.String(
|
||
"X-Gaia-Provider", xGaiaProvider), zap.Int("body_len", len(body)))
|
||
if p := xGaiaProvider; p != "" {
|
||
providerName = strings.TrimSpace(strings.ToLower(p))
|
||
}
|
||
if providerName == "" && len(body) > 0 {
|
||
var obj map[string]interface{}
|
||
if err = json.Unmarshal(body, &obj); err == nil {
|
||
if m, ok := obj["model"].(string); ok && m != "" {
|
||
global.GVA_LOG.Info("ProxyRequest 从 body 解析 model", zap.String("model", m))
|
||
// 按“已选模型”解析实际渠道(如 gpt-5-chat 若只在 Azure 下勾选则走 azure)
|
||
providerName, err = s.resolveProviderByModel(m)
|
||
if err != nil {
|
||
global.GVA_LOG.Error("ProxyRequest resolveProviderByModel 失败", zap.String(
|
||
"model", m), zap.Error(err))
|
||
return err
|
||
}
|
||
global.GVA_LOG.Info("ProxyRequest 解析得到 provider", zap.String("provider", providerName))
|
||
}
|
||
}
|
||
}
|
||
if providerName == "" {
|
||
return fmt.Errorf("请指定 provider:设置请求头 X-Gaia-Provider 或 query provider=,或在 body 中提供 model 字段")
|
||
}
|
||
|
||
// 若未从 body model 解析出 provider,则只校验该提供商已启用
|
||
if !s.isProviderEnabled(providerName) {
|
||
return fmt.Errorf("提供商 %s 未开启", providerName)
|
||
}
|
||
|
||
var base string
|
||
var bodyReader io.Reader
|
||
var creds *gaiaResponse.ProviderCredentials
|
||
if creds, err = s.GetDifyProviderCredentials(providerName); err != nil {
|
||
return err
|
||
}
|
||
|
||
if base = s.getUpstreamBase(providerName, creds); base == "" {
|
||
return fmt.Errorf("提供商 %s 无可用上游地址", providerName)
|
||
}
|
||
|
||
// 若 body 是 JSON 且含 stream: true,注入 stream_options.include_usage = true
|
||
// 这样上游会在 SSE 末尾的 data 行返回 usage,供后续计费解析使用。
|
||
if len(body) > 0 {
|
||
var bodyObj map[string]interface{}
|
||
if json.Unmarshal(body, &bodyObj) == nil {
|
||
if streamVal, ok := bodyObj["stream"].(bool); ok && streamVal {
|
||
if _, hasOpt := bodyObj["stream_options"]; !hasOpt {
|
||
bodyObj["stream_options"] = map[string]interface{}{"include_usage": true}
|
||
if injected, e := json.Marshal(bodyObj); e == nil {
|
||
body = injected
|
||
}
|
||
}
|
||
}
|
||
}
|
||
bodyReader = bytes.NewReader(body)
|
||
}
|
||
|
||
// 构建请求 URL,Azure 需要特殊处理
|
||
var requestURL string
|
||
if providerName == gaia.ProviderAzure {
|
||
// Azure OpenAI 有两种 API 格式:
|
||
// 1. 新版 v1 API(2025年8月后):/openai/v1/... 不需要 api-version 参数
|
||
// 2. 传统 API:/openai/deployments/{deployment}/... 需要 api-version 参数
|
||
// 参考:https://learn.microsoft.com/en-us/azure/ai-foundry/openai/api-version-lifecycle
|
||
//
|
||
// 当请求路径以 "v1/" 开头时,使用新版 v1 API,不添加 api-version
|
||
if strings.HasPrefix(path, "v1/") {
|
||
// 新版 v1 API:/openai/v1/chat/completions(不需要 api-version)
|
||
requestURL = base + "/openai/" + path
|
||
} else if strings.HasPrefix(path, "openai/v1/") {
|
||
// 已经包含完整的 openai/v1 前缀
|
||
requestURL = base + "/" + path
|
||
} else if strings.HasPrefix(path, "openai/") {
|
||
// 其他 openai 路径(如 openai/deployments/...),可能需要 api-version
|
||
requestURL = base + "/" + path
|
||
if creds.APIVersion != "" {
|
||
requestURL += "?api-version=" + creds.APIVersion
|
||
}
|
||
} else {
|
||
// 其他路径,添加 openai 前缀并使用传统 API
|
||
requestURL = base + "/openai/" + path
|
||
if creds.APIVersion != "" {
|
||
requestURL += "?api-version=" + creds.APIVersion
|
||
}
|
||
}
|
||
} else {
|
||
requestURL = base + "/" + path
|
||
}
|
||
|
||
httpReq, err := http.NewRequest(method, requestURL, bodyReader)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
// 复制常用请求头,Azure 使用 api-key 头,其他使用 Authorization Bearer
|
||
if providerName == gaia.ProviderAzure {
|
||
httpReq.Header.Set("api-key", creds.APIKey)
|
||
} else {
|
||
httpReq.Header.Set("Authorization", "Bearer "+creds.APIKey)
|
||
}
|
||
if ct := reqHeader.Get("Content-Type"); ct != "" {
|
||
httpReq.Header.Set("Content-Type", ct)
|
||
}
|
||
if accept := reqHeader.Get("Accept"); accept != "" {
|
||
httpReq.Header.Set("Accept", accept)
|
||
}
|
||
// 流式请求
|
||
if reqHeader.Get("Accept") == "text/event-stream" || reqHeader.Get("Accept") == "" {
|
||
// 不强制覆盖,上游可能根据 body 的 stream 返回 SSE
|
||
}
|
||
|
||
client := &http.Client{Timeout: 5 * time.Minute}
|
||
resp, err := client.Do(httpReq)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
// 记录代理日志(用于计费时可区分 openai_api_base)
|
||
startTime := time.Now()
|
||
modelOrPath := path
|
||
if len(body) > 0 {
|
||
var obj map[string]interface{}
|
||
if json.Unmarshal(body, &obj) == nil {
|
||
if m, _ := obj["model"].(string); m != "" {
|
||
modelOrPath = m
|
||
}
|
||
}
|
||
}
|
||
var logStatus, logError string
|
||
var promptTokens, completionTokens int
|
||
defer func() {
|
||
if logStatus == "" {
|
||
logStatus = "success"
|
||
}
|
||
global.GVA_DB.Create(&gaia.ModelProxyLog{
|
||
UserId: userID,
|
||
ProviderName: providerName,
|
||
ModelName: modelOrPath,
|
||
RequestTokens: promptTokens,
|
||
ResponseTokens: completionTokens,
|
||
Status: logStatus,
|
||
ErrorMessage: logError,
|
||
CreatedAt: startTime,
|
||
})
|
||
|
||
// 计费:仅成功时扣费
|
||
if logStatus == "success" {
|
||
if promptTokens > 0 || completionTokens > 0 {
|
||
// LLM 类型:按 token 计费
|
||
pricing, _ := s.fetchModelPricingFromDify(modelOrPath)
|
||
delta := calcQuotaDelta(pricing, modelOrPath, promptTokens, completionTokens)
|
||
deductAccountQuota(userID, delta)
|
||
} else if isImageOrPerRequestPath(path) {
|
||
// 图片生成等无 usage 的接口:按请求次数计费,默认单价见 gaia.DefaultImageGenerationPriceUSD
|
||
pricing, _ := s.fetchModelPricingFromDify(modelOrPath)
|
||
var delta float64
|
||
if pricing != nil && pricing.Input > 0 {
|
||
// 若定价表有配置,input 字段用作每次请求单价
|
||
delta = pricing.Input
|
||
} else {
|
||
delta = gaia.DefaultImageGenerationPriceUSD
|
||
}
|
||
deductAccountQuota(userID, delta)
|
||
}
|
||
}
|
||
}()
|
||
|
||
// 写回状态码与响应头(流式由上游 Content-Type 决定)
|
||
if w, ok := writer.(http.ResponseWriter); ok {
|
||
for k, v := range resp.Header {
|
||
for _, vv := range v {
|
||
w.Header().Add(k, vv)
|
||
}
|
||
}
|
||
w.WriteHeader(resp.StatusCode)
|
||
}
|
||
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
|
||
_, _ = io.Copy(writer, resp.Body)
|
||
return nil
|
||
}
|
||
|
||
// extractUsage 从 OpenAI 格式的 JSON 对象中提取 usage 字段
|
||
extractUsage := func(data []byte) {
|
||
var obj gaia.ModelUsageResponse
|
||
if json.Unmarshal(data, &obj) == nil && obj.Usage != nil {
|
||
if obj.Usage.PromptTokens > 0 {
|
||
promptTokens = obj.Usage.PromptTokens
|
||
}
|
||
if obj.Usage.CompletionTokens > 0 {
|
||
completionTokens = obj.Usage.CompletionTokens
|
||
}
|
||
}
|
||
}
|
||
|
||
// 流式响应:按行扫描,顺带从最后一条含 usage 的 data 行中提取 token 数
|
||
if strings.Contains(resp.Header.Get("Content-Type"), "text/event-stream") {
|
||
if flusher, ok := writer.(http.Flusher); ok {
|
||
scanner := bufio.NewScanner(resp.Body)
|
||
for scanner.Scan() {
|
||
line := scanner.Text()
|
||
if _, err = writer.Write([]byte(line + "\n")); err != nil {
|
||
logStatus, logError = "error", err.Error()
|
||
return err
|
||
}
|
||
flusher.Flush()
|
||
// 解析 SSE data 行中的 usage(stream_options.include_usage=true 时上游会附带)
|
||
if strings.HasPrefix(line, "data:") && strings.Contains(line, `"usage"`) {
|
||
payload := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
|
||
extractUsage([]byte(payload))
|
||
}
|
||
}
|
||
if err = scanner.Err(); err != nil {
|
||
logStatus, logError = "error", err.Error()
|
||
return err
|
||
}
|
||
return nil
|
||
}
|
||
}
|
||
|
||
// 非流式响应:TeeReader 同时转发给客户端并读取 body 用于解析 usage
|
||
var buf bytes.Buffer
|
||
tee := io.TeeReader(resp.Body, &buf)
|
||
_, err = io.Copy(writer, tee)
|
||
if err != nil {
|
||
logStatus, logError = "error", err.Error()
|
||
} else {
|
||
extractUsage(buf.Bytes())
|
||
}
|
||
return err
|
||
}
|
||
|
||
// GetProxyLogs 分页查询代理日志(model_proxy_log 表)。
|
||
func (s *ModelProviderService) GetProxyLogs(info gaiaRequest.GetProxyLogsReq) (list []map[string]interface{}, total int64, err error) {
|
||
page, pageSize := info.Page, info.PageSize
|
||
if page < 1 {
|
||
page = 1
|
||
}
|
||
if pageSize < 1 || pageSize > 100 {
|
||
pageSize = 20
|
||
}
|
||
db := global.GVA_DB.Table("model_proxy_log")
|
||
if err = db.Count(&total).Error; err != nil {
|
||
err = fmt.Errorf("查询日志总数失败:%w", err)
|
||
return
|
||
}
|
||
offset := (page - 1) * pageSize
|
||
if err = db.Order("created_at DESC").Limit(pageSize).Offset(offset).Find(&list).Error; err != nil {
|
||
err = fmt.Errorf("查询日志列表失败:%w", err)
|
||
return
|
||
}
|
||
return list, total, nil
|
||
}
|
||
|
||
// isProviderEnabled 检查该提供商是否已启用(未校验具体模型列表,用于通用代理)。
|
||
func (s *ModelProviderService) isProviderEnabled(providerName string) bool {
|
||
var config gaia.ModelProviderConfig
|
||
if err := global.GVA_DB.Where("provider_name = ? AND enabled = ?", providerName, true).First(&config).Error; err != nil {
|
||
return false
|
||
}
|
||
return true
|
||
}
|
||
|
||
// isImageOrPerRequestPath 判断请求路径是否为按次计费的接口(图片生成、语音合成等无 usage 字段的接口)。
|
||
func isImageOrPerRequestPath(path string) bool {
|
||
perRequestPaths := []string{
|
||
"images/generations",
|
||
"images/edits",
|
||
"images/variations",
|
||
"audio/speech",
|
||
"audio/transcriptions",
|
||
"audio/translations",
|
||
}
|
||
lpath := strings.ToLower(path)
|
||
for _, p := range perRequestPaths {
|
||
if strings.Contains(lpath, p) {
|
||
return true
|
||
}
|
||
}
|
||
return false
|
||
}
|