49 Commits

Author SHA1 Message Date
npc0-hue cb0233f7b6 fix: 打印原文 2026-04-25 10:17:36 +08:00
npc0-hue 6c0c6ba1fe fix: client_ip 打印客户端ip 2026-04-25 10:11:09 +08:00
npc0-hue 3240e6e4e5 fix: 模型名称自动加us.anthropic之类的,方便维护 2026-04-24 18:09:06 +08:00
npc0-hue 8d823787b7 fix: 构建 Bedrock 请求详情打印 2026-04-24 17:24:45 +08:00
npc0-hue 84fc7bcb43 fix: 构建 Bedrock 请求详情打印 2026-04-24 16:52:49 +08:00
npc0-hue e08b1b079c fix: 构建 Bedrock URL处理 2026-04-24 16:38:41 +08:00
npc0-hue 3f6ef97148 fix: anthropic.* 前缀是 AWS Bedrock 专用模型 ID 格式(如 anthropic.claude-sonnet-4-6), 2026-04-24 14:42:15 +08:00
npc0-hue 33a34e4181 fix: 后端转发修改转发时长 2026-04-24 14:32:06 +08:00
npc0-hue 347e48cef6 fix: 添加DeepSeek支持 2026-04-24 12:12:28 +08:00
npc0-hue d2a7ade1b0 fix: 补充index 2026-04-24 11:04:57 +08:00
npc0-hue efc25217dc fix(bedrock): 代理改为反向代理模式,兼容不支持 CONNECT 的代理
- 取消 http.Transport.Proxy(要求代理支持 HTTP CONNECT 隧道)
- 改为直接向代理地址发 HTTP 请求(host:port),路径同 Bedrock 原生 API
- httpReq.Host 设为真实 Bedrock host,SigV4 仍针对真实 host 签名后复制头部
- 支持凭证内 bedrock_proxy_url 与全局 BEDROCK_PROXY 环境变量(优先凭证)
- config.Gaia 增加 BedrockProxy 字段,overrideGaiaFromEnv 从 BEDROCK_PROXY 读取
2026-04-24 10:58:58 +08:00
npc0-hue 618a355ec8 feat(bedrock): 全局 BEDROCK_PROXY 环境变量支持
凭证里 bedrock_proxy_url 优先,未配置则回落到 BEDROCK_PROXY env。
解决 admin-server 在受限地区(如中国)直连 bedrock-runtime 被 Anthropic 地区拦截的问题。
与 Dify Python 侧 BEDROCK_PROXY 含义一致,docker-compose 设置一次即可全局生效。
2026-04-24 10:48:36 +08:00
npc0-hue d474f15673 fix(routing): Claude 模型路由优先 aws,未开启再回落 anthropic
原来写死 anthropic 优先,导致只要 anthropic provider 配置过就直连
api.anthropic.com,在地区受限时报 geo-block 错误。
改为:哪个 provider 开启且包含该模型就用哪个(aws 先试)。
2026-04-24 10:15:47 +08:00
npc0-hue 3bad30bff1 feat(bedrock): 支持 bedrock_proxy_url 代理配置
从 Dify provider_credentials 的 encrypted_config 中读取 bedrock_proxy_url 字段,
若非空则将 HTTP 请求经该代理(host:port 或 http://host:port)转发到 AWS Bedrock,
不再强制直连 bedrock-runtime.{region}.amazonaws.com。
变更:
- ProviderCredentials 新增 BedrockProxyURL 字段
- ConfigKeyBedrockProxyURL 常量
- GetDifyProviderCredentials 提取 bedrock_proxy_url(明文,不解密)
- proxyBedrockRequest 根据 BedrockProxyURL 配置 http.Transport.Proxy
2026-04-24 09:59:56 +08:00
npc0-hue ea77171028 fix(aws): 修复 provider_credentials 表 LIKE 查询无法匹配 bedrock_claude
Dify 内部将 AWS Bedrock provider 存储为 'langgenius/bedrock_claude/bedrock_claude',
而非含 'aws' 关键字的名称,导致 LIKE '%aws%' 查询返回空结果。
新增 difyProviderLikePattern() 辅助方法:
- ProviderAWS → LIKE '%bedrock%'(兼容 bedrock_claude 和 bedrock 两种插件包)
- 其他 provider 保持原有 LIKE '%<name>%' 逻辑
2026-04-23 17:53:10 +08:00
npc0-hue bb1db4ca99 fix(aws): 前端 & 后端测试凭证支持 AWS Bedrock
- TestProviderCredentials: AWS 渠道返回 access_key_id(脱敏)/region/session_token 状态
- modelManagement/index.vue: 测试凭证弹窗适配 AWS 字段展示
- 模型列表提示区分 AWS/Anthropic/其他,附上常用 Bedrock 模型 ID 示例
2026-04-23 15:44:09 +08:00
npc0-hue 5618c89721 fix(aws): 打通 AWS Bedrock 转发链路
- model_provider_constants_extend.go: 新增 AWS 凭证配置 key 常量
  (ConfigKeyAWSAccessKeyID/SecretAccessKey/SessionToken/Region)
- GetDifyProviderCredentials: 解析 Dify bedrock provider 的 AWS 凭证字段
- GetAvailableModelsFromDify: AWS/Anthropic 在 credentials 之前提前返回,
  避免因 APIKey 为空触发误报
- ProxyRequest: 在获取凭证后新增 AWS 分支,直接调用 proxyBedrockRequest
  完成 SigV4 签名直连 Bedrock 原生 API(bedrock_extend.go 中已实现,此前未被调用)
2026-04-23 15:41:53 +08:00
npc0-hue 3a02769e4a fix: 转发添加新模型和模型提供商 2026-04-23 14:49:29 +08:00
npc0-hue 0af791f56c fix: admin初始化修改 2026-03-31 15:26:21 +08:00
npc0-hue 0b20a17074 fix: 移除转发body 2026-03-31 15:02:04 +08:00
npc0-hue d13e083f37 fix: 调整docker初始化地址 2026-03-30 16:43:00 +08:00
npc0-hue 1cc7f4bc7b frea: admin初始化修改 2026-03-30 16:24:02 +08:00
npc0-hue 8341905b21 Merge pull request #111 from YFGaia/dependabot/npm_and_yarn/web/eslint-react/eslint-plugin-2.7.4
chore(deps-dev): bump @eslint-react/eslint-plugin from 2.7.0 to 2.7.4 in /web
2026-03-27 15:38:55 +08:00
npc0-hue b5af9263f8 Merge pull request #110 from YFGaia/dependabot/npm_and_yarn/web/elkjs-0.11.0
chore(deps): bump elkjs from 0.9.3 to 0.11.0 in /web
2026-03-27 15:37:25 +08:00
npc0-hue 950ef2d13e fix: 批量工作流刷新效果修改 2026-03-27 15:25:21 +08:00
npc0-hue 8257113c50 fix: 批处理修改 2026-03-27 14:51:11 +08:00
npc0-hue fff9543a37 fix: An “Internal Server Error” occurs on the monitoring page of the Chatflow application.
https://github.com/YFGaia/dify-plus/issues/116
2026-03-27 14:50:09 +08:00
npc0-hue 02e568c6d5 fix: 邮箱查询用户 2026-03-25 11:16:31 +08:00
npc0-hue 2bc3e4dc39 fix: fastopenapi路由认证 2026-03-25 09:53:24 +08:00
npc0-hue d7b77bee2e fix: 数据库密码莫名被改,注释掉数据库端口和redis端口 2026-03-23 17:25:07 +08:00
npc0-hue ce82b5f776 fix: docker-compose.middleware.yaml修复 2026-03-23 16:02:19 +08:00
npc0-hue f328825da7 fear: 恢复api密钥token使用量控制 2026-03-23 16:01:42 +08:00
npc0-hue 892a5f9127 fix: 批处理设置超时 2026-03-18 11:30:11 +08:00
npc0-hue 1d6e41829a fix: 调用第三方邮箱 API和镜像修改 2026-03-13 11:44:39 +08:00
npc0-hue 0484655f13 fix: 规范化处理代码 2026-03-12 14:47:02 +08:00
npc0-hue 786920c7e3 fix: 规范化处理代码 2026-03-12 11:42:02 +08:00
npc0-hue b84e94250f fix: 规范化处理代码 2026-03-12 11:29:37 +08:00
npc0-hue 9591795b10 fix: 千问3.5-plus计费不正确修复 2026-03-12 11:11:47 +08:00
npc0-hue 8df2e46658 fix: 余额不足直接拦截 2026-03-12 10:48:01 +08:00
npc0-hue 22d01c3c55 fix: 测试连接出现双重 /admin 2026-03-12 09:43:46 +08:00
npc0-hue d1b32f4310 fix: 综合修复(请求头打印、邮箱/用户名匹配、密钥显示、签名校验、删除报错等)
Made-with: Cursor
2026-03-12 09:21:21 +08:00
npc0-hue 4b5e2eaf35 fix: 计费完善 2026-03-11 12:05:53 +08:00
npc0-hue e283aa4055 feat: 前段完善 2026-03-10 11:58:14 +08:00
npc0-hue b5aba401e5 feat: 初始化token和前段修改 2026-03-10 10:17:57 +08:00
npc0-hue bc2edcdde6 feat: 钉钉机器人转发(未测试)
fix: admin初始化出错
2026-03-09 22:37:08 +08:00
npc0-hue 5962b9b518 feat: 钉钉机器人转发(未测试)
fix: admin初始化出错
2026-03-09 22:34:02 +08:00
npc0-hue 1b447b7b0b fix: admin初始化有误修复 2026-03-06 12:50:00 +08:00
dependabot[bot] e0cf5e2e27 chore(deps-dev): bump @eslint-react/eslint-plugin in /web
Bumps [@eslint-react/eslint-plugin](https://github.com/Rel1cx/eslint-react/tree/HEAD/packages/plugins/eslint-plugin) from 2.7.0 to 2.7.4.
- [Release notes](https://github.com/Rel1cx/eslint-react/releases)
- [Changelog](https://github.com/Rel1cx/eslint-react/blob/main/CHANGELOG.md)
- [Commits](https://github.com/Rel1cx/eslint-react/commits/v2.7.4/packages/plugins/eslint-plugin)

---
updated-dependencies:
- dependency-name: "@eslint-react/eslint-plugin"
  dependency-version: 2.7.4
  dependency-type: direct:development
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-02-11 11:12:55 +00:00
dependabot[bot] 2520715b8a chore(deps): bump elkjs from 0.9.3 to 0.11.0 in /web
Bumps [elkjs](https://github.com/kieler/elkjs) from 0.9.3 to 0.11.0.
- [Release notes](https://github.com/kieler/elkjs/releases)
- [Commits](https://github.com/kieler/elkjs/compare/0.9.3...0.11.0)

---
updated-dependencies:
- dependency-name: elkjs
  dependency-version: 0.11.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-02-11 11:12:47 +00:00
69 changed files with 4654 additions and 966 deletions
+3
View File
@@ -209,6 +209,9 @@ api/.vscode
.history
.idea/
.claude/
.cursor/
openspec/
# pnpm
/.pnpm-store
+9
View File
@@ -0,0 +1,9 @@
# Admin (Go Backend) Agent Guide
## Rules (must follow)
### 禁止匿名 struct
- **禁止在代码中出现匿名 struct**。不得使用 `var x []struct { ... }``var x struct { ... }` 或字面量 `struct { A int }{1}` 等匿名结构体。
- 所有用于 GORM 查询扫描、缓存结构、API 请求/响应的结构体必须定义为**具名类型**,放在合适的 model 包(如 `model/gaia/request``model/gaia/response`)或当前包顶部,便于复用和规范约束。
- 示例:用 `[]response.AppQuotaRankingRow` 替代 `[]struct { AppID string; TotalCost float64; ... }`;用 `response.AppQuotaRankingCache` 替代 `struct { List ...; Total int64 }`
+1 -1
View File
@@ -1,4 +1,4 @@
FROM golang:alpine as builder
FROM golang:alpine AS builder
RUN mkdir /app
WORKDIR /app
+1
View File
@@ -12,6 +12,7 @@ type ApiGroup struct {
BatchWorkflowApi
AppVersionApi
ModelProviderApi
ForwardProxyApi
}
var (
+108
View File
@@ -0,0 +1,108 @@
package gaia
import (
"crypto/sha256"
"fmt"
"github.com/flipped-aurora/gin-vue-admin/server/global"
gaiaModel "github.com/flipped-aurora/gin-vue-admin/server/model/gaia"
"github.com/flipped-aurora/gin-vue-admin/server/model/gaia/request"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"net/http"
)
type ForwardProxyApi struct{}
// ForwardProxy 转发代理入口:免 JWT,通过 forwarding token + ding_id 鉴权并计费
// @Tags ForwardProxy
// @Summary GPT 转发代理(钉钉入口,无需 JWT)
// @Param X-Forward-Token header string false "转发 Token"
// @Param X-Ding-Id header string false "钉钉 ID"
// @Param forward_token query string false "转发 TokenHeader 优先)"
// @Param ding_id query string false "钉钉 IDHeader 优先)"
// @Param path path string true "上游路径"
// @Router /gaia/forward/proxy/{path} [get,post,put,patch,delete]
func (f *ForwardProxyApi) ForwardProxy(c *gin.Context) {
// 打印请求 Header,便于排查转发问题
global.GVA_LOG.Info("ForwardProxy 请求头",
zap.Any("headers", c.Request.Header),
zap.String("method", c.Request.Method),
zap.String("path", c.Request.URL.Path),
)
// 1. 读取转发配置
integrate := systemIntegratedService.GetIntegratedConfig(gaiaModel.SystemIntegrationDingTalk)
configMap, err := systemIntegratedService.ParseDingTalkConfig(integrate.Config)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{"message": "配置解析失败"}})
return
}
// 2. 获取并校验 forwarding token(存在有效 Token 即视为开启转发能力)
dingId := c.GetHeader("X-Ding-Id")
apiKey := c.GetHeader("X-Api-Key")
bearer := c.GetHeader("Authorization")
token := c.GetHeader("X-Forward-Token")
if (len(bearer) > gaiaModel.BearerLength || len(apiKey) > gaiaModel.BearerLength) && len(dingId) == 0 {
if len(bearer) > gaiaModel.BearerLength {
if bearer[:gaiaModel.BearerLength] == "Bearer " {
bearer = bearer[gaiaModel.BearerLength:]
}
} else if len(apiKey) > gaiaModel.BearerLength {
if apiKey[:gaiaModel.BearerLength] == "Bearer " {
bearer = apiKey[gaiaModel.BearerLength:]
} else {
bearer = apiKey
}
}
if dingId, err = systemIntegratedService.ParseForwardToken(bearer, configMap.ForwardConfig.Tokens); err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": gin.H{"message": "Token 验证失败: " + err.Error()}})
return
}
} else {
if token == "" {
token = c.Query("forward_token")
}
if token == "" {
c.JSON(http.StatusUnauthorized, gin.H{"error": gin.H{"message": "缺少转发 Token"}})
return
}
if !validateForwardToken(token, configMap.ForwardConfig.Tokens) {
c.JSON(http.StatusUnauthorized, gin.H{"error": gin.H{"message": "无效的转发 Token"}})
return
}
// 4. 获取 ding_id
if dingId == "" {
dingId = c.Query("ding_id")
}
if dingId == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{"message": "缺少 ding_id"}})
return
}
}
// 5. 解析 account_id
accountId, err := systemIntegratedService.ResolveAccountByDingId(dingId, configMap.EmailApi)
if err != nil {
global.GVA_LOG.Warn("ForwardProxy 用户解析失败", zap.String("ding_id", dingId), zap.Error(err))
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{"message": "无法解析用户:" + err.Error()}})
return
}
// 6. 复用与 Proxy 相同的转发逻辑(path/body/ProxyRequest
proxyWithAccountId(c, accountId)
}
// validateForwardToken 校验 token 是否在转发 Token 列表中(SHA256 比对)
func validateForwardToken(token string, tokens []request.ForwardToken) bool {
hash := fmt.Sprintf("%x", sha256.Sum256([]byte(token)))
for _, t := range tokens {
if t.TokenHash == hash {
return true
}
}
return false
}
+118 -124
View File
@@ -2,13 +2,13 @@ package gaia
import (
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"github.com/flipped-aurora/gin-vue-admin/server/global"
"github.com/flipped-aurora/gin-vue-admin/server/model/common/response"
gaiaReq "github.com/flipped-aurora/gin-vue-admin/server/model/gaia/request"
"github.com/flipped-aurora/gin-vue-admin/server/service"
"github.com/flipped-aurora/gin-vue-admin/server/utils"
"github.com/gin-gonic/gin"
@@ -31,7 +31,7 @@ func (m *ModelProviderApi) GetProviderList(c *gin.Context) {
list, err := modelProviderService.GetProviderList()
if err != nil {
global.GVA_LOG.Error("获取提供商配置列表失败", zap.Error(err))
response.FailWithMessage("获取失败: "+err.Error(), c)
response.FailWithMessage("获取失败:"+err.Error(), c)
return
}
response.OkWithData(list, c)
@@ -54,20 +54,20 @@ func (m *ModelProviderApi) UpdateProviderConfig(c *gin.Context) {
}
if err := c.ShouldBindJSON(&req); err != nil {
response.FailWithMessage("参数错误: "+err.Error(), c)
response.FailWithMessage("参数错误:"+err.Error(), c)
return
}
if err := modelProviderService.UpdateProviderConfig(req.ProviderName, req.Enabled, req.Models); err != nil {
global.GVA_LOG.Error("更新提供商配置失败", zap.String("provider", req.ProviderName), zap.Error(err))
response.FailWithMessage("更新失败: "+err.Error(), c)
response.FailWithMessage("更新失败:"+err.Error(), c)
return
}
response.OkWithMessage("更新成功", c)
}
// GetModels 获取开启的模型列表(OpenAI格式
// GetModels 获取开启的模型列表(OpenAI 格式,供第三方兼容调用;成功时返回裸 JSON,错误时与项目统一使用 response)。
// @Tags ModelProvider
// @Summary 获取开启的模型列表
// @Security ApiKeyAuth
@@ -79,73 +79,73 @@ func (m *ModelProviderApi) GetModels(c *gin.Context) {
models, err := modelProviderService.GetEnabledModels()
if err != nil {
global.GVA_LOG.Error("获取模型列表失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": "获取模型列表失败: " + err.Error(),
},
})
response.FailWithMessage("获取失败:"+err.Error(), c)
return
}
c.JSON(http.StatusOK, models)
}
// Proxy 通用中转 API:将 /gaia/proxy/* 的请求按路径转发到上游(如 /v1/chat/completions、/v1/messages、/v1/images/generations、/v1/embeddings 等)
// 上游 base 优先使用 provider_credentials 的 openai_api_base(如 "https://yunwu.ai"),便于计费区分。
// proxyWithAccountId 通用代理逻辑:按路径转发到上游并计费
func proxyWithAccountId(c *gin.Context, accountId string) {
path := c.Param("path")
if path == "" || path == "/" {
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{"message": "代理路径不能为空"}})
return
}
reqHeader := c.Request.Header.Clone()
if q := strings.TrimSpace(c.Query("provider")); q != "" {
reqHeader.Set("X-Gaia-Provider", q)
}
body, err := io.ReadAll(c.Request.Body)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{"message": "读取请求体失败"}})
return
}
var bodyModel string
if len(body) > 0 {
var parseObj map[string]interface{}
if jsonErr := json.Unmarshal(body, &parseObj); jsonErr == nil {
if mv, ok := parseObj["model"].(string); ok {
bodyModel = mv
}
}
}
global.GVA_LOG.Info("Gaia代理请求入参",
zap.String("account_id", accountId),
zap.String("path", path),
zap.String("method", c.Request.Method),
zap.Int("body_len", len(body)),
zap.String("body_model", bodyModel),
zap.String("client_ip", c.ClientIP()),
)
// 余额前置检查:余额耗尽时直接拦截,不继续请求上游
if quotaErr := modelProviderService.CheckAccountQuota(accountId); quotaErr != nil {
c.JSON(http.StatusPaymentRequired, gin.H{"error": gin.H{"message": quotaErr.Error()}})
return
}
if err = modelProviderService.ProxyRequest(
accountId, path, c.Request.Method, reqHeader, body, c.Writer); err != nil {
global.GVA_LOG.Info("Gaia代理请求body",
zap.String("body", string(body)),
)
global.GVA_LOG.Error("代理请求失败", zap.String("account_id", accountId), zap.String("path", path), zap.Error(err))
if !c.Writer.Written() {
c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{"message": err.Error()}})
}
}
}
// Proxy 通用中转 API:将 /gaia/proxy/* 的请求按路径转发到上游(需 JWT,account 来自当前登录用户)。
// @Tags ModelProvider
// @Summary 通用中转API(按路径转发)
// @Security ApiKeyAuth
// @Param path path string true "上游路径,如 v1/chat/completions、v1/messages"
// @Router /gaia/proxy/*path [get,post,put,patch,delete]
func (m *ModelProviderApi) Proxy(c *gin.Context) {
// init
var err error
var body []byte
path := c.Param("path")
userID := utils.GetUserUuid(c).String()
if path == "" || path == "/" {
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{"message": "代理路径不能为空"}})
return
}
// 将 query provider 转为请求头,供 service 解析
reqHeader := c.Request.Header.Clone()
if q := strings.TrimSpace(c.Query("provider")); q != "" {
reqHeader.Set("X-Gaia-Provider", q)
}
if body, err = io.ReadAll(c.Request.Body); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{"message": "读取请求体失败"}})
return
}
// 打印传入参数便于排查
queryProvider := strings.TrimSpace(c.Query("provider"))
var bodyModel string
if len(body) > 0 {
var parseObj map[string]interface{}
if jsonErr := json.Unmarshal(body, &parseObj); jsonErr == nil {
if m, ok := parseObj["model"].(string); ok {
bodyModel = m
}
}
}
global.GVA_LOG.Info("Gaia代理请求入参",
zap.String("path", path),
zap.String("method", c.Request.Method),
zap.String("query_provider", queryProvider),
zap.Int("body_len", len(body)),
zap.String("body_model", bodyModel),
zap.String("body", string(body)),
)
if err = modelProviderService.ProxyRequest(
userID, path, c.Request.Method, reqHeader, body, c.Writer); err != nil {
global.GVA_LOG.Error("代理请求失败", zap.String("user_id", userID), zap.String(
"path", path), zap.Error(err))
if !c.Writer.Written() {
c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{"message": err.Error()}})
}
}
accountId := utils.GetUserUuid(c).String()
proxyWithAccountId(c, accountId)
}
// GetAvailableModels 获取提供商的可用模型
@@ -160,17 +160,15 @@ func (m *ModelProviderApi) Proxy(c *gin.Context) {
func (m *ModelProviderApi) GetAvailableModels(c *gin.Context) {
providerName := c.Query("provider_name")
if providerName == "" {
response.FailWithMessage("参数错误: provider_name不能为空", c)
response.FailWithMessage("参数错误:provider_name不能为空", c)
return
}
models, err := modelProviderService.GetAvailableModelsFromDify(providerName)
if err != nil {
global.GVA_LOG.Error("获取可用模型失败", zap.String("provider", providerName), zap.Error(err))
response.FailWithMessage("获取失败: "+err.Error(), c)
response.FailWithMessage("获取失败:"+err.Error(), c)
return
}
response.OkWithData(models, c)
}
@@ -186,35 +184,55 @@ func (m *ModelProviderApi) GetAvailableModels(c *gin.Context) {
func (m *ModelProviderApi) TestProviderCredentials(c *gin.Context) {
providerName := c.Query("provider_name")
if providerName == "" {
response.FailWithMessage("参数错误: provider_name不能为空", c)
response.FailWithMessage("参数错误:provider_name不能为空", c)
return
}
creds, err := modelProviderService.GetDifyProviderCredentials(providerName)
if err != nil {
global.GVA_LOG.Error("获取提供商凭证失败", zap.String("provider", providerName), zap.Error(err))
response.FailWithMessage("获取凭证失败: "+err.Error(), c)
response.FailWithMessage("获取凭证失败:"+err.Error(), c)
return
}
// 隐藏API Key的大部分内容
maskedKey := ""
if len(creds.APIKey) > 8 {
maskedKey = creds.APIKey[:4] + "****" + creds.APIKey[len(creds.APIKey)-4:]
} else {
maskedKey = "****"
}
var result map[string]interface{}
result := map[string]interface{}{
"provider": providerName,
"has_api_key": creds.APIKey != "",
"api_key": maskedKey,
// AWS Bedrock:展示 access key 信息
if creds.AWSAccessKeyID != "" {
maskedKey := "****"
if len(creds.AWSAccessKeyID) > 8 {
maskedKey = creds.AWSAccessKeyID[:4] + "****" + creds.AWSAccessKeyID[len(creds.AWSAccessKeyID)-4:]
}
region := creds.AWSRegion
if region == "" {
region = "us-east-1(默认)"
}
result = map[string]interface{}{
"provider": providerName,
"has_api_key": true,
"api_key": maskedKey,
"aws_access_key_id": maskedKey,
"aws_region": region,
"has_session_token": creds.AWSSessionToken != "",
}
} else {
// 隐藏API Key的大部分内容
maskedKey := ""
if len(creds.APIKey) > 8 {
maskedKey = creds.APIKey[:4] + "****" + creds.APIKey[len(creds.APIKey)-4:]
} else {
maskedKey = "****"
}
result = map[string]interface{}{
"provider": providerName,
"has_api_key": creds.APIKey != "",
"api_key": maskedKey,
}
}
response.OkWithData(result, c)
}
// GetProxyLogs 获取代理日志
// GetProxyLogs 获取代理日志(分页)
// @Tags ModelProvider
// @Summary 获取代理日志
// @Security ApiKeyAuth
@@ -222,54 +240,30 @@ func (m *ModelProviderApi) TestProviderCredentials(c *gin.Context) {
// @Produce application/json
// @Param page query int false "页码"
// @Param page_size query int false "每页数量"
// @Success 200 {object} response.Response{data=map[string]interface{},msg=string} "获取成功"
// @Success 200 {object} response.Response{data=response.PageResult,msg=string} "获取成功"
// @Router /gaia/model-provider/logs [get]
func (m *ModelProviderApi) GetProxyLogs(c *gin.Context) {
page := c.DefaultQuery("page", "1")
pageSize := c.DefaultQuery("page_size", "20")
var pageInt, pageSizeInt int
if _, err := fmt.Sscanf(page, "%d", &pageInt); err != nil {
pageInt = 1
}
if _, err := fmt.Sscanf(pageSize, "%d", &pageSizeInt); err != nil {
pageSizeInt = 20
}
if pageInt < 1 {
pageInt = 1
}
if pageSizeInt < 1 || pageSizeInt > 100 {
pageSizeInt = 20
}
var logs []map[string]interface{}
var total int64
db := global.GVA_DB.Table("model_proxy_log")
// 获取总数
if err := db.Count(&total).Error; err != nil {
global.GVA_LOG.Error("获取日志总数失败", zap.Error(err))
response.FailWithMessage("获取失败: "+err.Error(), c)
var req gaiaReq.GetProxyLogsReq
if err := c.ShouldBindQuery(&req); err != nil {
response.FailWithMessage("参数错误:"+err.Error(), c)
return
}
// 分页查询
offset := (pageInt - 1) * pageSizeInt
if err := db.Order("created_at DESC").Limit(pageSizeInt).Offset(
offset).Find(&logs).Error; err != nil {
global.GVA_LOG.Error("获取日志列表失败", zap.Error(err))
response.FailWithMessage("获取失败: "+err.Error(), c)
if req.Page < 1 {
req.Page = 1
}
if req.PageSize < 1 || req.PageSize > 100 {
req.PageSize = 20
}
list, total, err := modelProviderService.GetProxyLogs(req)
if err != nil {
global.GVA_LOG.Error("获取代理日志失败", zap.Error(err))
response.FailWithMessage("获取失败:"+err.Error(), c)
return
}
result := map[string]interface{}{
"list": logs,
"total": total,
"page": pageInt,
"page_size": pageSizeInt,
}
response.OkWithData(result, c)
response.OkWithDetailed(response.PageResult{
List: list,
Total: total,
Page: req.Page,
PageSize: req.PageSize,
}, "获取成功", c)
}
+273
View File
@@ -2,10 +2,26 @@ package gaia
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"net/url"
"strconv"
"strings"
"github.com/flipped-aurora/gin-vue-admin/server/global"
"github.com/flipped-aurora/gin-vue-admin/server/model/common/response"
"github.com/flipped-aurora/gin-vue-admin/server/model/gaia"
"github.com/flipped-aurora/gin-vue-admin/server/model/gaia/request"
gaiaResp "github.com/flipped-aurora/gin-vue-admin/server/model/gaia/response"
"github.com/flipped-aurora/gin-vue-admin/server/model/system"
serviceGaia "github.com/flipped-aurora/gin-vue-admin/server/service/gaia"
"github.com/flipped-aurora/gin-vue-admin/server/utils"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"time"
)
type SystemApi struct{}
@@ -55,3 +71,260 @@ func (systemApi *SystemApi) SetDingTalk(c *gin.Context) {
}
response.OkWithData("ok", c)
}
// TestEmailApiConfig 测试第三方邮箱 API 配置
// @Tags System
// @Summary 测试第三方邮箱 API 配置
// @Security ApiKeyAuth
// @accept application/json
// @Produce application/json
// @Param data body request.TestEmailApiConfigRequest true "测试配置请求"
// @Success 200 {object} response.Response{data=gaiaResp.TestEmailApiConfigResponse,msg=string} "测试结果"
// @Router /gaia/system/dingtalk/test-email-config [post]
func (systemApi *SystemApi) TestEmailApiConfig(c *gin.Context) {
var req request.TestEmailApiConfigRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.FailWithMessage(err.Error(), c)
return
}
result, err := systemIntegratedService.TestEmailApiConfig(req.Config, req.TestDingID)
if err != nil {
response.FailWithMessage("测试失败:"+err.Error(), c)
return
}
response.OkWithData(result, c)
}
// GetDingTalkTestAuthURL 获取「测试连接」用的钉钉授权 URL,打开后扫码完成即视为连接成功
// @Router /gaia/system/dingtalk/test-auth-url [get]
func (systemApi *SystemApi) GetDingTalkTestAuthURL(c *gin.Context) {
origin := c.GetHeader("Referer")
if origin == "" {
origin = c.GetHeader("Origin")
}
if origin != "" {
if u, err := url.Parse(origin); err == nil {
origin = u.Scheme + "://" + u.Host + strings.TrimSuffix(u.Path, "/")
}
}
if origin == "" {
response.FailWithMessage("无法获取前端地址,请从配置页点击「测试连接」", c)
return
}
authURL, err := systemIntegratedService.GetDingTalkTestAuthURL(origin)
if err != nil {
response.FailWithMessage(err.Error(), c)
return
}
response.OkWithData(gin.H{"auth_url": authURL}, c)
}
// DingTalkTestCallback 测试连接回调:仅用 code 换 token 验证,不登录
// @Router /gaia/system/dingtalk/test-callback [post]
func (systemApi *SystemApi) DingTalkTestCallback(c *gin.Context) {
var req struct {
Code string `json:"code"`
}
if err := c.ShouldBindJSON(&req); err != nil || strings.TrimSpace(req.Code) == "" {
response.FailWithMessage("缺少授权码 code", c)
return
}
if err := systemIntegratedService.DingTalkTestCallback(req.Code); err != nil {
response.FailWithMessage("验证失败: "+err.Error(), c)
return
}
response.OkWithMessage("验证成功", c)
}
// GetForwardTokens 获取转发 Token 列表
// @Tags System
// @Summary 获取转发 Token 列表
// @Security ApiKeyAuth
// @accept application/json
// @Produce application/json
// @Success 200 {object} response.Response{data=gaiaResp.ForwardTokensResponse,msg=string} "查询成功"
// @Router /gaia/system/forward-tokens [get]
func (systemApi *SystemApi) GetForwardTokens(c *gin.Context) {
integrate := systemIntegratedService.GetIntegratedConfig(gaia.SystemIntegrationDingTalk)
var configMap request.DingTalkConfigRequest
if integrate.Config != "" {
if err := json.Unmarshal([]byte(integrate.Config), &configMap); err != nil {
response.FailWithMessage("解析配置失败:"+err.Error(), c)
return
}
}
tokens := make([]gaiaResp.ForwardTokenInfo, 0, len(configMap.ForwardConfig.Tokens))
for i, token := range configMap.ForwardConfig.Tokens {
tokens = append(tokens, gaiaResp.ForwardTokenInfo{
ID: utils.AddAsteriskToString(token.TokenSecret),
CreatedAt: token.CreatedAt,
Seq: i + 1,
})
}
response.OkWithData(gaiaResp.ForwardTokensResponse{Tokens: tokens, Count: len(tokens), Max: 20}, c)
}
// CreateForwardToken 新增转发 Token
// @Tags System
// @Summary 新增转发 Token
// @Security ApiKeyAuth
// @accept application/json
// @Produce application/json
// @Param token body string true "Token 明文"
// @Success 200 {object} response.Response{data=request.ForwardToken,msg=string} "创建成功"
// @Router /gaia/system/forward-tokens [post]
func (systemApi *SystemApi) CreateForwardToken(c *gin.Context) {
var req struct {
Token string `json:"token"`
}
if err := c.ShouldBindJSON(&req); err != nil {
response.FailWithMessage(err.Error(), c)
return
}
if req.Token == "" {
response.FailWithMessage("Token 不能为空", c)
return
}
integrate := systemIntegratedService.GetIntegratedConfig(gaia.SystemIntegrationDingTalk)
var configMap request.DingTalkConfigRequest
if integrate.Config != "" {
if err := json.Unmarshal([]byte(integrate.Config), &configMap); err != nil {
response.FailWithMessage("解析配置失败:"+err.Error(), c)
return
}
}
// 检查数量限制
if len(configMap.ForwardConfig.Tokens) >= 20 {
response.FailWithMessage("转发 Token 最多 20 个", c)
return
}
// 生成唯一 ID 和哈希
tokenID := "tok_" + uuid.New().String()
tokenHash := fmt.Sprintf("%x", sha256.Sum256([]byte(req.Token)))
// 生成 HMAC 签名密钥(仅创建时回传一次)
secretBytes := make([]byte, 32)
if _, err := rand.Read(secretBytes); err != nil {
response.FailWithMessage("生成 TokenSecret 失败:"+err.Error(), c)
return
}
tokenSecret := base64.RawURLEncoding.EncodeToString(secretBytes)
newToken := request.ForwardToken{
ID: tokenID,
TokenHash: tokenHash,
CreatedAt: time.Now(),
TokenSecret: tokenSecret,
}
// 添加到配置
configMap.ForwardConfig.Tokens = append(configMap.ForwardConfig.Tokens, newToken)
seq := len(configMap.ForwardConfig.Tokens) // 1..N
configJSON, _ := json.Marshal(configMap)
integrate.Config = string(configJSON)
// 保存配置
if err := systemIntegratedService.SetIntegratedConfig(integrate, "", false); err != nil {
response.FailWithMessage("保存失败:"+err.Error(), c)
return
}
// 返回明文 token(仅此次展示)
response.OkWithData(gin.H{
"seq": seq,
"token": req.Token,
"token_secret": tokenSecret,
"created_at": newToken.CreatedAt,
}, c)
}
// DeleteForwardToken 删除转发 Token
// @Tags System
// @Summary 删除转发 Token
// @Security ApiKeyAuth
// @accept application/json
// @Produce application/json
// @Param seq path int true "Token 序列号(从列表获取,1..N"
// @Param password body string true "当前用户密码"
// @Success 200 {object} response.Response{msg=string} "删除成功"
// @Router /gaia/system/forward-tokens/:seq [delete]
func (systemApi *SystemApi) DeleteForwardToken(c *gin.Context) {
seqStr := c.Param("seq")
if seqStr == "" {
response.FailWithMessage("Token 序列号不能为空", c)
return
}
seq, err := strconv.Atoi(seqStr)
if err != nil || seq <= 0 {
response.FailWithMessage("Token 序列号非法", c)
return
}
var req struct {
Password string `json:"password"`
}
if err := c.ShouldBindJSON(&req); err != nil {
response.FailWithMessage(err.Error(), c)
return
}
// 验证当前用户密码(使用 Dify account 密码体系)
userID := utils.GetUserUuid(c).String()
var user system.SysUser
if err := global.GVA_DB.Select("email").Where(
"uuid = ?", userID).First(&user).Error; err != nil {
response.FailWithMessage("查询用户失败:"+err.Error(), c)
return
}
account, err := user.GetAccount()
if err != nil {
response.FailWithMessage("查询账号失败:"+err.Error(), c)
return
}
var pwd serviceGaia.PasswdEncode
if ok, pwdErr := pwd.ComparePassword(
req.Password, account.Password, account.PasswordSalt); pwdErr != nil || !ok {
response.FailWithMessage("密码错误", c)
return
}
// 获取配置
integrate := systemIntegratedService.GetIntegratedConfig(gaia.SystemIntegrationDingTalk)
var configMap request.DingTalkConfigRequest
if integrate.Config != "" {
if err = json.Unmarshal([]byte(integrate.Config), &configMap); err != nil {
response.FailWithMessage("解析配置失败:"+err.Error(), c)
return
}
}
// 查找并删除 token
if seq > len(configMap.ForwardConfig.Tokens) {
response.FailWithMessage("Token 不存在", c)
return
}
idx := seq - 1
newTokens := make([]request.ForwardToken, 0, len(configMap.ForwardConfig.Tokens)-1)
newTokens = append(newTokens, configMap.ForwardConfig.Tokens[:idx]...)
newTokens = append(newTokens, configMap.ForwardConfig.Tokens[idx+1:]...)
// 更新配置
configMap.ForwardConfig.Tokens = newTokens
configJSON, _ := json.Marshal(configMap)
integrate.Config = string(configJSON)
if err = systemIntegratedService.SetIntegratedConfig(integrate, "", false); err != nil {
response.FailWithMessage("保存失败:"+err.Error(), c)
return
}
response.OkWithMessage("删除成功", c)
}
+7 -8
View File
@@ -5,9 +5,8 @@ import (
"github.com/flipped-aurora/gin-vue-admin/server/model/common/response"
"github.com/flipped-aurora/gin-vue-admin/server/model/system/request"
"github.com/flipped-aurora/gin-vue-admin/server/service/system"
"go.uber.org/zap"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
type DBApi struct{}
@@ -20,7 +19,8 @@ type DBApi struct{}
// @Success 200 {object} response.Response{data=string} "初始化用户数据库"
// @Router /init/initdb [post]
func (i *DBApi) InitDB(c *gin.Context) {
if global.GVA_DB != nil {
if initDBService.IfInit() {
global.GVA_LOG.Error("已存在数据库配置!")
response.FailWithMessage("已存在数据库配置", c)
return
@@ -51,12 +51,11 @@ func (i *DBApi) InitDB(c *gin.Context) {
// @Success 200 {object} response.Response{data=map[string]interface{},msg=string} "初始化用户数据库"
// @Router /init/checkdb [post]
func (i *DBApi) CheckDB(c *gin.Context) {
var (
message = "前往初始化数据库"
needInit = true
)
// init
var needInit = true
var message = "前往初始化数据库"
if global.GVA_DB != nil {
if initDBService.IfInit() {
message = "数据库无需初始化"
needInit = false
}
+4 -4
View File
@@ -113,10 +113,10 @@ pgsql:
prefix: ""
port: "5432"
config: sslmode=disable TimeZone=Asia/Shanghai
db-name:
username:
password:
path:
db-name: dify
username: postgres
password: difyai123456
path: db_postgres
engine: ""
log-mode: error
max-idle-conns: 10
+1 -1
View File
@@ -5,5 +5,5 @@ type Gaia struct {
LoginMaxErrorLimit int `mapstructure:"login_max_error_limit" json:"login_max_error_limit" yaml:"login_max_error_limit"`
SuperAdminAccountId string `mapstructure:"SUPER_ADMIN_ACCOUNT_ID" json:"SUPER_ADMIN_ACCOUNT_ID" yaml:"SUPER_ADMIN_ACCOUNT_ID"` // 超级管理员账号
SuperAdminTenantId string `mapstructure:"SUPER_ADMIN_TENANT_ID" json:"SUPER_ADMIN_TENANT_ID" yaml:"SUPER_ADMIN_TENANT_ID"` // 系统默认工作区
StoragePath string `mapstructure:"storage-path" json:"storage-path" yaml:"storage-path"` // Dify storage 目录路径,用于读取私钥
StoragePath string `mapstructure:"storage-path" json:"storage-path" yaml:"storage-path"` // Dify storage 目录路径,用于读取私钥
}
+126
View File
@@ -0,0 +1,126 @@
package core
import (
"fmt"
"os"
"strconv"
"github.com/flipped-aurora/gin-vue-admin/server/global"
)
// overrideDBFromEnv 从环境变量覆盖数据库配置
// 优先级:环境变量 > 配置文件
func overrideDBFromEnv() {
// 数据库类型
if dbType := os.Getenv("DB_TYPE"); dbType != "" {
switch dbType {
case "mysql":
global.GVA_CONFIG.System.DbType = "mysql"
overrideMysqlFromEnv()
case "postgresql", "postgres":
global.GVA_CONFIG.System.DbType = "pgsql"
overridePgsqlFromEnv()
default:
global.GVA_CONFIG.System.DbType = "pgsql"
overridePgsqlFromEnv()
}
fmt.Printf("Database type overridden from DB_TYPE environment variable: %s\n", dbType)
}
}
// overrideMysqlFromEnv 从环境变量覆盖 MySQL 配置
func overrideMysqlFromEnv() {
cfg := &global.GVA_CONFIG.Mysql
if host := os.Getenv("DB_HOST"); host != "" {
cfg.Path = host
}
if port := os.Getenv("DB_PORT"); port != "" {
cfg.Port = port
} else {
cfg.Port = "3306"
}
if username := os.Getenv("DB_USERNAME"); username != "" {
cfg.Username = username
}
if password := os.Getenv("DB_PASSWORD"); password != "" {
cfg.Password = password
}
if dbname := os.Getenv("DB_DATABASE"); dbname != "" {
cfg.Dbname = dbname
}
if config := os.Getenv("DB_CONFIG"); config != "" {
cfg.Config = config
}
}
// overridePgsqlFromEnv 从环境变量覆盖 PostgreSQL 配置
func overridePgsqlFromEnv() {
cfg := &global.GVA_CONFIG.Pgsql
if host := os.Getenv("DB_HOST"); host != "" {
cfg.Path = host
}
if port := os.Getenv("DB_PORT"); port != "" {
cfg.Port = port
} else {
cfg.Port = "5432"
}
if username := os.Getenv("DB_USERNAME"); username != "" {
cfg.Username = username
}
if password := os.Getenv("DB_PASSWORD"); password != "" {
cfg.Password = password
}
if dbname := os.Getenv("DB_DATABASE"); dbname != "" {
cfg.Dbname = dbname
}
if config := os.Getenv("DB_CONFIG"); config != "" {
cfg.Config = config
} else {
cfg.Config = "sslmode=disable TimeZone=Asia/Shanghai"
}
}
// overrideRedisFromEnv 从环境变量覆盖 Redis 配置
func overrideRedisFromEnv() {
// 覆盖主 Redis 配置
if host := os.Getenv("REDIS_HOST"); host != "" {
port := os.Getenv("REDIS_PORT")
if port == "" {
port = "6379"
}
global.GVA_CONFIG.Redis.Addr = host + ":" + port
}
if password := os.Getenv("REDIS_PASSWORD"); password != "" {
global.GVA_CONFIG.Redis.Password = password
}
if db := os.Getenv("REDIS_DB"); db != "" {
if dbNum, err := strconv.Atoi(db); err == nil {
global.GVA_CONFIG.Redis.DB = dbNum
}
}
// 覆盖 Dify Redis 配置(与主 Redis 相同)
if host := os.Getenv("REDIS_HOST"); host != "" {
port := os.Getenv("REDIS_PORT")
if port == "" {
port = "6379"
}
global.GVA_CONFIG.DifyRedis.Addr = host + ":" + port
}
if password := os.Getenv("REDIS_PASSWORD"); password != "" {
global.GVA_CONFIG.DifyRedis.Password = password
}
if db := os.Getenv("REDIS_DB"); db != "" {
if dbNum, err := strconv.Atoi(db); err == nil {
global.GVA_CONFIG.DifyRedis.DB = dbNum
}
}
fmt.Printf("Redis configuration overridden from environment variables: %s\n", global.GVA_CONFIG.Redis.Addr)
}
// overrideAllFromEnv 从环境变量覆盖所有配置
func overrideAllFromEnv() {
overrideDBFromEnv()
overrideRedisFromEnv()
}
+7 -2
View File
@@ -4,10 +4,11 @@
package core
import (
"time"
"github.com/flipped-aurora/gin-vue-admin/server/initialize"
"github.com/fvbock/endless"
"github.com/gin-gonic/gin"
"syscall"
"time"
)
func initServer(address string, router *gin.Engine) server {
@@ -15,5 +16,9 @@ func initServer(address string, router *gin.Engine) server {
s.ReadHeaderTimeout = 10 * time.Minute
s.WriteTimeout = 10 * time.Minute
s.MaxHeaderBytes = 1 << 20
// 优雅关闭:在收到 SIGTERM/SIGINT 时先停止工作池,再关闭 HTTP 服务,避免 goroutine 与连接未释放
stopPool := func() { initialize.StopWorkerPool() }
s.SignalHooks[endless.PRE_SIGNAL][syscall.SIGTERM] = append(s.SignalHooks[endless.PRE_SIGNAL][syscall.SIGTERM], stopPool)
s.SignalHooks[endless.PRE_SIGNAL][syscall.SIGINT] = append(s.SignalHooks[endless.PRE_SIGNAL][syscall.SIGINT], stopPool)
return s
}
+13 -9
View File
@@ -26,7 +26,7 @@ func overrideJWTSigningKeyFromEnv() {
}
// Viper //
// 优先级: 命令行 > 环境变量 > 默认值
// 优先级命令行 > 环境变量 > 默认值
// Author [SliverHorn](https://github.com/SliverHorn)
func Viper(path ...string) *viper.Viper {
var config string
@@ -45,17 +45,17 @@ func Viper(path ...string) *viper.Viper {
case gin.TestMode:
config = internal.ConfigTestFile
}
fmt.Printf("您正在使用gin模式的%s环境名称,config的路径为%s\n", gin.Mode(), config)
} else { // internal.ConfigEnv 常量存储的环境变量不为空 将值赋值于config
fmt.Printf("您正在使用 gin 模式的%s环境名称config 的路径为%s\n", gin.Mode(), config)
} else { // internal.ConfigEnv 常量存储的环境变量不为空 将值赋值于 config
config = configEnv
fmt.Printf("您正在使用%s环境变量,config的路径为%s\n", internal.ConfigEnv, config)
fmt.Printf("您正在使用%s环境变量config 的路径为%s\n", internal.ConfigEnv, config)
}
} else { // 命令行参数不为空 将值赋值于config
fmt.Printf("您正在使用命令行的-c参数传递的值,config的路径为%s\n", config)
} else { // 命令行参数不为空 将值赋值于 config
fmt.Printf("您正在使用命令行的 -c 参数传递的值config 的路径为%s\n", config)
}
} else { // 函数传递的可变参数的第一个值赋值于config
} else { // 函数传递的可变参数的第一个值赋值于 config
config = path[0]
fmt.Printf("您正在使用func Viper()传递的值,config的路径为%s\n", config)
fmt.Printf("您正在使用 func Viper() 传递的值config 的路径为%s\n", config)
}
v := viper.New()
@@ -82,7 +82,11 @@ func Viper(path ...string) *viper.Viper {
// Extend: Override JWT signing key from environment variable after initial load
overrideJWTSigningKeyFromEnv()
// root 适配性 根据root位置去找到对应迁移位置,保证root路径有效
// Extend: Override database and redis configuration from environment variables
// This allows admin-server to use the same configuration as docker-compose
overrideAllFromEnv()
// root 适配性 根据 root 位置去找到对应迁移位置,保证 root 路径有效
global.GVA_CONFIG.AutoCode.Root, _ = filepath.Abs("..")
return v
+3 -2
View File
@@ -21,11 +21,12 @@ func newWithSeconds() *cron.Cron {
func Corn() {
var lock bool
initDBService := system.InitDBService{}
c := newWithSeconds()
// 每分钟同步一次用户列表
if _, err := c.AddFunc("0 */1 * * * *", func() {
if global.GVA_DB == nil {
global.GVA_LOG.Info("【定时任务-每1分钟执行1次】同步用户列表任务,数据库没有初始化,暂未开始同步")
if global.GVA_DB == nil || !initDBService.IfInit() {
global.GVA_LOG.Info("【定时任务-每1分钟执行1次】同步用户列表任务,数据库没有初始化或尚未完成初始化,暂未开始同步")
return
}
+51 -1
View File
@@ -81,8 +81,58 @@ func Gorm() *gorm.DB {
}
func RegisterTables(db *gorm.DB) {
var err error
var count int64
var menu system.SysBaseMenuBtn
var authority system.SysAuthority
if err = global.GVA_DB.Model(&menu).Count(&count).Error; count == 0 {
if err = global.GVA_DB.Model(&authority).Count(&count).Error; count == 1 {
return
}
}
// auto
err = db.AutoMigrate(
system.SysApi{},
system.SysIgnoreApi{},
system.SysUser{},
system.SysBaseMenu{},
system.JwtBlacklist{},
system.SysAuthority{},
system.SysDictionary{},
system.SysOperationRecord{},
system.SysAutoCodeHistory{},
system.SysDictionaryDetail{},
system.SysBaseMenuParameter{},
system.SysBaseMenuBtn{},
system.SysAuthorityBtn{},
system.SysAutoCodePackage{},
system.SysExportTemplate{},
system.Condition{},
system.JoinTemplate{},
system.SysParams{},
err := db.AutoMigrate(tables)
example.ExaFile{},
example.ExaCustomer{},
example.ExaFileChunk{},
example.ExaFileUploadAndDownload{},
adapter.CasbinRule{},
// Extend gaia model
gaia.AccountDingTalkExtend{},
gaia.AppRequestTestBatch{},
gaia.AppRequestTest{},
gaia.SystemIntegration{}, // Extend System Integration
gaia.ForwardingExtend{}, // Extend Forwarding Extend
gaia.BatchWorkflow{}, // Extend Batch Workflow
gaia.BatchWorkflowTask{}, // Extend Batch Workflow Task
gaia.AppVersionConfig{}, // 应用版本全局配置(Token
gaia.AppVersionRelease{}, // 应用版本发布
gaia.AppVersionDownload{}, // 应用版本各平台安装包
gaia.ModelProviderConfig{}, // 模型提供商配置
gaia.ModelProxyLog{}, // 模型中转请求日志
system.SysUserGlobalCode{}, // Extend Global Code
)
if err != nil {
global.GVA_LOG.Error("register table failed", zap.Error(err))
+2 -1
View File
@@ -22,6 +22,7 @@ func initBizRouter(routers ...*gin.RouterGroup) {
gaiaRouter.InitSystemRouter(privateGroup)
gaiaRouter.InitWorkflowRouter(privateGroup)
gaiaRouter.InitAppVersionRouter(publicGroup, privateGroup)
gaiaRouter.InitModelProviderRouter(privateGroup) // 模型提供商路由
gaiaRouter.InitModelProviderRouter(privateGroup) // 模型提供商路由
gaiaRouter.InitForwardProxyRouter(publicGroup) // GPT 转发代理(免 JWT
}
}
+5
View File
@@ -25,6 +25,11 @@ func CasbinHandler() gin.HandlerFunc {
// 获取用户的角色
sub := strconv.Itoa(int(waitUse.AuthorityId))
e := casbinService.Casbin() // 判断策略中是否存在
if e == nil {
global.GVA_LOG.Warn("Casbin enforcer is nil, skipping permission check")
c.Next()
return
}
success, _ := e.Enforce(sub, obj, act)
if !success {
response.FailWithDetailed(gin.H{}, "权限不足", c)
+24 -2
View File
@@ -32,10 +32,32 @@ const (
// 批量工作流配置常量
const (
MaxTaskRetryCount = 3 // 最大任务重试次数
ErrorPenaltyThreshold = 50 // 错误惩罚阈值(每50个错误减少1个并发位)
MaxTaskRetryCount = 3 // 最大任务重试次数
ErrorPenaltyThreshold = 50 // 错误惩罚阈值(每50个错误减少1个并发位)
)
// 优先级分档阈值(按 total_rows 划分)
const (
PriorityTier1MaxRows = 300 // 第一优先级:≤300 行
PriorityTier2MaxRows = 800 // 第二优先级:≤800 行
PriorityTier3MaxRows = 3000 // 第三优先级:≤3000 行
// 第四优先级:>3000 行
)
// UserWorkerAllocation 用户工作器分配信息
type UserWorkerAllocation struct {
UserID uint `json:"user_id"`
Workers int `json:"workers"`
MaxLimit int `json:"max_limit"`
}
// UserErrorInfo 用户错误信息(用于工作器分配计算)
type UserErrorInfo struct {
UserID uint `json:"user_id"`
ErrorCount int `json:"error_count"`
TotalRows int `json:"total_rows"`
}
// BatchWorkflow 批量工作流处理
type BatchWorkflow struct {
ID string `json:"id" gorm:"primaryKey;comment:批量处理ID"`
@@ -2,6 +2,49 @@ package gaia
import "time"
// ModelPricing 从 Dify Console API 拉取的模型定价信息(对应 pricing 字段)
type ModelPricing struct {
Input float64 `json:"input"` // 每 unit 的输入单价
Output float64 `json:"output"` // 每 unit 的输出单价(0 表示与 Input 相同或不区分)
Unit float64 `json:"unit"` // 计费单位(通常 0.001,即每千 token
Currency string `json:"currency"` // 货币(USD / RMB
}
// ModelUsage OpenAI 格式响应中的 usage 字段(非流式及流式末尾行)
type ModelUsage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
}
// ModelUsageResponse OpenAI 格式响应体(仅用于提取 usage 字段)
type ModelUsageResponse struct {
Usage *ModelUsage `json:"usage"`
}
// DifyModelPricingRaw Dify Console API 返回的原始定价字段(值为字符串形式的数字)
type DifyModelPricingRaw struct {
Input string `json:"input"`
Output string `json:"output"`
Unit string `json:"unit"`
Currency string `json:"currency"`
}
// DifyModelItem Dify Console API 返回的单个模型信息
type DifyModelItem struct {
Model string `json:"model"`
Pricing *DifyModelPricingRaw `json:"pricing"`
}
// DifyProviderModels Dify Console API 返回的单个 provider 下的模型列表
type DifyProviderModels struct {
Models []DifyModelItem `json:"models"`
}
// DifyModelsResponse Dify Console API GET /models/model-types/llm 的响应结构
type DifyModelsResponse struct {
Data []DifyProviderModels `json:"data"`
}
// ModelProviderConfig 模型提供商配置表
type ModelProviderConfig struct {
Id uint `json:"id" form:"id" gorm:"primarykey;column:id;comment:id;"`
@@ -9,6 +9,8 @@ const (
ProviderAzure = "azure"
ProviderZhipuai = "zhipuai"
ProviderMinimax = "minimax"
ProviderAWS = "aws" // AWS Bedrock 渠道(用于转发 Claude 等 Anthropic 模型)
ProviderDeepSeek = "deepseek" // DeepSeek 渠道
)
// DifyProviderTypeCustom Dify providers 表 provider_type 枚举
@@ -16,23 +18,31 @@ const DifyProviderTypeCustom = "custom"
// 凭证配置中的 key 名
const (
ConfigKeyOpenaiAPIKey = "openai_api_key"
ConfigKeyOpenaiAPIBase = "openai_api_base"
ConfigKeyOpenaiAPIVersion = "openai_api_version"
ConfigKeyDashScopeAPIKey = "dashscope_api_key"
ConfigKeyAPIKey = "api_key"
ConfigKeyOpenaiAPIKey = "openai_api_key"
ConfigKeyOpenaiAPIBase = "openai_api_base"
ConfigKeyOpenaiAPIVersion = "openai_api_version"
ConfigKeyDashScopeAPIKey = "dashscope_api_key"
ConfigKeyAPIKey = "api_key"
// AWS Bedrock 凭证字段(Dify bedrock provider 的 encrypted_config 中使用)
ConfigKeyAWSAccessKeyID = "aws_access_key_id"
ConfigKeyAWSSecretAccessKey = "aws_secret_access_key"
ConfigKeyAWSSessionToken = "aws_session_token"
ConfigKeyAWSRegion = "aws_region"
ConfigKeyBedrockProxyURL = "bedrock_proxy_url" // 可选:HTTP 代理地址,格式 host:port 或 http://host:port
)
// SupportedProviders 列表展示的提供商顺序
var SupportedProviders = []string{ProviderOpenai, ProviderTongyi, ProviderGoogle, ProviderAnthropic, ProviderAzure, ProviderZhipuai, ProviderMinimax}
var SupportedProviders = []string{ProviderOpenai, ProviderTongyi, ProviderGoogle, ProviderAnthropic, ProviderAWS, ProviderAzure, ProviderZhipuai, ProviderMinimax, ProviderDeepSeek}
// DefaultChatCompletionsEndpoints 各提供商聊天接口默认完整 URL(兼容旧 ProxyChat
var DefaultChatCompletionsEndpoints = map[string]string{
ProviderOpenai: "https://api.openai.com/v1/chat/completions",
ProviderTongyi: "https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions",
ProviderGoogle: "https://generativelanguage.googleapis.com/v1beta/chat/completions",
ProviderZhipuai: "https://open.bigmodel.cn/api/paas/v4/chat/completions",
ProviderMinimax: "https://api.minimax.chat/v1/text/chatcompletion_v2",
ProviderOpenai: "https://api.openai.com/v1/chat/completions",
ProviderTongyi: "https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions",
ProviderGoogle: "https://generativelanguage.googleapis.com/v1beta/chat/completions",
ProviderZhipuai: "https://open.bigmodel.cn/api/paas/v4/chat/completions",
ProviderMinimax: "https://api.minimax.chat/v1/text/chatcompletion_v2",
ProviderDeepSeek: "https://api.deepseek.com/v1/chat/completions",
// Azure 需要动态构建 URL,不使用默认值
}
@@ -42,10 +52,90 @@ var DefaultAPIBase = map[string]string{
ProviderTongyi: "https://dashscope.aliyuncs.com/compatible-mode",
ProviderGoogle: "https://generativelanguage.googleapis.com",
ProviderAnthropic: "https://api.anthropic.com",
ProviderAWS: "https://bedrock-runtime.us-east-1.amazonaws.com",
ProviderZhipuai: "https://open.bigmodel.cn",
ProviderMinimax: "https://api.minimax.chat",
ProviderDeepSeek: "https://api.deepseek.com",
// Azure 的 base URL 来自 openai_api_base 配置,不设置默认值
}
// CredentialKeyFallback 未知提供商时依次尝试的配置 key
var CredentialKeyFallback = []string{ConfigKeyOpenaiAPIKey, ConfigKeyAPIKey, ConfigKeyDashScopeAPIKey}
// RmbToUSDRate 人民币兑美元汇率
const RmbToUSDRate = 7.26
// DefaultImageGenerationPriceUSD 图片生成等按次计费接口的默认单价(USD),无 usage 时使用
const DefaultImageGenerationPriceUSD = 0.04
// DefaultQuotaFallbackUSDPerToken 未命中定价时的兜底单价:每 token 的 USD 金额(仅做记账占位,约 $0.001/千 token
const DefaultQuotaFallbackUSDPerToken = 0.000001
// Gaia 相关 Redis KeyGVA_REDIS / GVA_Dify_REDIS
const (
RedisKeyGaiaAdminConsoleToken = "gaia:admin_console_token"
RedisKeyGaiaModelPricingPrefix = "gaia:model_pricing:"
RedisKeyGaiaForwardDingPrefix = "gaia:forward:ding:"
RedisKeyModelProviderCredentialsPrefix = "model_provider_credentials:"
)
// BuiltinModelPricing 内置兜底定价表(当 Dify Console API 未返回该模型定价时使用)。
// 价格单位:每千 token(与 ModelPricing.Unit=0.001 对应),货币为各模型实际结算货币。
// 通义/百炼模型官方定价(人民币,参考 https://help.aliyun.com/document_detail/2586379.html):
// - 输入/输出价格均为「每百万 token」,换算为每千 token 时除以 1000。
var BuiltinModelPricing = map[string]ModelPricing{
// ──── 通义千问 Qwen3 系列(RMB / 百万 token128K 档) ────
"qwen3-235b-a22b": {Input: 0.4 / 1000, Output: 1.6 / 1000, Unit: 0.001, Currency: "RMB"},
"qwen3-30b-a3b": {Input: 0.11 / 1000, Output: 0.44 / 1000, Unit: 0.001, Currency: "RMB"},
"qwen3-32b": {Input: 0.8 / 1000, Output: 3.2 / 1000, Unit: 0.001, Currency: "RMB"},
"qwen3-14b": {Input: 0.3 / 1000, Output: 1.2 / 1000, Unit: 0.001, Currency: "RMB"},
"qwen3-8b": {Input: 0.1 / 1000, Output: 0.4 / 1000, Unit: 0.001, Currency: "RMB"},
"qwen3-4b": {Input: 0.04 / 1000, Output: 0.16 / 1000, Unit: 0.001, Currency: "RMB"},
"qwen3-1.7b": {Input: 0.02 / 1000, Output: 0.08 / 1000, Unit: 0.001, Currency: "RMB"},
"qwen3-0.6b": {Input: 0.01 / 1000, Output: 0.04 / 1000, Unit: 0.001, Currency: "RMB"},
// ──── 通义千问 Qwen3.5 系列(RMB / 百万 token128K 档) ────
"qwen3.5-plus": {Input: 0.8 / 1000, Output: 4.8 / 1000, Unit: 0.001, Currency: "RMB"},
"qwen3.5-turbo": {Input: 0.3 / 1000, Output: 1.2 / 1000, Unit: 0.001, Currency: "RMB"},
// ──── 通义千问 Qwen2.5 系列(RMB / 百万 token ────
"qwen2.5-72b-instruct": {Input: 4.0 / 1000, Output: 12.0 / 1000, Unit: 0.001, Currency: "RMB"},
"qwen2.5-32b-instruct": {Input: 3.5 / 1000, Output: 7.0 / 1000, Unit: 0.001, Currency: "RMB"},
"qwen2.5-14b-instruct": {Input: 2.0 / 1000, Output: 6.0 / 1000, Unit: 0.001, Currency: "RMB"},
"qwen2.5-7b-instruct": {Input: 1.0 / 1000, Output: 2.0 / 1000, Unit: 0.001, Currency: "RMB"},
"qwen2.5-3b-instruct": {Input: 0.3 / 1000, Output: 0.6 / 1000, Unit: 0.001, Currency: "RMB"},
"qwen-plus": {Input: 0.8 / 1000, Output: 2.0 / 1000, Unit: 0.001, Currency: "RMB"},
"qwen-turbo": {Input: 0.3 / 1000, Output: 0.6 / 1000, Unit: 0.001, Currency: "RMB"},
"qwen-max": {Input: 40.0 / 1000, Output: 120.0 / 1000, Unit: 0.001, Currency: "RMB"},
"qwen-long": {Input: 0.5 / 1000, Output: 2.0 / 1000, Unit: 0.001, Currency: "RMB"},
// ──── 月之暗面 Kimi 系列(RMB / 百万 token ────
// Kimi 走 tongyi(百炼)渠道转发,命名沿用 kimi 前缀;前缀匹配会让 kimi2-k2.6-xxx 也命中 kimi2-k2.6
"kimi2-k2.6": {Input: 4.0 / 1000, Output: 16.0 / 1000, Unit: 0.001, Currency: "RMB"},
"moonshot-v1-8k": {Input: 12.0 / 1000, Output: 12.0 / 1000, Unit: 0.001, Currency: "RMB"},
"moonshot-v1-32k": {Input: 24.0 / 1000, Output: 24.0 / 1000, Unit: 0.001, Currency: "RMB"},
"moonshot-v1-128k": {Input: 60.0 / 1000, Output: 60.0 / 1000, Unit: 0.001, Currency: "RMB"},
// ──── Anthropic Claude 系列(USD / 百万 token ────
// Claude 4.6 / 4.7 系列(Sonnet 与 Opus);anthropic 直连与 AWS Bedrock 走同一份定价
"claude-sonnet-4-6": {Input: 3.0 / 1000, Output: 15.0 / 1000, Unit: 0.001, Currency: "USD"},
"claude-sonnet-4-7": {Input: 3.0 / 1000, Output: 15.0 / 1000, Unit: 0.001, Currency: "USD"},
"claude-opus-4-6": {Input: 15.0 / 1000, Output: 75.0 / 1000, Unit: 0.001, Currency: "USD"},
"claude-opus-4-7": {Input: 15.0 / 1000, Output: 75.0 / 1000, Unit: 0.001, Currency: "USD"},
// AWS Bedrock 上常用的模型 ID 形式(带 anthropic. 前缀与 -v1:0 后缀),单独列出避免前缀匹配漂移
"anthropic.claude-sonnet-4-6-v1:0": {Input: 3.0 / 1000, Output: 15.0 / 1000, Unit: 0.001, Currency: "USD"},
"anthropic.claude-sonnet-4-7-v1:0": {Input: 3.0 / 1000, Output: 15.0 / 1000, Unit: 0.001, Currency: "USD"},
"anthropic.claude-opus-4-6-v1:0": {Input: 15.0 / 1000, Output: 75.0 / 1000, Unit: 0.001, Currency: "USD"},
"anthropic.claude-opus-4-7-v1:0": {Input: 15.0 / 1000, Output: 75.0 / 1000, Unit: 0.001, Currency: "USD"},
// ──── OpenAI 图片生成(按次计费,Input 字段表示「每次请求的 USD 单价」) ────
// 命中分支见 service/gaia/model_provider.go 中的 isImageOrPerRequestPath 与 ProxyRequest 计费逻辑
"gpt-image-1": {Input: 0.04, Currency: "USD"},
"gpt-image-2": {Input: 0.05, Currency: "USD"},
// ──── DeepSeek 系列(USD / 百万 token ────
// deepseek-v4-pro:旗舰推理模型,定价参考官方 https://platform.deepseek.com/api-docs/pricing
"deepseek-v4-pro": {Input: 2.19 / 1000, Output: 8.19 / 1000, Unit: 0.001, Currency: "USD"},
// deepseek-v4-flash:高速轻量模型
"deepseek-v4-flash": {Input: 0.27 / 1000, Output: 1.10 / 1000, Unit: 0.001, Currency: "USD"},
}
@@ -1,5 +1,11 @@
package request
// GetProxyLogsReq 代理日志分页请求
type GetProxyLogsReq struct {
Page int `form:"page"` // 页码,从 1 开始
PageSize int `form:"page_size"` // 每页条数,最大 100
}
// ChatRequest 聊天请求(OpenAI 兼容)
type ChatRequest struct {
Model string `json:"model"`
+84 -36
View File
@@ -1,5 +1,7 @@
package request
import "time"
// SystemOAuth2Error OAuth2 错误返回
type SystemOAuth2Error struct {
Code int `json:"code" gorm:"comment:分类"` // 错误代码
@@ -8,56 +10,102 @@ type SystemOAuth2Error struct {
// SystemOAuth2Request OAuth2 集成配置
type SystemOAuth2Request struct {
Classify uint `json:"classify" gorm:"comment:分类"` // 分类
Status bool `json:"status" gorm:"comment:状态"` // 状态
ServerURL string `json:"server_url" gorm:"comment:服务器地址"` // OAuth2 服务器地址
AuthorizeURL string `json:"authorize_url" gorm:"comment:申请认证的URL"` // 申请认证的URL
TokenURL string `json:"token_url" gorm:"comment:获取TokenURL"` // 获取TokenURL
UserinfoURL string `json:"userinfo_url" gorm:"comment:获取用户信息URL"` // 获取用户信息的URL
LogoutURL string `json:"logout_url" gorm:"comment:退出登录回调URL"` // 退出登录回调URL
DiscoveryURL string `json:"discovery_url" gorm:"comment:OIDC发现配置URL"` // OIDC 发现配置URL
AppID string `json:"app_id" gorm:"comment:Client ID"` // Client ID
AppSecret string `json:"app_secret" gorm:"comment:Client Secret"` // Client Secret
UserNameField string `json:"user_name_field" gorm:"comment:用户名字段"` // 用户名字段
UserEmailField string `json:"user_email_field" gorm:"comment:邮箱字段"` // 邮箱字段
UserIDField string `json:"user_id_field" gorm:"comment:用户唯一标识字段"` // 用户唯一标识字段
Scope string `json:"scope" gorm:"comment:授权范围scope"` // 授权范围
TokenAuthMethod string `json:"token_auth_method" gorm:"comment:令牌端点认证方式"` // client_secret_post|client_secret_basic
RedirectUri string `json:"redirect_uri" gorm:"comment:测试用回调地址"` // 测试用回调地址
Test bool `json:"test" gorm:"default:0;comment:是否测试链接联通性"` // 是否测试链接联通性
Code string `json:"code" gorm:"default:0;comment:code代码"` // code代码
Classify uint `json:"classify" gorm:"comment:分类"` // 分类
Status bool `json:"status" gorm:"comment:状态"` // 状态
ServerURL string `json:"server_url" gorm:"comment:服务器地址"` // OAuth2 服务器地址
AuthorizeURL string `json:"authorize_url" gorm:"comment:申请认证的 URL"` // 申请认证的 URL
TokenURL string `json:"token_url" gorm:"comment:获取 TokenURL"` // 获取 TokenURL
UserinfoURL string `json:"userinfo_url" gorm:"comment:获取用户信息 URL"` // 获取用户信息的 URL
LogoutURL string `json:"logout_url" gorm:"comment:退出登录回调 URL"` // 退出登录回调 URL
DiscoveryURL string `json:"discovery_url" gorm:"comment:OIDC 发现配置 URL"` // OIDC 发现配置 URL
AppID string `json:"app_id" gorm:"comment:Client ID"` // Client ID
AppSecret string `json:"app_secret" gorm:"comment:Client Secret"` // Client Secret
UserNameField string `json:"user_name_field" gorm:"comment:用户名字段"` // 用户名字段
UserEmailField string `json:"user_email_field" gorm:"comment:邮箱字段"` // 邮箱字段
UserIDField string `json:"user_id_field" gorm:"comment:用户唯一标识字段"` // 用户唯一标识字段
Scope string `json:"scope" gorm:"comment:授权范围 scope"` // 授权范围
TokenAuthMethod string `json:"token_auth_method" gorm:"comment:令牌端点认证方式"` // client_secret_post|client_secret_basic
RedirectUri string `json:"redirect_uri" gorm:"comment:测试用回调地址"` // 测试用回调地址
Test bool `json:"test" gorm:"default:0;comment:是否测试链接联通性"` // 是否测试链接联通性
Code string `json:"code" gorm:"default:0;comment:code 代码"` // code 代码
}
// ValueType 参数/字段值类型
const (
ValueTypeString = "string" // 字符串类型
ValueTypeInt = "int" // 整数类型
ValueTypeBool = "bool" // 布尔类型
ValueTypeDingID = "ding_id" // 钉钉 ID 类型(运行时自动替换)
)
// DingIDMarker Raw 模式下钉钉 ID 占位符
const DingIDMarker = "$<{[ding_id]}>"
// AuthorizationConfig 认证配置
type AuthorizationConfig struct {
Type string `json:"type"` // none | bearer | basic
Token string `json:"token"` // Bearer Token
Username string `json:"username"` // Basic Auth用户名
Password string `json:"password"` // Basic Auth密码
Username string `json:"username"` // Basic Auth 用户名
Password string `json:"password"` // Basic Auth 密码
}
// BodyData Body数据配置
// RequestParam URL 查询参数配置
type RequestParam struct {
Key string `json:"key"` // 参数名
ValueType string `json:"value_type"` // string | int | bool | ding_id
Value string `json:"value"` // 参数值(ding_id 类型时运行时自动替换)
}
// BodyField Body 字段配置(支持类型化)
type BodyField struct {
Key string `json:"key"` // 字段名
ValueType string `json:"value_type"` // string | int | bool | ding_id
Value string `json:"value"` // 字段值
}
// BodyData Body 数据配置
type BodyData struct {
FormData []map[string]string `json:"form_data"` // form-data格式数据
Urlencoded []map[string]string `json:"urlencoded"` // x-www-form-urlencoded格式数据
Raw string `json:"raw"` // raw JSON字符串
FormData []BodyField `json:"form_data"` // form-data 格式数据(新格式)
Urlencoded []BodyField `json:"urlencoded"` // x-www-form-urlencoded 格式数据(新格式)
Raw string `json:"raw"` // raw JSON 字符串
}
// EmailApiConfig 第三方邮箱API配置
// EmailApiConfig 第三方邮箱 API 配置
type EmailApiConfig struct {
Enabled bool `json:"enabled"` // 是否启用
URL string `json:"url"` // API地址
Method string `json:"method"` // HTTP方法
RequestParamField string `json:"request_param_field"` // 请求参数字段名
BodyType string `json:"body_type"` // Body类型: form-data | x-www-form-urlencoded | raw
Headers map[string]string `json:"headers"` // 请求头
Authorization AuthorizationConfig `json:"authorization"` // 认证配置
BodyData BodyData `json:"body_data"` // Body数据
ResponseEmailField string `json:"response_email_field"` // 响应邮箱字段路径
Enabled bool `json:"enabled"` // 是否启用
URL string `json:"url"` // API 地址
Method string `json:"method"` // HTTP 方法
RequestParamField string `json:"request_param_field"` // 请求参数字段名(旧格式兼容)
Params []RequestParam `json:"params"` // URL 查询参数列表(新格式)
BodyType string `json:"body_type"` // Body 类型:form-data | x-www-form-urlencoded | raw
Headers map[string]string `json:"headers"` // 请求头
Authorization AuthorizationConfig `json:"authorization"` // 认证配置
BodyData BodyData `json:"body_data"` // Body 数据
ResponseEmailField string `json:"response_email_field"` // 响应邮箱字段路径
}
// TestEmailApiConfigRequest 测试邮箱 API 配置请求
type TestEmailApiConfigRequest struct {
Config EmailApiConfig `json:"config"` // 完整的邮箱配置
TestDingID string `json:"test_ding_id"` // 测试用的钉钉 ID(可选)
}
// ForwardToken 转发 Token 配置
type ForwardToken struct {
ID string `json:"id"` // 前端生成的唯一 ID(用于删除)
TokenHash string `json:"token_hash"` // SHA256(token)
CreatedAt time.Time `json:"created_at"` // 创建时间
TokenSecret string `json:"token_secret"` // HMAC 签名密钥(随机生成,服务端保存)
}
// ForwardConfig 转发集成配置
type ForwardConfig struct {
Enabled bool `json:"enabled"` // 是否启用转发
Tokens []ForwardToken `json:"tokens"` // Token 列表,最多 20 个
}
// DingTalkConfigRequest 钉钉集成配置
type DingTalkConfigRequest struct {
EmailApi EmailApiConfig `json:"email_api"` // 第三方邮箱API配置
EmailApi EmailApiConfig `json:"email_api"` // 第三方邮箱 API 配置
ForwardConfig ForwardConfig `json:"forward_config"` // 转发集成配置
}
@@ -1,5 +1,29 @@
package response
// AppQuotaRankingRow 应用配额排名查询单行(仅用于 service 层 GORM 查询扫描)
type AppQuotaRankingRow struct {
AppID string `gorm:"column:app_id"`
TotalCost float64 `gorm:"column:total_cost"`
MessageCost float64 `gorm:"column:message_cost"`
WorkflowCost float64 `gorm:"column:workflow_cost"`
RecordNum float64 `gorm:"column:record_num"`
}
// AppQuotaRankingCache 应用配额排名缓存结构(List + Total
type AppQuotaRankingCache struct {
List []GetAppQuotaRankingDataRes
Total int64
}
// AiImageQuotaRankingRow AI 图片使用量排名查询单行(仅用于 service 层 GORM 查询扫描)
type AiImageQuotaRankingRow struct {
Address string `gorm:"column:address"`
Path string `gorm:"column:path"`
TotalCost float64 `gorm:"column:total_cost"`
RecordNum int `gorm:"column:record_num"`
Model string `gorm:"column:model"`
}
// GetAccountQuotaRankingDataRes 获取账户配额排名数据的响应结构
type GetAccountQuotaRankingDataRes struct {
Ranking int `json:"ranking"` // 排名
@@ -5,6 +5,14 @@ type ProviderCredentials struct {
APIKey string `json:"api_key"`
Endpoint string `json:"endpoint,omitempty"`
APIVersion string `json:"api_version,omitempty"` // Azure OpenAI API 版本
// AWS Bedrock 直连用:access key + secret + region(不走 APIKey/Endpoint
AWSAccessKeyID string `json:"aws_access_key_id,omitempty"`
AWSSecretAccessKey string `json:"aws_secret_access_key,omitempty"`
AWSSessionToken string `json:"aws_session_token,omitempty"`
AWSRegion string `json:"aws_region,omitempty"`
// Bedrock 可选代理地址(host:port 或 http://host:port),非空时请求经该代理转发到 AWS
BedrockProxyURL string `json:"bedrock_proxy_url,omitempty"`
}
// ModelInfo 模型信息
@@ -40,10 +48,10 @@ type OpenAIModelsListResponse struct {
type TongyiModelsListResponse struct {
Success bool `json:"success"`
Output struct {
Total int `json:"total"`
PageNo int `json:"page_no"`
PageSize int `json:"page_size"`
Models []TongyiModelItem `json:"models"`
Total int `json:"total"`
PageNo int `json:"page_no"`
PageSize int `json:"page_size"`
Models []TongyiModelItem `json:"models"`
} `json:"output"`
}
@@ -55,8 +63,8 @@ type TongyiModelItem struct {
// GeminiModelsListResponse Google Gemini GET /v1beta/models 返回:models[] + nextPageToken
type GeminiModelsListResponse struct {
Models []GeminiModelItem `json:"models"`
NextPageToken string `json:"nextPageToken"`
Models []GeminiModelItem `json:"models"`
NextPageToken string `json:"nextPageToken"`
}
// GeminiModelItem Gemini 模型单项,name 为 "models/gemini-xxx"baseModelId 用于请求
@@ -0,0 +1,31 @@
package response
import "time"
type CheckAccountQuotaRow struct {
TotalQuota float64 `gorm:"column:total_quota"`
UsedQuota float64 `gorm:"column:used_quota"`
}
// ForwardTokenInfo 转发 Token 列表项(不暴露内部 ID)
type ForwardTokenInfo struct {
ID string `json:"id"` // token
Seq int `json:"seq"` // 1..N 序列号(用于删除)
CreatedAt time.Time `json:"created_at"`
}
// ForwardTokensResponse 获取转发 Token 列表响应
type ForwardTokensResponse struct {
Tokens []ForwardTokenInfo `json:"tokens"`
Count int `json:"count"`
Max int `json:"max"`
}
// TestEmailApiConfigResponse 测试邮箱 API 配置响应
type TestEmailApiConfigResponse struct {
StatusCode int `json:"status_code"` // HTTP 状态码
Body interface{} `json:"body"` // 响应 Body(JSON 时为对象,否则为字符串)
EmailFieldPreview string `json:"email_field_preview"` // 邮箱字段解析预览(如 data[0].userName = test@example.com
IsValid bool `json:"is_valid"` // 配置是否有效(能正确提取邮箱)
ErrorMessage string `json:"error_message,omitempty"` // 错误信息(可选)
}
@@ -5,6 +5,7 @@ const SystemIntegrationDingTalk = uint(1) // 钉钉集成
const SystemIntegrationWeiXin = uint(2) // 微信集成
const SystemIntegrationFeiShu = uint(3) // 飞书集成
const SystemIntegrationOAuth2 = uint(4) // OAuth2集成
const BearerLength = 7 // OAuth2集成
// SystemIntegration 系统集成表
type SystemIntegration struct {
+1
View File
@@ -23,3 +23,4 @@ var testApi = api.ApiGroupApp.GaiaApiGroup.TestApi
var batchWorkflowApi = api.ApiGroupApp.GaiaApiGroup.BatchWorkflowApi
var appVersionApi = api.ApiGroupApp.GaiaApiGroup.AppVersionApi
var modelProviderApi = api.ApiGroupApp.GaiaApiGroup.ModelProviderApi
var forwardProxyApi = api.ApiGroupApp.GaiaApiGroup.ForwardProxyApi
+22 -8
View File
@@ -10,16 +10,30 @@ type SystemRouter struct{}
func (s *SystemRouter) InitSystemRouter(Router *gin.RouterGroup) {
systemRouter := Router.Group("gaia/system")
{
systemRouter.GET("dingtalk", systemApi.GetDingTalk) // 获取钉钉系统配置
systemRouter.POST("dingtalk", systemApi.SetDingTalk) // 设置钉钉系统配置
systemRouter.GET("oauth2", systemOAuth2Api.GetOAuth2Config) // 获取OAuth2配置
systemRouter.POST("oauth2", systemOAuth2Api.SetOAuth2Config) // 设置OAuth2配置
systemRouter.GET("dingtalk", systemApi.GetDingTalk) // 获取钉钉系统配置
systemRouter.POST("dingtalk", systemApi.SetDingTalk) // 设置钉钉系统配置
systemRouter.GET("dingtalk/test-auth-url", systemApi.GetDingTalkTestAuthURL) // 测试连接:获取钉钉授权 URL
systemRouter.POST("dingtalk/test-callback", systemApi.DingTalkTestCallback) // 测试连接:回调验证 code
systemRouter.GET("oauth2", systemOAuth2Api.GetOAuth2Config) // 获取 OAuth2 配置
systemRouter.POST("oauth2", systemOAuth2Api.SetOAuth2Config) // 设置 OAuth2 配置
// 邮箱 API 配置测试
systemRouter.POST("dingtalk/test-email-config", systemApi.TestEmailApiConfig) // 测试第三方邮箱 API 配置
// 转发 Token 管理
systemRouter.GET("forward-tokens", systemApi.GetForwardTokens) // 获取转发 Token 列表
systemRouter.POST("forward-tokens", systemApi.CreateForwardToken) // 新增转发 Token
systemRouter.DELETE("forward-tokens/:seq", systemApi.DeleteForwardToken) // 删除转发 Token(按序列号)
}
}
// InitForwardProxyRouter 初始化 GPT 转发代理路由
func (s *SystemRouter) InitForwardProxyRouter(PublicRouter *gin.RouterGroup) {
// 免 JWT 转发入口,通过 forwarding token + ding_id 鉴权
PublicRouter.Any("gaia/forward/proxy/*path", forwardProxyApi.ForwardProxy)
}
// InitModelProviderRouter 初始化模型提供商路由
func (s *SystemRouter) InitModelProviderRouter(Router *gin.RouterGroup) {
// 管理端API(需要JWT认证)
// 管理端 API(需要 JWT 认证)
modelProviderRouter := Router.Group("gaia/model-provider")
{
modelProviderRouter.GET("list", modelProviderApi.GetProviderList) // 获取提供商配置列表
@@ -29,10 +43,10 @@ func (s *SystemRouter) InitModelProviderRouter(Router *gin.RouterGroup) {
modelProviderRouter.GET("logs", modelProviderApi.GetProxyLogs) // 获取代理日志
}
// 第三方API(需要JWT认证)
// 第三方 API(需要 JWT 认证)
gaiaRouter := Router.Group("gaia")
{
gaiaRouter.GET("models", modelProviderApi.GetModels) // 获取开启的模型列表(OpenAI格式)
gaiaRouter.Any("proxy/*path", modelProviderApi.Proxy) // 通用中转API:按路径转发(v1/chat/completions、v1/messages、v1/images/generations、v1/embeddings 等)
gaiaRouter.GET("models", modelProviderApi.GetModels) // 获取开启的模型列表(OpenAI 格式)
gaiaRouter.Any("proxy/*path", modelProviderApi.Proxy) // 通用中转 API:按路径转发(v1/chat/completions、v1/messages、v1/images/generations、v1/embeddings 等)
}
}
+27 -2
View File
@@ -28,6 +28,31 @@ func (s *BatchWorkflowService) CreateBatchWorkflow(
return nil, fmt.Errorf("数据库连接未初始化")
}
// 计算本次上传的有效数据行数(去掉表头和空行)
uploadedDataRows := 0
if len(fileContent) > 1 {
for _, row := range fileContent[1:] {
for _, v := range row {
if strings.TrimSpace(v) != "" {
uploadedDataRows++
break
}
}
}
}
// 检查当前用户 pending 状态队列中是否已有相同文件(文件名 + 行数一致视为重复)
var duplicateCount int64
if err := global.GVA_DB.Model(&gaia.BatchWorkflow{}).Where(
"user_id = ? AND file_name = ? AND installed_id = ? AND total_rows = ? AND status = ?",
userId, fileName, installedID, uploadedDataRows, gaia.BatchWorkflowStatusPending).
Count(&duplicateCount).Error; err != nil {
return nil, fmt.Errorf("检查重复文件失败: %v", err)
}
if duplicateCount > 0 {
return nil, fmt.Errorf("文件重复上传:当前队列中已存在相同文件(文件名:%s,行数:%d),请勿重复提交", fileName, uploadedDataRows)
}
// 创建批量处理记录
keyByte, _ := json.Marshal(keyNameMapping)
batchWorkflow := &gaia.BatchWorkflow{
@@ -363,8 +388,8 @@ func (s *BatchWorkflowService) callDifyAPI(
}
// Extend End: 添加CSRF token支持
// 发送请求
client := &http.Client{}
// 发送请求(设置超时,避免 Dify 卡住时 goroutine 与连接长期占用不释放)
client := &http.Client{Timeout: 5 * time.Minute}
resp, err := client.Do(req)
if err != nil {
return "", err
+399
View File
@@ -0,0 +1,399 @@
package gaia
import (
"bytes"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"github.com/aws/aws-sdk-go/aws/credentials"
v4 "github.com/aws/aws-sdk-go/aws/signer/v4"
estream "github.com/aws/aws-sdk-go/private/protocol/eventstream"
"github.com/flipped-aurora/gin-vue-admin/server/global"
"github.com/flipped-aurora/gin-vue-admin/server/model/gaia"
gaiaResponse "github.com/flipped-aurora/gin-vue-admin/server/model/gaia/response"
"go.uber.org/zap"
)
// proxyBedrockRequest 直连 AWS Bedrock 原生 API 转发 Anthropic Messages 请求。
//
// 路径转换:v1/messages → model/{modelId}/invoke 或 model/{modelId}/invoke-with-response-stream
// 鉴权:SigV4service=bedrockregion 来自 Dify 凭证 aws_region 字段)
// 请求体改写:去掉 model 字段(Bedrock 走 URL 路径),注入 anthropic_version=bedrock-2023-05-31
// 响应:
// - 非流式:Bedrock 返回 Anthropic Messages JSON(含 usage.input_tokens/output_tokens),原样转发
// - 流式:vnd.amazon.eventstream 二进制帧,每帧 payload 是 {"bytes":"<base64>"}
// 解出后是 Anthropic SSE 事件 JSON,在此重组为标准 SSE 写回客户端
//
// 计费:成功后按 (input_tokens, output_tokens) 调 calcQuotaDelta 扣额。
func (s *ModelProviderService) proxyBedrockRequest(
userID, _ /* path */, method string, _ /* reqHeader */ http.Header, body []byte, writer io.Writer,
creds *gaiaResponse.ProviderCredentials,
) error {
// 1) 校验 AWS 凭证
if creds == nil || creds.AWSAccessKeyID == "" || creds.AWSSecretAccessKey == "" {
return fmt.Errorf("AWS Bedrock 凭证缺失(需要 aws_access_key_id / aws_secret_access_key")
}
region := creds.AWSRegion
if region == "" {
region = "us-east-1"
}
// 2) 解析 body:拿到 modelId 与 stream 标记,并改写为 Bedrock 期望的格式
if len(body) == 0 {
return fmt.Errorf("Bedrock 请求 body 不能为空")
}
var bodyObj map[string]interface{}
if err := json.Unmarshal(body, &bodyObj); err != nil {
return fmt.Errorf("解析 Bedrock 请求 body 失败:%w", err)
}
modelID, _ := bodyObj["model"].(string)
if modelID == "" {
return fmt.Errorf("Bedrock 请求 body 缺少 model 字段")
}
streaming := false
if v, ok := bodyObj["stream"].(bool); ok {
streaming = v
}
// 移除 modelBedrock 不需要);删除 stream(流式由 URL 决定)
delete(bodyObj, "model")
delete(bodyObj, "stream")
delete(bodyObj, "stream_options")
// OpenAI 兼容字段转换:max_completion_tokens → max_tokensBedrock/Anthropic 使用 max_tokens
if _, hasMaxTokens := bodyObj["max_tokens"]; !hasMaxTokens {
if v, ok := bodyObj["max_completion_tokens"]; ok {
bodyObj["max_tokens"] = v
delete(bodyObj, "max_completion_tokens")
}
} else {
delete(bodyObj, "max_completion_tokens")
}
// 注入 Bedrock 必需的 anthropic_version
if _, ok := bodyObj["anthropic_version"]; !ok {
bodyObj["anthropic_version"] = "bedrock-2023-05-31"
}
rewritten, err := json.Marshal(bodyObj)
if err != nil {
return fmt.Errorf("重写 Bedrock 请求 body 失败:%w", err)
}
// 3) 构建 Bedrock URL
// 新一代 Claude 模型(3.5v2、3.7、Sonnet-4、Opus-4 等)要求通过跨区域推理配置文件调用,
// 模型 ID 需加地理前缀(us. / eu. / ap.),否则 Bedrock 返回 "on-demand throughput isn't supported" 错误。
// 若调用方已传入带前缀的 ID(如 us.anthropic.xxx)则直接使用,不重复添加。
invokeModelID := bedrockResolveModelID(modelID, region)
if invokeModelID != modelID {
global.GVA_LOG.Info("Bedrock 模型 ID 已映射为跨区域推理配置文件",
zap.String("original", modelID),
zap.String("resolved", invokeModelID),
zap.String("region", region),
)
}
host := fmt.Sprintf("bedrock-runtime.%s.amazonaws.com", region)
op := "invoke"
if streaming {
op = "invoke-with-response-stream"
}
requestURL := fmt.Sprintf("https://%s/model/%s/%s", host, url.PathEscape(invokeModelID), op)
// 打印请求地址、参数和代理地址
global.GVA_LOG.Info("Bedrock 请求详情",
zap.String("request_url", requestURL),
zap.String("method", method),
zap.ByteString("body", rewritten),
zap.String("proxy_url", creds.BedrockProxyURL),
)
httpReq, err := http.NewRequest(method, requestURL, bytes.NewReader(rewritten))
if err != nil {
return fmt.Errorf("构建 Bedrock 请求失败:%w", err)
}
httpReq.Header.Set("Content-Type", "application/json")
if streaming {
httpReq.Header.Set("Accept", "application/vnd.amazon.eventstream")
httpReq.Header.Set("X-Amzn-Bedrock-Accept", "application/json")
} else {
httpReq.Header.Set("Accept", "application/json")
}
// 4) SigV4 签名(service=bedrock
awsCreds := credentials.NewStaticCredentials(creds.AWSAccessKeyID, creds.AWSSecretAccessKey, creds.AWSSessionToken)
signer := v4.NewSigner(awsCreds)
if _, err = signer.Sign(httpReq, bytes.NewReader(rewritten), "bedrock", region, time.Now()); err != nil {
return fmt.Errorf("Bedrock SigV4 签名失败:%w", err)
}
// 5) 发起请求(若配置了 bedrock_proxy_url 则经 HTTP 代理转发)
startTime := time.Now()
transport := http.DefaultTransport
if creds.BedrockProxyURL != "" {
proxyAddr := creds.BedrockProxyURL
if !strings.HasPrefix(proxyAddr, "http://") && !strings.HasPrefix(proxyAddr, "https://") {
proxyAddr = "http://" + proxyAddr
}
if proxyURL, parseErr := url.Parse(proxyAddr); parseErr == nil {
transport = &http.Transport{Proxy: http.ProxyURL(proxyURL)}
}
}
client := &http.Client{Timeout: 5 * time.Minute, Transport: transport}
resp, err := client.Do(httpReq)
if err != nil {
s.logBedrock(userID, modelID, "error", err.Error(), startTime, 0, 0)
return err
}
defer func() { _ = resp.Body.Close() }()
// 6) 写回响应头/状态码(流式改写 Content-Type 为 SSE
if w, ok := writer.(http.ResponseWriter); ok {
for k, v := range resp.Header {
lower := strings.ToLower(k)
if streaming && (lower == "content-type" || lower == "content-length") {
continue
}
for _, vv := range v {
w.Header().Add(k, vv)
}
}
if streaming {
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
}
w.WriteHeader(resp.StatusCode)
}
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
raw, _ := io.ReadAll(resp.Body)
_, _ = writer.Write(raw)
s.logBedrock(userID, modelID, "error",
fmt.Sprintf("bedrock %d: %s", resp.StatusCode, string(raw)), startTime, 0, 0)
return nil
}
// 7) 处理响应体
var inputTokens, outputTokens int
if streaming {
inputTokens, outputTokens, err = s.streamBedrockEventStream(resp.Body, writer)
if err != nil {
s.logBedrock(userID, modelID, "error", err.Error(), startTime, inputTokens, outputTokens)
return err
}
} else {
var buf bytes.Buffer
tee := io.TeeReader(resp.Body, &buf)
if _, err = io.Copy(writer, tee); err != nil {
s.logBedrock(userID, modelID, "error", err.Error(), startTime, 0, 0)
return err
}
inputTokens, outputTokens = parseAnthropicUsage(buf.Bytes())
}
// 8) 记录日志 + 计费扣款
s.logBedrock(userID, modelID, "success", "", startTime, inputTokens, outputTokens)
if inputTokens > 0 || outputTokens > 0 {
pricing, _ := s.fetchModelPricingFromDify(modelID)
delta := calcQuotaDelta(pricing, modelID, inputTokens, outputTokens)
deductAccountQuota(userID, delta)
}
return nil
}
// streamBedrockEventStream 解析 Bedrock 的 vnd.amazon.eventstream 二进制流,
// 把每个事件还原为 Anthropic SSEevent: <type>\ndata: <json>\n\n)写给客户端。
// 返回累计的 input/output token 数(用于计费)。
func (s *ModelProviderService) streamBedrockEventStream(r io.Reader, w io.Writer) (int, int, error) {
flusher, _ := w.(http.Flusher)
dec := estream.NewDecoder(r)
payloadBuf := make([]byte, 0, 32*1024)
var inputTokens, outputTokens int
for {
msg, err := dec.Decode(payloadBuf)
if err != nil {
if err == io.EOF {
return inputTokens, outputTokens, nil
}
return inputTokens, outputTokens, fmt.Errorf("eventstream decode 失败:%w", err)
}
// Bedrock 的事件 payload 形如 {"bytes":"<base64-encoded inner JSON>"}
var wrap struct {
Bytes string `json:"bytes"`
}
var inner []byte
if e := json.Unmarshal(msg.Payload, &wrap); e == nil && wrap.Bytes != "" {
if decoded, e2 := base64.StdEncoding.DecodeString(wrap.Bytes); e2 == nil {
inner = decoded
}
}
if len(inner) == 0 {
// 非包装格式(如错误/ping),直接用原 payload
inner = msg.Payload
}
// 解析事件类型和 usageAnthropic 在 message_start.message.usage 给 input_tokens
// message_delta.usage 给 output_tokens
var ev struct {
Type string `json:"type"`
Message struct {
Usage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
} `json:"usage"`
} `json:"message"`
Usage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
} `json:"usage"`
}
_ = json.Unmarshal(inner, &ev)
if ev.Message.Usage.InputTokens > 0 {
inputTokens = ev.Message.Usage.InputTokens
}
if ev.Message.Usage.OutputTokens > 0 {
outputTokens = ev.Message.Usage.OutputTokens
}
if ev.Usage.InputTokens > 0 {
inputTokens = ev.Usage.InputTokens
}
if ev.Usage.OutputTokens > 0 {
outputTokens = ev.Usage.OutputTokens
}
// 重组为 Anthropic SSE 写回
eventName := ev.Type
if eventName == "" {
eventName = "message"
}
sse := "event: " + eventName + "\ndata: " + string(inner) + "\n\n"
if _, err = w.Write([]byte(sse)); err != nil {
return inputTokens, outputTokens, err
}
if flusher != nil {
flusher.Flush()
}
}
}
// parseAnthropicUsage 从非流式 Anthropic Messages 响应 JSON 中提取 usage 字段。
func parseAnthropicUsage(data []byte) (input, output int) {
var obj struct {
Usage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
} `json:"usage"`
}
if json.Unmarshal(data, &obj) == nil {
return obj.Usage.InputTokens, obj.Usage.OutputTokens
}
return 0, 0
}
// bedrockCrossRegionPrefixes 是需要跨区域推理配置文件的模型 ID 前缀列表(anthropic. 开头)。
// 来源:https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-support.html
// 规则:凡模型不在旧版 on-demand 列表中,均需加地理前缀才能调用。
var bedrockCrossRegionPrefixes = []string{
// Claude 3.5 v2
"anthropic.claude-3-5-sonnet-20241022-v2",
"anthropic.claude-3-5-haiku-20241022",
// Claude 3.7
"anthropic.claude-3-7-sonnet",
// Claude Sonnet 4 / Opus 4 / Haiku 4(及后续新版本)
"anthropic.claude-sonnet-4",
"anthropic.claude-opus-4",
"anthropic.claude-haiku-4",
}
// bedrockResolveModelID 将用户输入的模型 ID 解析为正确的 Bedrock 调用 ID。
//
// 支持以下输入格式(会自动规范化):
// - 完整带前缀:us.anthropic.claude-sonnet-4-6 → 直接使用
// - 带厂商前缀:anthropic.claude-sonnet-4-6 → 按需加 us./eu./ap.
// - 短横线名称:claude-sonnet-4-6 → 补 anthropic. 再按需加前缀
// - 空格+点号: "claude sonnet 4.6" → 规范化后同上
func bedrockResolveModelID(modelID, region string) string {
// Step 1: 规范化输入
// 小写、空格→横线、版本中的点→横线(如 4.6 → 4-6,但保留 : 用于版本后缀如 v2:0)
normalized := strings.ToLower(strings.TrimSpace(modelID))
normalized = strings.ReplaceAll(normalized, " ", "-")
// 仅将数字之间的 `.` 替换为 `-`(处理 "4.6" → "4-6"),保留 anthropic. 这样的厂商点
normalized = replaceVersionDots(normalized)
// Step 2: 补全 anthropic. 前缀(用户只填了 claude-xxx
if !strings.HasPrefix(normalized, "anthropic.") &&
!strings.HasPrefix(normalized, "us.") &&
!strings.HasPrefix(normalized, "eu.") &&
!strings.HasPrefix(normalized, "ap.") {
normalized = "anthropic." + normalized
}
// Step 3: 已经带地理前缀 → 直接使用
if strings.HasPrefix(normalized, "us.") || strings.HasPrefix(normalized, "eu.") || strings.HasPrefix(normalized, "ap.") {
return normalized
}
// Step 4: 判断是否属于需要跨区域推理配置文件的模型
needsCrossRegion := false
for _, prefix := range bedrockCrossRegionPrefixes {
if strings.HasPrefix(normalized, prefix) {
needsCrossRegion = true
break
}
}
if !needsCrossRegion {
return normalized
}
// Step 5: 根据 region 推导地理前缀
geoPrefix := "us" // 默认
switch {
case strings.HasPrefix(region, "us-"):
geoPrefix = "us"
case strings.HasPrefix(region, "eu-"):
geoPrefix = "eu"
case strings.HasPrefix(region, "ap-"):
geoPrefix = "ap"
}
return geoPrefix + "." + normalized
}
// replaceVersionDots 将版本号中数字之间的 `.` 替换为 `-`(如 4.6 → 4-6),
// 保留厂商命名空间中的点(如 anthropic. 开头不受影响,因为点后紧跟字母)。
func replaceVersionDots(s string) string {
var b strings.Builder
for i := 0; i < len(s); i++ {
if s[i] == '.' && i > 0 && i < len(s)-1 {
prev := s[i-1]
next := s[i+1]
// 仅当点号两侧都是数字时才替换为 -
if prev >= '0' && prev <= '9' && next >= '0' && next <= '9' {
b.WriteByte('-')
continue
}
}
b.WriteByte(s[i])
}
return b.String()
}
// logBedrock 记录代理日志(与 ProxyRequest 中的 ModelProxyLog 行为一致)。
func (s *ModelProviderService) logBedrock(userID, modelID, status, errMsg string, startTime time.Time, in, out int) {
if err := global.GVA_DB.Create(&gaia.ModelProxyLog{
UserId: userID,
ProviderName: gaia.ProviderAWS,
ModelName: modelID,
RequestTokens: in,
ResponseTokens: out,
Status: status,
ErrorMessage: errMsg,
CreatedAt: startTime,
}).Error; err != nil {
global.GVA_LOG.Warn("logBedrock 写日志失败", zap.Error(err))
}
}
+22 -35
View File
@@ -19,7 +19,8 @@ import (
type DashboardService struct{}
// GetAccountQuotaRankingData 分页获取【账号】额度排名列表
func (dashboardService *DashboardService) GetAccountQuotaRankingData(info gaiaReq.GetAccountQuotaRankingDataReq) (list []response.GetAccountQuotaRankingDataRes, total int64, err error) {
func (s *DashboardService) GetAccountQuotaRankingData(info gaiaReq.GetAccountQuotaRankingDataReq) (
list []response.GetAccountQuotaRankingDataRes, total int64, err error) {
limit := info.PageSize
offset := info.PageSize * (info.Page - 1)
@@ -79,15 +80,13 @@ func (dashboardService *DashboardService) GetAccountQuotaRankingData(info gaiaRe
}
// GetAppQuotaRankingData 分页获取【应用】配额排名数据
func (dashboardService *DashboardService) GetAppQuotaRankingData(info gaiaReq.GetAppQuotaRankingDataReq) (list []response.GetAppQuotaRankingDataRes, total int64, err error) {
func (s *DashboardService) GetAppQuotaRankingData(info gaiaReq.GetAppQuotaRankingDataReq) (
list []response.GetAppQuotaRankingDataRes, total int64, err error) {
cacheKey := fmt.Sprintf("app_token_quota_ranking:%d:%d", info.Page, info.PageSize)
var cachedResult struct {
List []response.GetAppQuotaRankingDataRes
Total int64
}
var cachedResult response.AppQuotaRankingCache
if found, err := dashboardService.getCachedResult(cacheKey, &cachedResult); err == nil && found {
if found, err := s.getCachedResult(cacheKey, &cachedResult); err == nil && found {
return cachedResult.List, cachedResult.Total, nil
}
@@ -102,7 +101,7 @@ func (dashboardService *DashboardService) GetAppQuotaRankingData(info gaiaReq.Ge
Select("" +
"app_id, " +
"COUNT(id) as message_num, " +
"SUM(CASE WHEN currency = 'RMB' THEN total_price / 7.26 ELSE total_price END) as message_cost").
fmt.Sprintf("SUM(CASE WHEN currency = 'RMB' THEN total_price / %f ELSE total_price END) as message_cost", gaia.RmbToUSDRate)).
Group("app_id")
workflowCosts := global.GVA_DB.Table("public.workflow_node_executions").
@@ -111,7 +110,7 @@ func (dashboardService *DashboardService) GetAppQuotaRankingData(info gaiaReq.Ge
"COUNT(id) as workflow_num, " +
"SUM(CASE " +
" WHEN execution_metadata::json->>'currency' = 'RMB' " +
" THEN CAST((execution_metadata::json->>'total_price') AS NUMERIC) / 7.26 " +
fmt.Sprintf(" THEN CAST((execution_metadata::json->>'total_price') AS NUMERIC) / %f ", gaia.RmbToUSDRate) +
" ELSE CAST((execution_metadata::json->>'total_price') AS NUMERIC) " +
"END) AS workflow_cost").
Where("execution_metadata IS NOT NULL AND execution_metadata != '' AND (execution_metadata::json->>'total_price') IS NOT NULL").
@@ -140,13 +139,7 @@ func (dashboardService *DashboardService) GetAppQuotaRankingData(info gaiaReq.Ge
}
// 执行查询
var results []struct {
AppID string `gorm:"column:app_id"`
TotalCost float64 `gorm:"column:total_cost"`
MessageCost float64 `gorm:"column:message_cost"`
WorkflowCost float64 `gorm:"column:workflow_cost"`
RecordNum float64 `gorm:"column:record_num"`
}
var results []response.AppQuotaRankingRow
err = query.Find(&results).Error
if err != nil {
@@ -269,12 +262,8 @@ func (dashboardService *DashboardService) GetAppQuotaRankingData(info gaiaReq.Ge
}
// 在返回结果之前,缓存结果
result := struct {
List []response.GetAppQuotaRankingDataRes
Total int64
}{list, total}
if err := dashboardService.cacheResult(cacheKey, result, 24*time.Hour); err != nil {
cachePayload := response.AppQuotaRankingCache{List: list, Total: total}
if err := s.cacheResult(cacheKey, cachePayload, 24*time.Hour); err != nil {
global.GVA_LOG.Error("Failed to cache result", zap.Error(err))
}
@@ -282,7 +271,8 @@ func (dashboardService *DashboardService) GetAppQuotaRankingData(info gaiaReq.Ge
}
// GetAppTokenQuotaRankingData 分页获取【应用密钥】配额排名数据列表
func (dashboardService *DashboardService) GetAppTokenQuotaRankingData(info gaiaReq.GetAppTokenQuotaRankingDataReq) (list []response.GetAppTokenQuotaRankingDataRes, total int64, err error) {
func (s *DashboardService) GetAppTokenQuotaRankingData(info gaiaReq.GetAppTokenQuotaRankingDataReq) (
list []response.GetAppTokenQuotaRankingDataRes, total int64, err error) {
limit := info.PageSize
offset := info.PageSize * (info.Page - 1)
@@ -365,9 +355,11 @@ func (dashboardService *DashboardService) GetAppTokenQuotaRankingData(info gaiaR
}
// GetAppTokenDailyQuotaData 获取每天密钥花费数据列表
func (dashboardService *DashboardService) GetAppTokenDailyQuotaData(info gaiaReq.GetAppTokenDailyQuotaDataReq) (list []response.GetAppTokenDailyQuotaDataRes, err error) {
func (s *DashboardService) GetAppTokenDailyQuotaData(info gaiaReq.GetAppTokenDailyQuotaDataReq) (
list []response.GetAppTokenDailyQuotaDataRes, err error) {
db := global.GVA_DB.Select("DATE(stat_at) as stat_at, SUM(day_used_quota) as day_used_quota").Model(&gaia.ApiTokenMoneyDailyStatExtend{}).Order("stat_at desc").Group("DATE(stat_at)")
db := global.GVA_DB.Select("DATE(stat_at) as stat_at, SUM(day_used_quota) as day_used_quota").Model(
&gaia.ApiTokenMoneyDailyStatExtend{}).Order("stat_at desc").Group("DATE(stat_at)")
var apiTokenMoneyDailyStatExtends []gaia.ApiTokenMoneyDailyStatExtend
if info.AppId != "" {
@@ -397,7 +389,8 @@ func (dashboardService *DashboardService) GetAppTokenDailyQuotaData(info gaiaReq
}
// GetAiImageQuotaRankingData 获取【AI图片】使用量排名数据列表
func (dashboardService *DashboardService) GetAiImageQuotaRankingData(info gaiaReq.GetAiImageQuotaRankingDataReq) (list []response.GetAiImageQuotaRankingRes, err error) {
func (s *DashboardService) GetAiImageQuotaRankingData(info gaiaReq.GetAiImageQuotaRankingDataReq) (
list []response.GetAiImageQuotaRankingRes, err error) {
limit := info.PageSize
offset := info.PageSize * (info.Page - 1)
@@ -427,13 +420,7 @@ func (dashboardService *DashboardService) GetAiImageQuotaRankingData(info gaiaRe
db = db.Limit(limit).Offset(offset)
}
var results []struct {
Address string
Path string
TotalCost float64
RecordNum int
Model string
}
var results []response.AiImageQuotaRankingRow
err = db.Find(&results).Error
if err != nil {
@@ -456,7 +443,7 @@ func (dashboardService *DashboardService) GetAiImageQuotaRankingData(info gaiaRe
return list, nil
}
func (dashboardService *DashboardService) cacheResult(key string, data interface{}, expiration time.Duration) error {
func (s *DashboardService) cacheResult(key string, data interface{}, expiration time.Duration) error {
jsonData, err := json.Marshal(data)
if err != nil {
@@ -465,7 +452,7 @@ func (dashboardService *DashboardService) cacheResult(key string, data interface
return global.GVA_REDIS.Set(context.Background(), key, jsonData, expiration).Err()
}
func (dashboardService *DashboardService) getCachedResult(key string, result interface{}) (bool, error) {
func (s *DashboardService) getCachedResult(key string, result interface{}) (bool, error) {
data, err := global.GVA_REDIS.Get(context.Background(), key).Bytes()
if err != nil {
if errors.Is(err, redis.Nil) {
+273
View File
@@ -0,0 +1,273 @@
package gaia
import (
"testing"
"github.com/flipped-aurora/gin-vue-admin/server/model/gaia/request"
)
// TestBuildURL_NewFormat 测试新格式 Params 自动拼接 URL
func TestBuildURL_NewFormat(t *testing.T) {
config := request.EmailApiConfig{
URL: "https://api.example.com/user",
Params: []request.RequestParam{},
}
// 无参数
got := buildURL("https://api.example.com/user", config, "USER123")
if got != "https://api.example.com/user" {
t.Errorf("无参数时 URL 应保持不变,got: %s", got)
}
// 单个 string 类型参数
config.Params = []request.RequestParam{
{Key: "appKey", ValueType: "string", Value: "mykey"},
}
got = buildURL("https://api.example.com/user", config, "USER123")
if got != "https://api.example.com/user?appKey=mykey" {
t.Errorf("单参数拼接错误,got: %s", got)
}
// URL 已有 ? 时使用 &
got = buildURL("https://api.example.com/user?type=admin", config, "USER123")
if got != "https://api.example.com/user?type=admin&appKey=mykey" {
t.Errorf("已有参数时应使用 & 拼接,got: %s", got)
}
}
// TestBuildURL_DingIDParam 测试钉钉 ID 类型参数自动替换
func TestBuildURL_DingIDParam(t *testing.T) {
config := request.EmailApiConfig{
URL: "https://api.example.com/user",
Params: []request.RequestParam{
{Key: "userId", ValueType: "ding_id"},
},
}
got := buildURL("https://api.example.com/user", config, "USER123")
if got != "https://api.example.com/user?userId=USER123" {
t.Errorf("钉钉 ID 类型参数应自动替换,got: %s", got)
}
}
// TestBuildURL_OldFormat 测试旧格式 RequestParamField 兼容
func TestBuildURL_OldFormat(t *testing.T) {
config := request.EmailApiConfig{
URL: "https://api.example.com/user",
RequestParamField: "userId",
Params: nil, // 旧格式:Params 为 nil
}
got := buildURL("https://api.example.com/user", config, "USER123")
if got != "https://api.example.com/user?userId=USER123" {
t.Errorf("旧格式应使用 RequestParamFieldgot: %s", got)
}
}
// TestResolveParamValue 测试参数值解析
func TestResolveParamValue(t *testing.T) {
tests := []struct {
vt string
value string
dingId string
want string
}{
{"ding_id", "", "USER123", "USER123"},
{"string", "myvalue", "USER123", "myvalue"},
{"string", "prefix_{{ding_id}}_suffix", "USER123", "prefix_USER123_suffix"},
{"string", "prefix_$<{[ding_id]}>_suffix", "USER123", "prefix_USER123_suffix"},
{"int", "42", "USER123", "42"},
{"bool", "true", "USER123", "true"},
}
for _, tt := range tests {
got := resolveParamValue(tt.vt, tt.value, tt.dingId)
if got != tt.want {
t.Errorf("resolveParamValue(%q, %q, %q) = %q, want %q", tt.vt, tt.value, tt.dingId, got, tt.want)
}
}
}
// TestExtractJSONPathAdvanced 测试 JSON 路径提取
func TestExtractJSONPathAdvanced(t *testing.T) {
data := map[string]interface{}{
"code": float64(0),
"data": []interface{}{
map[string]interface{}{
"userName": "test@example.com",
"userId": "USER123",
},
},
"nested": map[string]interface{}{
"email": "nested@example.com",
},
}
tests := []struct {
path string
want string
}{
{"code", "0"},
{"data[0].userName", "test@example.com"},
{"data[0].userId", "USER123"},
{"nested.email", "nested@example.com"},
{"notexist", ""},
{"data[1].userName", ""},
}
for _, tt := range tests {
got := extractJSONPathAdvanced(data, tt.path)
if got != tt.want {
t.Errorf("extractJSONPathAdvanced(data, %q) = %q, want %q", tt.path, got, tt.want)
}
}
}
// TestParseEmailApiConfigFromJSON_NewFormat 测试新格式解析
func TestParseEmailApiConfigFromJSON_NewFormat(t *testing.T) {
jsonStr := `{
"enabled": true,
"url": "https://api.example.com",
"method": "GET",
"params": [
{"key": "userId", "value_type": "ding_id", "value": ""},
{"key": "appKey", "value_type": "string", "value": "mykey"}
],
"response_email_field": "data[0].email"
}`
cfg, err := parseEmailApiConfigFromJSON([]byte(jsonStr))
if err != nil {
t.Fatalf("解析新格式配置失败: %v", err)
}
if !isNewEmailApiConfig(cfg) {
t.Error("应检测为新格式")
}
if len(cfg.Params) != 2 {
t.Errorf("Params 数量错误,got: %d", len(cfg.Params))
}
}
// TestParseEmailApiConfigFromJSON_OldFormat 测试旧格式兼容解析
func TestParseEmailApiConfigFromJSON_OldFormat(t *testing.T) {
jsonStr := `{
"enabled": true,
"url": "https://api.example.com",
"method": "GET",
"request_param_field": "userId",
"response_email_field": "data[0].email",
"body_data": {
"form_data": [{"userId": ""}],
"urlencoded": []
}
}`
cfg, err := parseEmailApiConfigFromJSON([]byte(jsonStr))
if err != nil {
t.Fatalf("解析旧格式配置失败: %v", err)
}
if isNewEmailApiConfig(cfg) {
t.Error("旧格式配置不应检测为新格式")
}
if cfg.RequestParamField != "userId" {
t.Errorf("RequestParamField 应为 userIdgot: %s", cfg.RequestParamField)
}
}
// TestValidateEmailApiConfigFields_NewFormat 测试新格式配置验证
func TestValidateEmailApiConfigFields_NewFormat(t *testing.T) {
cfg := request.EmailApiConfig{
Enabled: true,
URL: "https://api.example.com",
Method: "GET",
Params: []request.RequestParam{},
ResponseEmailField: "data[0].email",
}
if err := validateEmailApiConfigFields(cfg); err != nil {
t.Errorf("有效的新格式配置不应报错: %v", err)
}
}
// TestValidateEmailApiConfigFields_InvalidParamType 测试 Params 不支持 int 类型
func TestValidateEmailApiConfigFields_InvalidParamType(t *testing.T) {
cfg := request.EmailApiConfig{
Enabled: true,
URL: "https://api.example.com",
Method: "GET",
Params: []request.RequestParam{
{Key: "count", ValueType: "int", Value: "10"}, // Params 不支持 int
},
ResponseEmailField: "data[0].email",
}
if err := validateEmailApiConfigFields(cfg); err == nil {
t.Error("Params 不应支持 int 类型,应报错")
}
}
// TestValidateEmailApiConfigFields_InvalidBodyType 测试 Body 不支持未知类型
func TestValidateEmailApiConfigFields_InvalidBodyType(t *testing.T) {
cfg := request.EmailApiConfig{
Enabled: true,
URL: "https://api.example.com",
Method: "POST",
Params: []request.RequestParam{},
BodyData: request.BodyData{
FormData: []request.BodyField{
{Key: "field1", ValueType: "invalid_type", Value: "val"},
},
},
ResponseEmailField: "data[0].email",
}
if err := validateEmailApiConfigFields(cfg); err == nil {
t.Error("Body 不应支持未知类型,应报错")
}
}
// TestBuildBodyFields 测试 Body 字段类型转换
func TestBuildBodyFields(t *testing.T) {
fields := []request.BodyField{
{Key: "userId", ValueType: "ding_id", Value: ""},
{Key: "appKey", ValueType: "string", Value: "mykey"},
{Key: "count", ValueType: "int", Value: "10"},
{Key: "enabled", ValueType: "bool", Value: "true"},
}
form := buildBodyFields(fields, "USER123")
if form.Get("userId") != "USER123" {
t.Errorf("ding_id 类型应替换为实际钉钉 IDgot: %s", form.Get("userId"))
}
if form.Get("appKey") != "mykey" {
t.Errorf("string 类型应直接使用,got: %s", form.Get("appKey"))
}
if form.Get("count") != "10" {
t.Errorf("int 类型值不正确,got: %s", form.Get("count"))
}
}
// TestDingIDMarkerReplacement 测试 Raw 模式钉钉 ID 标记替换
func TestDingIDMarkerReplacement(t *testing.T) {
raw := `{"userId": "$<{[ding_id]}>", "other": "value"}`
dingId := "USER789"
// 使用 resolveParamValue 测试标记替换
replaced := resolveParamValue("string", raw, dingId)
expected := `{"userId": "USER789", "other": "value"}`
if replaced != expected {
t.Errorf("钉钉 ID 标记替换错误\ngot: %s\nwant: %s", replaced, expected)
}
}
// TestDingIDMarkerOldFormat 测试旧格式 {{ding_id}} 占位符替换
func TestDingIDMarkerOldFormat(t *testing.T) {
raw := `{"userId": "{{ding_id}}", "other": "value"}`
replaced := resolveParamValue("string", raw, "USER789")
expected := `{"userId": "USER789", "other": "value"}`
if replaced != expected {
t.Errorf("旧格式占位符替换错误\ngot: %s\nwant: %s", replaced, expected)
}
}
+119 -26
View File
@@ -136,6 +136,7 @@ func (e *SystemIntegratedService) OAuth2CodeLogin(
return nil, fmt.Errorf("无法从 OAuth2 用户信息中获取邮箱或用户唯一标识")
}
fmt.Println("OAuth2CodeLogin", email, username)
sysUser, err := e.findUserByEmailOrPhone(email, userID)
if err != nil {
return nil, err
@@ -148,8 +149,48 @@ func (e *SystemIntegratedService) OAuth2CodeLogin(
return &response.GaiaLoginResult{User: *sysUser, Token: token, RedirectURI: req.RedirectURI, State: req.State}, nil
}
// DingTalkTestCallback 仅用 code 换 token,用于「测试连接」回调,不登录、不写 session
func (e *SystemIntegratedService) DingTalkTestCallback(code string) error {
code = strings.TrimSpace(code)
if code == "" {
return fmt.Errorf("授权码为空")
}
integrate := e.getIntegratedConfigRaw(gaia.SystemIntegrationDingTalk)
if integrate.AppKey == "" || integrate.AppSecret == "" {
return fmt.Errorf("钉钉配置不完整")
}
bodyJSON, _ := json.Marshal(map[string]string{
"clientId": integrate.AppKey,
"clientSecret": integrate.AppSecret,
"code": code,
"grantType": "authorization_code",
})
httpReq, err := http.NewRequest("POST", "https://api.dingtalk.com/v1.0/oauth2/userAccessToken", bytes.NewReader(bodyJSON))
if err != nil {
return err
}
httpReq.Header.Set("Content-Type", "application/json")
client := &http.Client{}
resp, err := client.Do(httpReq)
if err != nil {
return fmt.Errorf("钉钉 token 请求失败: %w", err)
}
defer resp.Body.Close()
respBody, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
global.GVA_LOG.Error("钉钉 token 非 200", zap.Int("status", resp.StatusCode), zap.String("body", string(respBody)))
return fmt.Errorf("钉钉返回错误: %d", resp.StatusCode)
}
var tokenResp map[string]interface{}
if err = json.Unmarshal(respBody, &tokenResp); err != nil || tokenResp["accessToken"] == "" {
return fmt.Errorf("解析钉钉 token 失败")
}
return nil
}
// DingTalkCodeLogin 钉钉 code 换用户并登录(扫码/OAuth2 回调带 code
func (e *SystemIntegratedService) DingTalkCodeLogin(req request.GaiaDingTalkLoginReq) (*response.GaiaLoginResult, error) {
func (e *SystemIntegratedService) DingTalkCodeLogin(
req request.GaiaDingTalkLoginReq) (*response.GaiaLoginResult, error) {
integrate := e.getIntegratedConfigRaw(gaia.SystemIntegrationDingTalk)
if !integrate.Status {
return nil, fmt.Errorf("钉钉登录未启用")
@@ -165,7 +206,8 @@ func (e *SystemIntegratedService) DingTalkCodeLogin(req request.GaiaDingTalkLogi
"code": req.AuthCode,
"grantType": "authorization_code",
})
httpReq, err := http.NewRequest("POST", "https://api.dingtalk.com/v1.0/oauth2/userAccessToken", bytes.NewReader(bodyJSON))
httpReq, err := http.NewRequest("POST",
"https://api.dingtalk.com/v1.0/oauth2/userAccessToken", bytes.NewReader(bodyJSON))
if err != nil {
return nil, err
}
@@ -179,7 +221,8 @@ func (e *SystemIntegratedService) DingTalkCodeLogin(req request.GaiaDingTalkLogi
defer resp.Body.Close()
respBody, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
global.GVA_LOG.Error("钉钉 token 非 200", zap.Int("status", resp.StatusCode), zap.String("body", string(respBody)))
global.GVA_LOG.Error("钉钉 token 非 200", zap.Int(
"status", resp.StatusCode), zap.String("body", string(respBody)))
return nil, fmt.Errorf("钉钉返回错误: %d", resp.StatusCode)
}
@@ -205,8 +248,56 @@ func (e *SystemIntegratedService) DingTalkCodeLogin(req request.GaiaDingTalkLogi
if err = json.Unmarshal(userBody, &dingUser); err != nil {
return nil, fmt.Errorf("解析钉钉用户信息失败")
}
email := dingUser["email"].(string)
username := dingUser["nick"].(string)
// 提取钉钉 IDuser_id 字段)
dingId := ""
if v, ok := dingUser["unionId"]; ok && v != nil {
dingId, _ = v.(string)
}
if dingId == "" {
if v, ok := dingUser["userId"]; ok && v != nil {
dingId, _ = v.(string)
}
}
// 解析用户名配置
var emailList []string
var configMap request.DingTalkConfigRequest
var emailConfig request.EmailApiConfig
if integrate.Config != "" {
if jsonErr := json.Unmarshal([]byte(integrate.Config), &configMap); jsonErr == nil {
var rawMsg json.RawMessage
if rawBytes, marshalErr := json.Marshal(configMap.EmailApi); marshalErr == nil {
rawMsg = rawBytes
if cfg, parseErr := parseEmailApiConfigFromJSON(rawMsg); parseErr == nil {
emailConfig = cfg
}
}
}
}
// 优先通过用户名 API 获取用户名(新格式)
if emailConfig.Enabled && dingId != "" {
emailList, err = e.callEmailApi(dingId, emailConfig)
if err == nil && len(emailList) > 0 {
fmt.Println("钉钉 code 换用户并登录(扫码/OAuth2 回调带 code", emailList)
sysUser, findErr := e.findUserByEmail(emailList)
if findErr != nil {
return nil, findErr
}
token, _, tokenErr := utils.LoginToken(sysUser)
if tokenErr != nil {
return nil, fmt.Errorf("签发 token 失败")
}
return &response.GaiaLoginResult{User: *sysUser, Token: token, RedirectURI: req.RedirectURI, State: req.State}, nil
}
global.GVA_LOG.Warn("DingTalkCodeLogin: 第三方邮箱 API 获取失败,尝试钉钉直接返回邮箱",
zap.String("ding_id", dingId), zap.Error(err))
}
// 回退:直接从钉钉用户信息获取邮箱
email, _ := dingUser["email"].(string)
username, _ := dingUser["nick"].(string)
if username == "" {
username = email
}
@@ -214,7 +305,8 @@ func (e *SystemIntegratedService) DingTalkCodeLogin(req request.GaiaDingTalkLogi
return nil, fmt.Errorf("钉钉未返回邮箱")
}
sysUser, err := e.findUserByEmail(email)
fmt.Println("钉钉 code 换用户并登录第三方邮箱 API 获取失败", email)
sysUser, err := e.findUserByEmail([]string{email})
if err != nil {
return nil, err
}
@@ -231,7 +323,8 @@ func getStringFromMap(m map[string]interface{}, keys ...string) string {
continue
}
if v, ok := m[k]; ok && v != nil {
if s, ok := v.(string); ok {
var s string
if s, ok = v.(string); ok {
return s
}
}
@@ -297,21 +390,21 @@ func getStringByPathOrKeys(m map[string]interface{}, path string, fallbackKeys .
return getStringFromMap(m, fallbackKeys...)
}
// findUserByEmail 按邮箱查找已存在的用户(需在 gaia.accounts 中有对应记录方可签发 JWT)
func (e *SystemIntegratedService) findUserByEmail(email string) (*system.SysUser, error) {
var u system.SysUser
var mailList []string
mailList = append(mailList, email)
parts := strings.Split(email, "@")
defaultMail := os.Getenv(gaia.EmailDomainEnv)
if len(defaultMail) > 0 && len(parts) == 2 {
mailList = append(mailList, parts[0]+"@"+defaultMail)
}
// findUserByEmail 按username查找已存在的用户(需在 gaia.accounts 中有对应记录方可签发 JWT)
func (e *SystemIntegratedService) findUserByEmail(mailList []string) (*system.SysUser, error) {
// 查询关联邮箱
var u system.SysUser
if len(mailList) == 1 {
parts := strings.Split(mailList[0], "@")
defaultMail := os.Getenv(gaia.EmailDomainEnv)
if len(defaultMail) > 1 && len(parts) > 1 && len(parts[0]) > 0 {
mailList = append(mailList, parts[0]+"@"+defaultMail)
}
}
if err := global.GVA_DB.Where("email IN (?)", mailList).Preload(
"Authorities").Preload("Authority").First(&u).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, fmt.Errorf(fmt.Sprintf("邮箱%s尚未开通账号,请联系管理员", email))
return nil, fmt.Errorf("%s尚未开通账号,请联系管理员", mailList[0])
}
return nil, err
}
@@ -323,20 +416,20 @@ func (e *SystemIntegratedService) findUserByEmail(email string) (*system.SysUser
}
// findUserByEmailOrPhone 按邮箱或用户唯一标识(如手机号)查找用户,优先邮箱
func (e *SystemIntegratedService) findUserByEmailOrPhone(email, userID string) (*system.SysUser, error) {
if email != "" {
u, err := e.findUserByEmail(email)
if err == nil {
func (e *SystemIntegratedService) findUserByEmailOrPhone(
mail, userID string) (u *system.SysUser, err error) {
if mail != "" {
if u, err = e.findUserByEmail([]string{mail}); err == nil {
return u, nil
}
// 仅当“未开通”时再尝试按 userID(phone) 查,其他错误直接返回
if err != nil && !strings.Contains(err.Error(), "尚未开通") {
if !strings.Contains(err.Error(), "尚未开通") {
return nil, err
}
}
if userID != "" {
var u system.SysUser
if err := global.GVA_DB.Where("phone = ?", userID).Preload("Authorities").Preload("Authority").First(&u).Error; err != nil {
if err = global.GVA_DB.Where("phone = ?", userID).Preload(
"Authorities").Preload("Authority").First(&u).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, fmt.Errorf("该用户唯一标识尚未开通后台账号,请联系管理员")
}
@@ -345,7 +438,7 @@ func (e *SystemIntegratedService) findUserByEmailOrPhone(email, userID string) (
if u.Enable != 1 {
return nil, fmt.Errorf("账号已被禁用")
}
return &u, nil
return u, nil
}
return nil, fmt.Errorf("无法从 OAuth2 用户信息中获取邮箱或用户唯一标识")
}
+18 -2
View File
@@ -14,10 +14,10 @@ import (
// GetLoginOptions 获取登录方式选项(供登录页展示钉钉/OAuth2 按钮,不暴露密钥)
func (e *SystemIntegratedService) GetLoginOptions(frontendOrigin string) (res response.LoginOptionsResponse) {
// 非本地的需要加上admin
// 非本地的需要加上 admin(若 Referer 已带 /admin 则不再追加,避免 /admin/admin
integrateDing := e.getIntegratedConfigRaw(gaia.SystemIntegrationDingTalk)
frontendOrigin = strings.TrimSuffix(frontendOrigin, "/")
if !strings.Contains(frontendOrigin, "localhost") {
if !strings.Contains(frontendOrigin, "localhost") && !strings.HasSuffix(frontendOrigin, "/admin") {
frontendOrigin = frontendOrigin + "/admin"
}
if integrateDing.Status && integrateDing.AppKey != "" {
@@ -72,6 +72,22 @@ func (e *SystemIntegratedService) GetLoginOptions(frontendOrigin string) (res re
return res
}
// GetDingTalkTestAuthURL 返回用于「测试连接」的钉钉授权 URLstate=dingtalk_test,回调后仅验证 code 换 token,不登录)
func (e *SystemIntegratedService) GetDingTalkTestAuthURL(frontendOrigin string) (string, error) {
integrate := e.getIntegratedConfigRaw(gaia.SystemIntegrationDingTalk)
if integrate.AppKey == "" || integrate.AppSecret == "" {
return "", fmt.Errorf("请先配置 AppKey 与 AppSecret")
}
frontendOrigin = strings.TrimSuffix(frontendOrigin, "/")
if !strings.Contains(frontendOrigin, "localhost") && !strings.HasSuffix(frontendOrigin, "/admin") {
frontendOrigin = frontendOrigin + "/admin"
}
callbackURI := frontendOrigin + "/#/loginCallback?provider=dingtalk"
authURL := fmt.Sprintf("https://login.dingtalk.com/oauth2/auth?client_id=%s&response_type=code&scope=openid&redirect_uri=%s&state=dingtalk_test",
integrate.AppKey, url.QueryEscape(callbackURI))
return authURL, nil
}
// getIntegratedConfigRaw 获取集成配置(不脱敏,仅内部使用)
func (e *SystemIntegratedService) getIntegratedConfigRaw(classID uint) (integrate gaia.SystemIntegration) {
if err := global.GVA_DB.Where("classify = ?", classID).First(&integrate).Error; err != nil {
+458 -43
View File
@@ -14,21 +14,231 @@ import (
"encoding/pem"
"errors"
"fmt"
"github.com/flipped-aurora/gin-vue-admin/server/global"
"github.com/flipped-aurora/gin-vue-admin/server/model/gaia"
gaiaRequest "github.com/flipped-aurora/gin-vue-admin/server/model/gaia/request"
gaiaResponse "github.com/flipped-aurora/gin-vue-admin/server/model/gaia/response"
"go.gnd.pw/crypto/eax"
"go.uber.org/zap"
"gorm.io/gorm"
"io"
"net/http"
"os"
"strings"
"sync"
"time"
"github.com/flipped-aurora/gin-vue-admin/server/global"
"github.com/flipped-aurora/gin-vue-admin/server/model/gaia"
gaiaRequest "github.com/flipped-aurora/gin-vue-admin/server/model/gaia/request"
gaiaResponse "github.com/flipped-aurora/gin-vue-admin/server/model/gaia/response"
"github.com/flipped-aurora/gin-vue-admin/server/model/system"
"github.com/flipped-aurora/gin-vue-admin/server/utils"
"go.gnd.pw/crypto/eax"
"go.uber.org/zap"
"gorm.io/gorm"
)
// fetchAdminToken 查询一个管理员用户,生成 Dify Console API 兼容的 JWT。
// 结果缓存到 RedisTTL 50 分钟(JWT 有效期内复用,避免频繁生成)。
func (s *ModelProviderService) fetchAdminToken() (token string, err error) {
ctx := context.Background()
// 优先从 Redis 读取缓存
if cached, e := global.GVA_REDIS.Get(ctx, gaia.RedisKeyGaiaAdminConsoleToken).Result(); e == nil && cached != "" {
return cached, nil
}
// 查询一个活跃管理员
var adminUser system.SysUser
if err = global.GVA_DB.Where("authority_id = ? AND enable = ?",
system.AdminAuthorityId, system.UserActive).First(&adminUser).Error; err != nil {
return "", fmt.Errorf("找不到可用的管理员账号:%w", err)
}
token, _, _, err = utils.LoginTokenWithCSRF(&adminUser)
if err != nil {
return "", fmt.Errorf("生成管理员 token 失败:%w", err)
}
// 缓存 50 分钟(JWT 缓冲时间内有效)
global.GVA_REDIS.Set(ctx, gaia.RedisKeyGaiaAdminConsoleToken, token, 50*time.Minute)
return token, nil
}
// fetchModelPricingFromDify 通过 Dify Console API 拉取 LLM 模型定价,结果按 model 名缓存到 RedisTTL 1 小时)。
// Dify Console APIGET /console/api/workspaces/current/models/model-types/llm
// 响应结构:{"data": [{"models": [{"model": "gpt-4o", "fetch_from": "...", "pricing": {"input":"0.005","output":"0.015","unit":"0.001","currency":"USD"}}]}]}
func (s *ModelProviderService) fetchModelPricingFromDify(modelName string) (*gaia.ModelPricing, error) {
const redisTTL = time.Hour
cacheKey := gaia.RedisKeyGaiaModelPricingPrefix + modelName
ctx := context.Background()
// 先查 Redis
if cached, err := global.GVA_REDIS.Get(ctx, cacheKey).Result(); err == nil && cached != "" {
var p gaia.ModelPricing
if json.Unmarshal([]byte(cached), &p) == nil {
return &p, nil
}
}
// 获取管理员 token
token, err := s.fetchAdminToken()
if err != nil {
return nil, err
}
// 调用 Dify Console API
apiURL := strings.TrimSuffix(global.GVA_CONFIG.Gaia.Url, "/") +
"/console/api/workspaces/current/models/model-types/llm"
req, err := http.NewRequest(http.MethodGet, apiURL, nil)
if err != nil {
return nil, fmt.Errorf("构建定价请求失败:%w", err)
}
req.Header.Set("Authorization", "Bearer "+token)
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: 15 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("请求 Dify 定价接口失败:%w", err)
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("读取定价响应失败:%w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("dify 定价接口返回 %d%s", resp.StatusCode, string(respBody))
}
// 解析响应,批量缓存所有模型的定价
var apiResp gaia.DifyModelsResponse
if err = json.Unmarshal(respBody, &apiResp); err != nil {
return nil, fmt.Errorf("解析定价响应失败:%w", err)
}
var targetPricing *gaia.ModelPricing
for _, providerData := range apiResp.Data {
for _, m := range providerData.Models {
if m.Pricing == nil {
continue
}
p := gaia.ModelPricing{Currency: m.Pricing.Currency}
_, _ = fmt.Sscanf(m.Pricing.Input, "%f", &p.Input)
_, _ = fmt.Sscanf(m.Pricing.Output, "%f", &p.Output)
_, _ = fmt.Sscanf(m.Pricing.Unit, "%f", &p.Unit)
if p.Unit == 0 {
p.Unit = 0.001 // 默认按千 token 计费
}
// 缓存每个模型的定价
if b, e := json.Marshal(p); e == nil {
global.GVA_REDIS.Set(ctx, gaia.RedisKeyGaiaModelPricingPrefix+m.Model, string(b), redisTTL)
}
if m.Model == modelName {
cp := p
targetPricing = &cp
}
}
}
if targetPricing != nil {
return targetPricing, nil
}
// 未找到该模型的定价,写入空标记避免反复请求(TTL 10 分钟)
global.GVA_REDIS.Set(ctx, cacheKey, "{}", 10*time.Minute)
return nil, nil
}
// rmbToUSD 将人民币金额按固定汇率换算为 USD。
func rmbToUSD(rmb float64) float64 {
return rmb / gaia.RmbToUSDRate
}
// resolvePricing 返回模型定价:优先用从 Dify 拉取的 pricing
// 其次查内置兜底定价表(BuiltinModelPricing),最后返回 nil。
func resolvePricing(pricing *gaia.ModelPricing, modelName string) *gaia.ModelPricing {
if pricing != nil && pricing.Unit > 0 {
return pricing
}
// 内置定价表精确匹配
if p, ok := gaia.BuiltinModelPricing[modelName]; ok {
return &p
}
// 前缀模糊匹配(如 "qwen3.5-plus-xxx" 匹配 "qwen3.5-plus"
lower := strings.ToLower(modelName)
for k, p := range gaia.BuiltinModelPricing {
if strings.HasPrefix(lower, strings.ToLower(k)) {
cp := p
return &cp
}
}
return nil
}
// calcQuotaDelta 根据定价和 token 用量计算本次消耗的配额金额(统一以 USD 计)。
// Dify pricing 字段语义:input/output 为每「unit」个 token 的价格,unit 通常为 0.001(千分之一),
// 即 input=0.0014, unit=0.001 表示每千 token ¥0.0014 × (tokens/1000)。
// 公式:cost = tokens × input × unit(因为 unit=1/1000,等价于 tokens/1000 × input)。
// 若货币为 RMB/CNY,则除以汇率 7.26 换算为 USD,与 account_money_extend.used_quota 存储单位保持一致。
// 若 Dify 未返回定价则查内置兜底表;均未命中时按极小默认值记账,避免多扣。
func calcQuotaDelta(pricing *gaia.ModelPricing, modelName string, promptTokens, completionTokens int) float64 {
p := resolvePricing(pricing, modelName)
if p == nil {
// 兜底:仅做记账占位,不应大量触发
global.GVA_LOG.Warn("calcQuotaDelta 未找到模型定价,使用兜底值",
zap.String("model", modelName),
zap.Int("prompt_tokens", promptTokens),
zap.Int("completion_tokens", completionTokens),
)
return float64(promptTokens+completionTokens) * gaia.DefaultQuotaFallbackUSDPerToken
}
inputCost := float64(promptTokens) * p.Input * p.Unit
outputPrice := p.Output
if outputPrice == 0 {
outputPrice = p.Input
}
outputCost := float64(completionTokens) * outputPrice * p.Unit
total := inputCost + outputCost
// RMB/CNY 定价统一换算为 USD 后再扣费,与 used_quota 存储单位保持一致
if strings.EqualFold(p.Currency, "RMB") || strings.EqualFold(p.Currency, "CNY") {
total = rmbToUSD(total)
}
return total
}
// CheckAccountQuota 检查用户是否还有可用余额(total_quota - used_quota > 0)。
// total_quota = 0 视为"未设置限额",不拦截;total_quota > 0 时才做余额校验。
func (s *ModelProviderService) CheckAccountQuota(userID string) error {
var row gaiaResponse.CheckAccountQuotaRow
err := global.GVA_DB.Table("account_money_extend").
Select("total_quota, used_quota").
Where("account_id = ?::uuid", userID).
First(&row).Error
if err != nil {
// 记录未找到:可能尚未初始化,放行
return nil
}
// total_quota = 0 表示不限额,放行
if row.TotalQuota <= 0 {
return nil
}
if row.UsedQuota >= row.TotalQuota {
return fmt.Errorf("余额不足,已用 %.6f / 总额 %.6f USD,请联系管理员充值", row.UsedQuota, row.TotalQuota)
}
return nil
}
// deductAccountQuota 将消耗配额计入 account_money_extend.used_quota(原子累加)。
func deductAccountQuota(userID string, delta float64) {
if delta <= 0 {
return
}
if err := global.GVA_DB.Exec(
`UPDATE account_money_extend SET used_quota = used_quota + ?, updated_at = NOW() WHERE account_id = ?::uuid`,
delta, userID,
).Error; err != nil {
global.GVA_LOG.Warn("deductAccountQuota 失败",
zap.String("user_id", userID), zap.Float64("delta", delta), zap.Error(err))
}
}
// ModelProviderService 模型提供商服务,负责提供商配置、凭证获取、可用模型拉取及聊天请求代理。
type ModelProviderService struct{}
@@ -186,7 +396,8 @@ func (s *ModelProviderService) getAvailableModelsFromProviderModelCredentials(pr
Distinct("model_name").
Pluck("model_name", &modelNames).Error
if err != nil {
global.GVA_LOG.Warn("从 provider_model_credentials 拉取模型列表失败", zap.String("provider", providerName), zap.Error(err))
global.GVA_LOG.Warn("从 provider_model_credentials 拉取模型列表失败", zap.String(
"provider", providerName), zap.Error(err))
return nil, nil
}
list := make([]gaiaResponse.ModelInfo, 0, len(modelNames))
@@ -207,6 +418,14 @@ func (s *ModelProviderService) GetAvailableModelsFromDify(providerName string) (
if providerName == gaia.ProviderAzure {
return s.getAvailableModelsFromProviderModelCredentials(providerName)
}
// AWS Bedrock 没有统一的 /v1/models 接口,模型由前端手输;直接返回空列表
if providerName == gaia.ProviderAWS || providerName == gaia.ProviderAnthropic {
return nil, nil
}
// DeepSeek 模型由前端手输(避免拉取全量列表),直接返回空列表
if providerName == gaia.ProviderDeepSeek {
return nil, nil
}
creds, err := s.GetDifyProviderCredentials(providerName)
if err != nil || creds.APIKey == "" {
@@ -230,8 +449,6 @@ func (s *ModelProviderService) GetAvailableModelsFromDify(providerName string) (
base = gaia.DefaultAPIBase[gaia.ProviderGoogle]
}
return s.fetchGeminiModels(client, base, creds.APIKey)
case gaia.ProviderAnthropic:
return nil, nil
default:
if creds.Endpoint != "" {
return s.fetchOpenAICompatibleModels(client, creds.Endpoint, creds.APIKey)
@@ -244,7 +461,8 @@ func (s *ModelProviderService) GetAvailableModelsFromDify(providerName string) (
// 兼容两种响应格式:
// 1) OpenAI: { "data": [ { "id": "..." }, ... ] }
// 2) 通义: { "success": true, "output": { "models": [ { "model": "...", "name": "..." }, ... ] } }
func (s *ModelProviderService) fetchOpenAICompatibleModels(client *http.Client, baseURL, apiKey string) ([]gaiaResponse.ModelInfo, error) {
func (s *ModelProviderService) fetchOpenAICompatibleModels(client *http.Client, baseURL, apiKey string) (
[]gaiaResponse.ModelInfo, error) {
url := strings.TrimSuffix(baseURL, "/") + "/v1/models"
req, err := http.NewRequest("GET", url, nil)
if err != nil {
@@ -259,7 +477,8 @@ func (s *ModelProviderService) fetchOpenAICompatibleModels(client *http.Client,
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
global.GVA_LOG.Warn("拉取模型列表接口非 200", zap.String("url", url), zap.Int("status", resp.StatusCode), zap.String("body", string(body)))
global.GVA_LOG.Warn("拉取模型列表接口非 200", zap.String("url", url), zap.Int(
"status", resp.StatusCode), zap.String("body", string(body)))
return nil, fmt.Errorf("接口返回 %d", resp.StatusCode)
}
@@ -298,7 +517,8 @@ func (s *ModelProviderService) fetchOpenAICompatibleModels(client *http.Client,
// fetchGeminiModels 调用 Google Gemini GET /v1beta/models?key=API_KEY,解析 models[],支持分页。
// 认证使用 query 参数 key,响应格式:{ "models": [ { "name": "models/xxx", "baseModelId": "xxx", "displayName": "..." } ], "nextPageToken": "..." }
func (s *ModelProviderService) fetchGeminiModels(client *http.Client, baseURL, apiKey string) ([]gaiaResponse.ModelInfo, error) {
func (s *ModelProviderService) fetchGeminiModels(client *http.Client, baseURL, apiKey string) (
[]gaiaResponse.ModelInfo, error) {
baseURL = strings.TrimSuffix(baseURL, "/")
all := make([]gaiaResponse.ModelInfo, 0)
pageToken := ""
@@ -320,7 +540,9 @@ func (s *ModelProviderService) fetchGeminiModels(client *http.Client, baseURL, a
body, _ := io.ReadAll(resp.Body)
resp.Body.Close()
if resp.StatusCode != http.StatusOK {
global.GVA_LOG.Warn("拉取 Gemini 模型列表非 200", zap.String("url", baseURL+"/v1beta/models"), zap.Int("status", resp.StatusCode), zap.String("body", string(body)))
global.GVA_LOG.Warn("拉取 Gemini 模型列表非 200", zap.String(
"url", baseURL+"/v1beta/models"), zap.Int("status", resp.StatusCode), zap.String(
"body", string(body)))
return nil, fmt.Errorf("接口返回 %d", resp.StatusCode)
}
@@ -356,7 +578,8 @@ func (s *ModelProviderService) fetchGeminiModels(client *http.Client, baseURL, a
// fetchAzureOpenAIModels 调用 Azure OpenAI GET {endpoint}/openai/models?api-version={version},解析 data[]。
// 认证使用 api-key 请求头,响应格式:{ "data": [ { "id": "...", "object": "model" } ] }
func (s *ModelProviderService) fetchAzureOpenAIModels(client *http.Client, baseURL, apiKey, apiVersion string) ([]gaiaResponse.ModelInfo, error) {
func (s *ModelProviderService) fetchAzureOpenAIModels(client *http.Client, baseURL, apiKey, apiVersion string) (
[]gaiaResponse.ModelInfo, error) {
baseURL = strings.TrimSuffix(baseURL, "/")
if apiVersion == "" {
apiVersion = "2024-08-01-preview" // 默认 API 版本
@@ -378,7 +601,8 @@ func (s *ModelProviderService) fetchAzureOpenAIModels(client *http.Client, baseU
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
global.GVA_LOG.Warn("拉取 Azure OpenAI 模型列表非 200", zap.String("url", url), zap.Int("status", resp.StatusCode), zap.String("body", string(body)))
global.GVA_LOG.Warn("拉取 Azure OpenAI 模型列表非 200", zap.String("url", url), zap.Int(
"status", resp.StatusCode), zap.String("body", string(body)))
return nil, fmt.Errorf("接口返回 %d", resp.StatusCode)
}
@@ -416,29 +640,36 @@ func (s *ModelProviderService) GetDifyProviderCredentials(providerName string) (
var cached string
var firstTenant gaia.Tenants
tenantID := firstTenant.GetSuperAdminTenantId()
cacheKey := fmt.Sprintf("model_provider_credentials:%s", providerName)
cacheKey := gaia.RedisKeyModelProviderCredentialsPrefix + providerName
if cached, err = global.GVA_Dify_REDIS.Get(context.Background(), cacheKey).Result(); err == nil {
if err = json.Unmarshal([]byte(cached), &creds); err == nil {
global.GVA_LOG.Info("GetDifyProviderCredentials 命中缓存",
zap.String("provider", providerName),
zap.String("aws_region", creds.AWSRegion),
zap.String("bedrock_proxy_url", creds.BedrockProxyURL),
)
return creds, nil
}
}
// 尝试方式1: 从 providers + provider_credentials 表查询
var row gaia.ProviderCredential
err = global.GVA_DB.Table("providers").
// 将短名转为 Dify 内部 provider_name 的 LIKE 模式(避免 aws 匹配不到 bedrock_claude
likePattern := s.difyProviderLikePattern(providerName)
err = global.GVA_DB.Debug().Table("providers").
Select("provider_credentials.encrypted_config, providers.tenant_id").
Joins("LEFT JOIN provider_credentials ON providers.credential_id = provider_credentials.id").
Where("providers.tenant_id = ? AND providers.provider_name LIKE ? AND providers.provider_type = ? AND providers.is_valid = ?",
tenantID, fmt.Sprintf("%%%s%%", providerName), gaia.DifyProviderTypeCustom, true).
tenantID, likePattern, gaia.DifyProviderTypeCustom, true).
Order("provider_credentials.updated_at DESC").
First(&row).Error
// 如果方式1 未找到记录,尝试方式2: 从 provider_model_credentials 表查询
if err != nil || row.EncryptedConfig == "" {
var pmcRow gaia.ProviderCredential
if pmcErr := global.GVA_DB.Table("provider_model_credentials").
if pmcErr := global.GVA_DB.Debug().Table("provider_model_credentials").
Select("encrypted_config, tenant_id, provider_name, updated_at").
Where("tenant_id = ? AND provider_name LIKE ?", tenantID, fmt.Sprintf("%%%s%%", providerName)).
Where("tenant_id = ? AND provider_name LIKE ?", tenantID, likePattern).
Order("updated_at DESC"). // 按 updated_at 倒序,取最新的凭证
First(&pmcRow).Error; pmcErr == nil && pmcRow.EncryptedConfig != "" {
row = pmcRow
@@ -453,6 +684,7 @@ func (s *ModelProviderService) GetDifyProviderCredentials(providerName string) (
// 兼容两种存储:1) 明文 JSON(如 {"openai_api_key":"...", "openai_api_base":"..."});2) Dify RSA+AES-EAX 加密后再 base64
var base, apiVersion string
var configMap map[string]interface{}
fmt.Println("row.EncryptedConfig", row.EncryptedConfig)
if err = json.Unmarshal([]byte(row.EncryptedConfig), &configMap); err == nil {
// 解密函数用于处理加密的值
if config, ok := configMap[gaia.ConfigKeyOpenaiAPIKey]; ok {
@@ -468,6 +700,35 @@ func (s *ModelProviderService) GetDifyProviderCredentials(providerName string) (
creds.APIKey, err = s.decryptConfig(config.(string), row.TenantID)
} else if config, ok = configMap[gaia.ConfigKeyAPIKey]; ok {
creds.APIKey, err = s.decryptConfig(config.(string), row.TenantID)
} else if _, hasAWSKey := configMap[gaia.ConfigKeyAWSAccessKeyID]; hasAWSKey {
// AWS Bedrock 凭证:解析 aws_access_key_id / aws_secret_access_key / aws_region
if v, ok2 := configMap[gaia.ConfigKeyAWSAccessKeyID].(string); ok2 && v != "" {
if creds.AWSAccessKeyID, err = s.decryptConfig(v, row.TenantID); err != nil {
return nil, fmt.Errorf("解密 aws_access_key_id 失败: %w", err)
}
}
if v, ok2 := configMap[gaia.ConfigKeyAWSSecretAccessKey].(string); ok2 && v != "" {
if creds.AWSSecretAccessKey, err = s.decryptConfig(v, row.TenantID); err != nil {
return nil, fmt.Errorf("解密 aws_secret_access_key 失败: %w", err)
}
}
if v, ok2 := configMap[gaia.ConfigKeyAWSSessionToken].(string); ok2 && v != "" {
if creds.AWSSessionToken, err = s.decryptConfig(v, row.TenantID); err != nil {
return nil, fmt.Errorf("解密 aws_session_token 失败: %w", err)
}
}
if v, ok2 := configMap[gaia.ConfigKeyAWSRegion].(string); ok2 && v != "" {
creds.AWSRegion = strings.TrimSpace(v)
}
// 可选:HTTP 代理地址(用于从受限地区中转 Bedrock 请求)。
// 支持 "host:port" 或 "http(s)://host:port";若值被加密则先解密,失败时按原文处理。
if v, ok2 := configMap[gaia.ConfigKeyBedrockProxyURL].(string); ok2 && strings.TrimSpace(v) != "" {
proxyVal := strings.TrimSpace(v)
if decrypted, decErr := s.decryptConfig(proxyVal, row.TenantID); decErr == nil && decrypted != "" {
proxyVal = strings.TrimSpace(decrypted)
}
creds.BedrockProxyURL = proxyVal
}
} else {
// 尝试从备选字段中查找
for _, key := range gaia.CredentialKeyFallback {
@@ -490,13 +751,14 @@ func (s *ModelProviderService) GetDifyProviderCredentials(providerName string) (
return nil, fmt.Errorf("解密凭证失败: %w", err)
}
}
if creds.APIKey == "" {
return nil, fmt.Errorf("未能从配置中提取API Key")
if creds.APIKey == "" && creds.AWSAccessKeyID == "" {
return nil, fmt.Errorf("未能从配置中提取API Key(也未找到 AWS 凭证)")
}
// 缓存凭证(1小时)
var cacheJSON []byte
if cacheJSON, err = json.Marshal(creds); err == nil {
fmt.Println("row.EncryptedConfig", string(cacheJSON))
global.GVA_Dify_REDIS.Set(context.Background(), cacheKey, cacheJSON, time.Hour)
}
@@ -708,7 +970,7 @@ func (s *ModelProviderService) ProxyChat(userID string, req gaiaRequest.ChatRequ
defer func() {
// 记录日志
log := gaia.ModelProxyLog{
global.GVA_DB.Create(&gaia.ModelProxyLog{
UserId: userID,
ProviderName: providerName,
ModelName: req.Model,
@@ -717,8 +979,7 @@ func (s *ModelProviderService) ProxyChat(userID string, req gaiaRequest.ChatRequ
Status: status,
ErrorMessage: errorMsg,
CreatedAt: startTime,
}
global.GVA_DB.Create(&log)
})
}()
// 处理流式响应
@@ -726,7 +987,7 @@ func (s *ModelProviderService) ProxyChat(userID string, req gaiaRequest.ChatRequ
scanner := bufio.NewScanner(resp.Body)
for scanner.Scan() {
line := scanner.Text()
if _, err := writer.Write([]byte(line + "\n")); err != nil {
if _, err = writer.Write([]byte(line + "\n")); err != nil {
status = "error"
errorMsg = err.Error()
return err
@@ -736,14 +997,14 @@ func (s *ModelProviderService) ProxyChat(userID string, req gaiaRequest.ChatRequ
flusher.Flush()
}
}
if err := scanner.Err(); err != nil {
if err = scanner.Err(); err != nil {
status = "error"
errorMsg = err.Error()
return err
}
} else {
// 非流式响应
if _, err := io.Copy(writer, resp.Body); err != nil {
if _, err = io.Copy(writer, resp.Body); err != nil {
status = "error"
errorMsg = err.Error()
return err
@@ -769,8 +1030,18 @@ func (s *ModelProviderService) getProviderCandidatesByModel(modelName string) []
if strings.HasPrefix(modelLower, "gemini") || strings.Contains(modelLower, "google") {
return []string{gaia.ProviderGoogle}
}
if strings.HasPrefix(modelLower, "anthropic.") {
// anthropic.* 前缀是 AWS Bedrock 专用模型 ID 格式(如 anthropic.claude-sonnet-4-6),
// 仅走 Bedrock 渠道,不回落到 Anthropic 直连(直连不支持此格式且可能因地区受限)。
return []string{gaia.ProviderAWS}
}
if strings.Contains(modelLower, "claude") || strings.Contains(modelLower, "anthropic") {
return []string{gaia.ProviderAnthropic}
// 顺序即优先级:AWS Bedrock 优先(配置成本更高、可覆盖受限地区),未开启再回落到 anthropic 直连
return []string{gaia.ProviderAWS, gaia.ProviderAnthropic}
}
// Kimi / Moonshot 系列经由 tongyi(百炼)渠道转发
if strings.HasPrefix(modelLower, "kimi") || strings.Contains(modelLower, "moonshot") {
return []string{gaia.ProviderTongyi}
}
// GLM/智谱 可能配置在 tongyi(统一入口)或 zhipuai 下,先试 tongyi
if strings.HasPrefix(modelLower, "glm") || strings.Contains(modelLower, "zhipu") || strings.Contains(modelLower, "chatglm") {
@@ -780,6 +1051,10 @@ func (s *ModelProviderService) getProviderCandidatesByModel(modelName string) []
if strings.HasPrefix(modelLower, "minimax") || strings.Contains(modelLower, "abab") {
return []string{gaia.ProviderTongyi, gaia.ProviderMinimax}
}
// DeepSeek 系列模型走 deepseek 渠道
if strings.HasPrefix(modelLower, "deepseek") {
return []string{gaia.ProviderDeepSeek}
}
return nil
}
@@ -811,9 +1086,18 @@ func (s *ModelProviderService) getProviderByModel(modelName string) (string, err
if strings.HasPrefix(modelLower, "gemini") || strings.Contains(modelLower, "google") {
return gaia.ProviderGoogle, nil
}
if strings.HasPrefix(modelLower, "anthropic.") {
// anthropic.* 前缀是 AWS Bedrock 专用格式,直接归 aws 渠道
return gaia.ProviderAWS, nil
}
if strings.Contains(modelLower, "claude") || strings.Contains(modelLower, "anthropic") {
// 仅按名字推断时默认 anthropic;实际渠道(含 AWS Bedrock)由 resolveProviderByModel 决定
return gaia.ProviderAnthropic, nil
}
// Kimi / Moonshot 默认走 tongyi(百炼)渠道
if strings.HasPrefix(modelLower, "kimi") || strings.Contains(modelLower, "moonshot") {
return gaia.ProviderTongyi, nil
}
if strings.Contains(modelLower, "azure") {
return gaia.ProviderAzure, nil
}
@@ -825,6 +1109,10 @@ func (s *ModelProviderService) getProviderByModel(modelName string) (string, err
if strings.HasPrefix(modelLower, "minimax") || strings.Contains(modelLower, "abab") {
return gaia.ProviderTongyi, nil
}
// DeepSeek 系列走 deepseek 渠道
if strings.HasPrefix(modelLower, "deepseek") {
return gaia.ProviderDeepSeek, nil
}
return "", fmt.Errorf("无法识别模型 %s 的提供商", modelName)
}
@@ -900,7 +1188,8 @@ func (s *ModelProviderService) ProxyRequest(
// 解析 provider:头 > query 已在 handler 传入;此处从 body 取 model 仅当 body 为 JSON 且含 model 时用于推断
xGaiaProvider := reqHeader.Get("X-Gaia-Provider")
global.GVA_LOG.Info("ProxyRequest 解析 provider", zap.String("path", path), zap.String("X-Gaia-Provider", xGaiaProvider), zap.Int("body_len", len(body)))
global.GVA_LOG.Info("ProxyRequest 解析 provider", zap.String("path", path), zap.String(
"X-Gaia-Provider", xGaiaProvider), zap.Int("body_len", len(body)))
if p := xGaiaProvider; p != "" {
providerName = strings.TrimSpace(strings.ToLower(p))
}
@@ -912,7 +1201,8 @@ func (s *ModelProviderService) ProxyRequest(
// 按“已选模型”解析实际渠道(如 gpt-5-chat 若只在 Azure 下勾选则走 azure
providerName, err = s.resolveProviderByModel(m)
if err != nil {
global.GVA_LOG.Error("ProxyRequest resolveProviderByModel 失败", zap.String("model", m), zap.Error(err))
global.GVA_LOG.Error("ProxyRequest resolveProviderByModel 失败", zap.String(
"model", m), zap.Error(err))
return err
}
global.GVA_LOG.Info("ProxyRequest 解析得到 provider", zap.String("provider", providerName))
@@ -935,11 +1225,29 @@ func (s *ModelProviderService) ProxyRequest(
return err
}
// AWS Bedrock 直连:不走通用 HTTP 转发,改用 SigV4 签名的 Bedrock 原生 API
if providerName == gaia.ProviderAWS {
return s.proxyBedrockRequest(userID, path, method, reqHeader, body, writer, creds)
}
if base = s.getUpstreamBase(providerName, creds); base == "" {
return fmt.Errorf("提供商 %s 无可用上游地址", providerName)
}
// 若 body 是 JSON 且含 stream: true,注入 stream_options.include_usage = true
// 这样上游会在 SSE 末尾的 data 行返回 usage,供后续计费解析使用。
if len(body) > 0 {
var bodyObj map[string]interface{}
if json.Unmarshal(body, &bodyObj) == nil {
if streamVal, ok := bodyObj["stream"].(bool); ok && streamVal {
if _, hasOpt := bodyObj["stream_options"]; !hasOpt {
bodyObj["stream_options"] = map[string]interface{}{"include_usage": true}
if injected, e := json.Marshal(bodyObj); e == nil {
body = injected
}
}
}
}
bodyReader = bytes.NewReader(body)
}
@@ -975,7 +1283,6 @@ func (s *ModelProviderService) ProxyRequest(
requestURL = base + "/" + path
}
fmt.Println("path", requestURL, string(body))
httpReq, err := http.NewRequest(method, requestURL, bodyReader)
if err != nil {
return err
@@ -1017,18 +1324,42 @@ func (s *ModelProviderService) ProxyRequest(
}
}
var logStatus, logError string
var promptTokens, completionTokens int
defer func() {
if logStatus == "" {
logStatus = "success"
}
global.GVA_DB.Create(&gaia.ModelProxyLog{
UserId: userID,
ProviderName: providerName,
ModelName: modelOrPath,
Status: logStatus,
ErrorMessage: logError,
CreatedAt: startTime,
UserId: userID,
ProviderName: providerName,
ModelName: modelOrPath,
RequestTokens: promptTokens,
ResponseTokens: completionTokens,
Status: logStatus,
ErrorMessage: logError,
CreatedAt: startTime,
})
// 计费:仅成功时扣费
if logStatus == "success" {
if promptTokens > 0 || completionTokens > 0 {
// LLM 类型:按 token 计费
pricing, _ := s.fetchModelPricingFromDify(modelOrPath)
delta := calcQuotaDelta(pricing, modelOrPath, promptTokens, completionTokens)
deductAccountQuota(userID, delta)
} else if isImageOrPerRequestPath(path) {
// 图片生成等无 usage 的接口:按请求次数计费,默认单价见 gaia.DefaultImageGenerationPriceUSD
pricing, _ := s.fetchModelPricingFromDify(modelOrPath)
var delta float64
if pricing != nil && pricing.Input > 0 {
// 若定价表有配置,input 字段用作每次请求单价
delta = pricing.Input
} else {
delta = gaia.DefaultImageGenerationPriceUSD
}
deductAccountQuota(userID, delta)
}
}
}()
// 写回状态码与响应头(流式由上游 Content-Type 决定)
@@ -1044,17 +1375,36 @@ func (s *ModelProviderService) ProxyRequest(
_, _ = io.Copy(writer, resp.Body)
return nil
}
// 流式响应时按行刷新,避免缓冲
// extractUsage 从 OpenAI 格式的 JSON 对象中提取 usage 字段
extractUsage := func(data []byte) {
var obj gaia.ModelUsageResponse
if json.Unmarshal(data, &obj) == nil && obj.Usage != nil {
if obj.Usage.PromptTokens > 0 {
promptTokens = obj.Usage.PromptTokens
}
if obj.Usage.CompletionTokens > 0 {
completionTokens = obj.Usage.CompletionTokens
}
}
}
// 流式响应:按行扫描,顺带从最后一条含 usage 的 data 行返回
if strings.Contains(resp.Header.Get("Content-Type"), "text/event-stream") {
if flusher, ok := writer.(http.Flusher); ok {
scanner := bufio.NewScanner(resp.Body)
for scanner.Scan() {
fmt.Println("sss", scanner.Text())
if _, err = writer.Write([]byte(scanner.Text() + "\n")); err != nil {
line := scanner.Text()
if _, err = writer.Write([]byte(line + "\n")); err != nil {
logStatus, logError = "error", err.Error()
return err
}
flusher.Flush()
// 解析 SSE data 行中的 usagestream_options.include_usage=true 时上游会附带)
if strings.HasPrefix(line, "data:") && strings.Contains(line, `"usage"`) {
payload := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
extractUsage([]byte(payload))
}
}
if err = scanner.Err(); err != nil {
logStatus, logError = "error", err.Error()
@@ -1063,13 +1413,41 @@ func (s *ModelProviderService) ProxyRequest(
return nil
}
}
_, err = io.Copy(writer, resp.Body)
// 非流式响应:TeeReader 同时转发给客户端并读取 body 用于解析 usage
var buf bytes.Buffer
tee := io.TeeReader(resp.Body, &buf)
_, err = io.Copy(writer, tee)
if err != nil {
logStatus, logError = "error", err.Error()
} else {
extractUsage(buf.Bytes())
}
return err
}
// GetProxyLogs 分页查询代理日志(model_proxy_log 表)。
func (s *ModelProviderService) GetProxyLogs(info gaiaRequest.GetProxyLogsReq) (list []map[string]interface{}, total int64, err error) {
page, pageSize := info.Page, info.PageSize
if page < 1 {
page = 1
}
if pageSize < 1 || pageSize > 100 {
pageSize = 20
}
db := global.GVA_DB.Table("model_proxy_log")
if err = db.Count(&total).Error; err != nil {
err = fmt.Errorf("查询日志总数失败:%w", err)
return
}
offset := (page - 1) * pageSize
if err = db.Order("created_at DESC").Limit(pageSize).Offset(offset).Find(&list).Error; err != nil {
err = fmt.Errorf("查询日志列表失败:%w", err)
return
}
return list, total, nil
}
// isProviderEnabled 检查该提供商是否已启用(未校验具体模型列表,用于通用代理)。
func (s *ModelProviderService) isProviderEnabled(providerName string) bool {
var config gaia.ModelProviderConfig
@@ -1078,3 +1456,40 @@ func (s *ModelProviderService) isProviderEnabled(providerName string) bool {
}
return true
}
// isImageOrPerRequestPath 判断请求路径是否为按次计费的接口(图片生成、语音合成等无 usage 字段的接口)。
func isImageOrPerRequestPath(path string) bool {
perRequestPaths := []string{
"images/generations",
"images/edits",
"images/variations",
"audio/speech",
"audio/transcriptions",
"audio/translations",
}
lpath := strings.ToLower(path)
for _, p := range perRequestPaths {
if strings.Contains(lpath, p) {
return true
}
}
return false
}
// difyProviderLikePattern 将 Gaia 短名(openai/aws/...)转换为 Dify providers 表
// provider_name 字段的 LIKE 搜索模式。
// Dify 内部以 "langgenius/<plugin>/<plugin>" 格式存储提供商名,例如:
// - aws → langgenius/bedrock_claude/bedrock_claude 或 langgenius/bedrock/bedrock
// - openai → langgenius/openai/openai
// - anthropic → langgenius/anthropic/anthropic
//
// 对于 AWSDify 有两种插件包:bedrock_claude 和 bedrock;使用 bedrock 作为公共关键字可同时命中。
func (s *ModelProviderService) difyProviderLikePattern(providerName string) string {
switch providerName {
case gaia.ProviderAWS:
// langgenius/bedrock_claude/bedrock_claude 和 langgenius/bedrock/bedrock 都含 "bedrock"
return "%bedrock%"
default:
return fmt.Sprintf("%%%s%%", providerName)
}
}
+731 -52
View File
@@ -1,6 +1,10 @@
package gaia
import (
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
@@ -8,13 +12,16 @@ import (
"net/http"
"net/url"
"os"
"strconv"
"strings"
"time"
"github.com/faabiosr/cachego/file"
"github.com/fastwego/dingding"
"github.com/flipped-aurora/gin-vue-admin/server/global"
"github.com/flipped-aurora/gin-vue-admin/server/model/gaia"
"github.com/flipped-aurora/gin-vue-admin/server/model/gaia/request"
gaiaResp "github.com/flipped-aurora/gin-vue-admin/server/model/gaia/response"
"github.com/flipped-aurora/gin-vue-admin/server/utils"
"github.com/google/uuid"
"go.uber.org/zap"
@@ -124,22 +131,52 @@ func (e *SystemIntegratedService) SetIntegratedConfig(
// @param: req gaia.SystemIntegration
// @return: *dingding.Client, error
func (e *SystemIntegratedService) DingTalkConfigAvailable(req gaia.SystemIntegration) (*dingding.Client, error) {
var err error
// 1. 先直接调用钉钉 gettoken 接口,校验 AppKey/AppSecret 是否正确
if strings.TrimSpace(req.AppKey) == "" || strings.TrimSpace(req.AppSecret) == "" {
return nil, errors.New("AppKey 或 AppSecret 不能为空")
}
params := url.Values{}
params.Add("appkey", req.AppKey)
params.Add("appsecret", req.AppSecret)
resp, err := http.Get("https://oapi.dingtalk.com/gettoken?" + params.Encode())
if err != nil {
return nil, fmt.Errorf("请求钉钉 gettoken 失败: %w", err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("钉钉 gettoken HTTP 状态异常: %d, body=%s", resp.StatusCode, string(body))
}
var tokenResp struct {
ErrCode int `json:"errcode"`
ErrMsg string `json:"errmsg"`
}
if err = json.Unmarshal(body, &tokenResp); err != nil {
return nil, fmt.Errorf("解析钉钉 gettoken 响应失败: %w", err)
}
if tokenResp.ErrCode != 0 {
return nil, fmt.Errorf("钉钉 gettoken 返回错误: errcode=%d, errmsg=%s", tokenResp.ErrCode, tokenResp.ErrMsg)
}
// 2. 校验通过后再构造 client(保持原有行为,供后续使用)
var reqs *http.Request
dingding.ServerUrl = "https://api.dingtalk.com"
// 特殊需要,检查可用性就不设置缓存了
return dingding.NewClient(&dingding.DefaultAccessTokenManager{
client := dingding.NewClient(&dingding.DefaultAccessTokenManager{
Id: uuid.New().String(),
Cache: file.New(os.TempDir()),
Name: "x-acs-dingtalk-access-token",
GetRefreshRequestFunc: func() *http.Request {
params := url.Values{}
params.Add("appkey", req.AppKey)
params.Add("appsecret", req.AppSecret)
reqs, err = http.NewRequest(http.MethodGet, "https://oapi.dingtalk.com/gettoken?"+params.Encode(), nil)
// 这里沿用原来的 token 刷新逻辑
reqs, _ = http.NewRequest(http.MethodGet, "https://oapi.dingtalk.com/gettoken?"+params.Encode(), nil)
return reqs
},
}), err
})
return client, nil
}
// TestConnection 测试连接
@@ -162,6 +199,11 @@ func (e *SystemIntegratedService) TestConnection(integrate gaia.SystemIntegratio
global.GVA_LOG.Warn("第三方邮箱API配置验证失败", zap.Error(err))
// 不阻止保存,只记录警告
}
// 验证转发集成配置
if err := e.ValidateForwardConfig(integrate); err != nil {
global.GVA_LOG.Warn("转发集成配置验证失败", zap.Error(err))
// 不阻止保存,只记录警告
}
return nil
case gaia.SystemIntegrationOAuth2:
// 测试OAuth2连接
@@ -171,79 +213,307 @@ func (e *SystemIntegratedService) TestConnection(integrate gaia.SystemIntegratio
}
}
// ValidateEmailApiConfig 验证第三方邮箱API配置
// @Tags System Integrated
// @Summary 验证第三方邮箱API配置
// @param: integrate gaia.SystemIntegration
// @return: error
func (e *SystemIntegratedService) ValidateEmailApiConfig(integrate gaia.SystemIntegration) error {
// 解析Config字段
if integrate.Config == "" {
return nil // 配置为空不算错误
}
// ParseDingTalkConfig 解析钉钉集成配置,自动处理新旧格式兼容
func (e *SystemIntegratedService) ParseDingTalkConfig(configJSON string) (request.DingTalkConfigRequest, error) {
var configMap request.DingTalkConfigRequest
if err := json.Unmarshal([]byte(integrate.Config), &configMap); err != nil {
return fmt.Errorf("解析配置失败: %s", err.Error())
if configJSON == "" {
return configMap, nil
}
// 检查是否启用邮箱API
if !configMap.EmailApi.Enabled {
return nil // 未启用不需要验证
// 先解析顶层结构
var raw struct {
EmailApi json.RawMessage `json:"email_api"`
ForwardConfig request.ForwardConfig `json:"forward_config"`
}
if err := json.Unmarshal([]byte(configJSON), &raw); err != nil {
return configMap, fmt.Errorf("解析钉钉配置失败: %s", err.Error())
}
// 验证必填字段
if configMap.EmailApi.URL == "" {
configMap.ForwardConfig = raw.ForwardConfig
if raw.EmailApi != nil {
cfg, err := parseEmailApiConfigFromJSON(raw.EmailApi)
if err != nil {
return configMap, err
}
configMap.EmailApi = cfg
}
return configMap, nil
}
// oldBodyDataCompat 旧格式 BodyData(兼容解析用)
type oldBodyDataCompat struct {
FormData []map[string]string `json:"form_data"`
Urlencoded []map[string]string `json:"urlencoded"`
Raw string `json:"raw"`
}
// isNewEmailApiConfig 检测配置是否使用新格式(通过 params 字段判断)
func isNewEmailApiConfig(config request.EmailApiConfig) bool {
return config.Params != nil
}
// convertOldBodyDataToNew 将旧格式 BodyData[]map[string]string)转换为新格式([]BodyField
func convertOldBodyDataToNew(old oldBodyDataCompat) request.BodyData {
newData := request.BodyData{Raw: old.Raw}
for _, kv := range old.FormData {
for k, v := range kv {
newData.FormData = append(newData.FormData, request.BodyField{
Key: k,
ValueType: request.ValueTypeString,
Value: v,
})
}
}
for _, kv := range old.Urlencoded {
for k, v := range kv {
newData.Urlencoded = append(newData.Urlencoded, request.BodyField{
Key: k,
ValueType: request.ValueTypeString,
Value: v,
})
}
}
return newData
}
// parseEmailApiConfigFromJSON 解析 EmailApiConfig,自动兼容新旧格式
// 新格式:包含 params 字段
// 旧格式:包含 request_param_field 字段,body_data 中使用 []map[string]string
func parseEmailApiConfigFromJSON(raw json.RawMessage) (request.EmailApiConfig, error) {
// 先尝试解析为新格式
var newConfig request.EmailApiConfig
if err := json.Unmarshal(raw, &newConfig); err != nil {
return request.EmailApiConfig{}, fmt.Errorf("解析邮箱配置失败: %s", err.Error())
}
// 如果有 params 字段,说明是新格式
if isNewEmailApiConfig(newConfig) {
return newConfig, nil
}
// 旧格式:尝试解析 body_data 中的 []map[string]string
var oldCompat struct {
request.EmailApiConfig
BodyData oldBodyDataCompat `json:"body_data"`
}
if err := json.Unmarshal(raw, &oldCompat); err == nil {
newConfig.BodyData = convertOldBodyDataToNew(oldCompat.BodyData)
global.GVA_LOG.Info("邮箱配置:检测到旧格式,已在内存中转换为新格式",
zap.String("request_param_field", newConfig.RequestParamField))
}
return newConfig, nil
}
// validateEmailApiConfigFields 验证 EmailApiConfig 字段
func validateEmailApiConfigFields(cfg request.EmailApiConfig) error {
if !cfg.Enabled {
return nil
}
if cfg.URL == "" {
return errors.New("邮箱API URL不能为空")
}
if configMap.EmailApi.Method == "" {
configMap.EmailApi.Method = "GET"
if cfg.Method == "" {
cfg.Method = "GET"
}
if configMap.EmailApi.RequestParamField == "" {
return errors.New("邮箱请求字段不能为空")
}
if configMap.EmailApi.ResponseEmailField == "" {
return errors.New("邮箱信息提取字段不能为空")
}
// 验证Body类型(仅POST/PUT/DELETE需要)
if configMap.EmailApi.Method != "GET" {
bodyType := strings.ToLower(configMap.EmailApi.BodyType)
if bodyType == "" {
configMap.EmailApi.BodyType = "raw" // 默认raw
} else if bodyType != "form-data" && bodyType != "x-www-form-urlencoded" && bodyType != "raw" {
return fmt.Errorf("不支持的Body类型: %s,支持的类型: form-data, x-www-form-urlencoded, raw", bodyType)
// 新格式不强制要求 request_param_field(通过 params 配置 URL 查询参数)
if isNewEmailApiConfig(cfg) {
// Params 只支持 string 和 ding_id 两种类型
for i, p := range cfg.Params {
if err := validateParamValueType(p.ValueType); err != nil {
return fmt.Errorf("第%d个 URL 参数类型无效:%s", i+1, err.Error())
}
}
// Body fields 支持 string、int、bool、ding_id
for i, f := range cfg.BodyData.FormData {
if err := validateValueType(f.ValueType); err != nil {
return fmt.Errorf("form-data 第%d个字段类型无效:%s", i+1, err.Error())
}
if f.ValueType == request.ValueTypeInt && f.Value != "" {
if _, err := strconv.ParseInt(f.Value, 10, 64); err != nil {
return fmt.Errorf("form-data 第%d个字段(%s)的值不是有效整数:%s", i+1, f.Key, f.Value)
}
}
if f.ValueType == request.ValueTypeBool && f.Value != "" {
if _, err := strconv.ParseBool(f.Value); err != nil {
return fmt.Errorf("form-data 第%d个字段(%s)的值不是有效布尔值:%s", i+1, f.Key, f.Value)
}
}
}
for i, f := range cfg.BodyData.Urlencoded {
if err := validateValueType(f.ValueType); err != nil {
return fmt.Errorf("urlencoded 第%d个字段类型无效:%s", i+1, err.Error())
}
}
} else {
// 旧格式兼容:request_param_field 不能为空
if cfg.RequestParamField == "" {
return errors.New("邮箱请求字段不能为空")
}
}
// 验证Authorization配置
authType := strings.ToLower(configMap.EmailApi.Authorization.Type)
if cfg.ResponseEmailField == "" {
return errors.New("邮箱信息提取字段不能为空")
}
if cfg.Method != "GET" {
bodyType := strings.ToLower(cfg.BodyType)
if bodyType != "" && bodyType != "form-data" && bodyType != "x-www-form-urlencoded" && bodyType != "raw" {
return fmt.Errorf("不支持的Body类型: %s,支持的类型: form-data, x-www-form-urlencoded, raw", cfg.BodyType)
}
}
authType := strings.ToLower(cfg.Authorization.Type)
if authType != "" && authType != "none" {
if authType == "bearer" {
if configMap.EmailApi.Authorization.Token == "" {
if cfg.Authorization.Token == "" {
return errors.New("Bearer Token不能为空")
}
} else if authType == "basic" {
if configMap.EmailApi.Authorization.Username == "" || configMap.EmailApi.Authorization.Password == "" {
if cfg.Authorization.Username == "" || cfg.Authorization.Password == "" {
return errors.New("Basic Auth需要填写Username和Password")
}
} else {
return fmt.Errorf("不支持的Authorization类型: %s,支持的类型: none, bearer, basic", authType)
}
}
return nil
}
// validateValueType 验证 Body 字段的 ValueType 是否合法(支持全部四种)
func validateValueType(vt string) error {
switch vt {
case "", request.ValueTypeString, request.ValueTypeInt, request.ValueTypeBool, request.ValueTypeDingID:
return nil
default:
return fmt.Errorf("不支持的值类型: %s,支持的类型: string, int, bool, ding_id", vt)
}
}
// validateParamValueType 验证 URL Params 的 ValueType 是否合法(只支持 string 和 ding_id
func validateParamValueType(vt string) error {
switch vt {
case "", request.ValueTypeString, request.ValueTypeDingID:
return nil
default:
return fmt.Errorf("URL 参数不支持的值类型: %s,支持的类型: string, ding_id", vt)
}
}
// ValidateEmailApiConfig 验证第三方邮箱API配置
// @Tags System Integrated
// @Summary 验证第三方邮箱API配置
// @param: integrate gaia.SystemIntegration
// @return: error
func (e *SystemIntegratedService) ValidateEmailApiConfig(integrate gaia.SystemIntegration) error {
if integrate.Config == "" {
return nil
}
var configMap struct {
EmailApi json.RawMessage `json:"email_api"`
}
if err := json.Unmarshal([]byte(integrate.Config), &configMap); err != nil {
return fmt.Errorf("解析配置失败: %s", err.Error())
}
if configMap.EmailApi == nil {
return nil
}
cfg, err := parseEmailApiConfigFromJSON(configMap.EmailApi)
if err != nil {
return err
}
if err = validateEmailApiConfigFields(cfg); err != nil {
return err
}
global.GVA_LOG.Info("第三方邮箱API配置验证通过",
zap.String("url", configMap.EmailApi.URL),
zap.String("method", configMap.EmailApi.Method),
zap.String("body_type", configMap.EmailApi.BodyType),
zap.String("auth_type", configMap.EmailApi.Authorization.Type))
zap.String("url", cfg.URL),
zap.String("method", cfg.Method),
zap.String("body_type", cfg.BodyType),
zap.String("auth_type", cfg.Authorization.Type),
zap.Bool("new_format", isNewEmailApiConfig(cfg)))
return nil
}
// TestEmailApiConfig 测试邮箱 API 配置,返回详细的响应结果用于调试
func (e *SystemIntegratedService) TestEmailApiConfig(cfg request.EmailApiConfig, testDingID string) (*gaiaResp.TestEmailApiConfigResponse, error) {
if err := validateEmailApiConfigFields(cfg); err != nil {
return &gaiaResp.TestEmailApiConfigResponse{
IsValid: false,
ErrorMessage: "配置验证失败:" + err.Error(),
}, nil
}
dingId := strings.TrimSpace(testDingID)
if dingId == "" {
return &gaiaResp.TestEmailApiConfigResponse{
IsValid: false,
ErrorMessage: "测试钉钉 ID 不能为空,请先在弹窗中填写一个真实的 ding_id",
}, nil
}
respBody, statusCode, reqErr := e.doEmailApiRequest(dingId, cfg)
result := &gaiaResp.TestEmailApiConfigResponse{
StatusCode: statusCode,
}
// 尝试解析响应 Body 为 JSON
var bodyJSON interface{}
if json.Unmarshal(respBody, &bodyJSON) == nil {
result.Body = bodyJSON
} else {
result.Body = string(respBody)
}
if reqErr != nil {
result.IsValid = false
result.ErrorMessage = reqErr.Error()
return result, nil
}
// 尝试提取邮箱字段
if cfg.ResponseEmailField != "" {
if bodyMap, ok := bodyJSON.(map[string]interface{}); ok {
email := extractJSONPathAdvanced(bodyMap, cfg.ResponseEmailField)
if email != "" {
result.EmailFieldPreview = cfg.ResponseEmailField + " = " + email
result.IsValid = true
} else {
result.IsValid = false
result.ErrorMessage = "未找到邮箱字段:" + cfg.ResponseEmailField
}
} else if bodySlice, ok := bodyJSON.([]interface{}); ok {
email := extractJSONPathAdvanced(bodySlice, cfg.ResponseEmailField)
if email != "" {
result.EmailFieldPreview = cfg.ResponseEmailField + " = " + email
result.IsValid = true
} else {
result.IsValid = false
result.ErrorMessage = "未找到邮箱字段:" + cfg.ResponseEmailField
}
} else {
result.IsValid = statusCode >= 200 && statusCode < 300
}
} else {
result.IsValid = statusCode >= 200 && statusCode < 300
}
global.GVA_LOG.Info("测试邮箱 API 配置",
zap.Int("status_code", statusCode),
zap.String("email_preview", result.EmailFieldPreview),
zap.Bool("is_valid", result.IsValid))
return result, nil
}
// TestOAuth2Connection 测试OAuth2连接
// @Tags System Integrated
// @Summary 测试OAuth2连接
@@ -325,3 +595,412 @@ func (e *SystemIntegratedService) TestOAuth2Connection(integrate gaia.SystemInte
return nil
}
// ValidateForwardConfig 验证转发集成配置
// @Tags System Integrated
// @Summary 验证转发集成配置
// @param: integrate gaia.SystemIntegration
// @return: error
func (e *SystemIntegratedService) ValidateForwardConfig(integrate gaia.SystemIntegration) error {
// 解析 Config 字段
if integrate.Config == "" {
return nil // 配置为空不算错误
}
var configMap request.DingTalkConfigRequest
if err := json.Unmarshal([]byte(integrate.Config), &configMap); err != nil {
return fmt.Errorf("解析配置失败:%s", err.Error())
}
// 若未配置转发 Token,则认为未使用转发能力,不强制校验
if len(configMap.ForwardConfig.Tokens) == 0 {
return nil
}
// 使用转发能力的前置条件:至少 1 个 Token + 启用并配置「第三方邮箱配置」
if !configMap.EmailApi.Enabled || strings.TrimSpace(configMap.EmailApi.URL) == "" {
return errors.New("使用转发能力前请先启用并配置「第三方邮箱配置」")
}
// 验证 Token 数量
if len(configMap.ForwardConfig.Tokens) > 20 {
return errors.New("转发 Token 最多 20 个")
}
// 验证每个 Token 的必填字段
for i, token := range configMap.ForwardConfig.Tokens {
if token.ID == "" {
return fmt.Errorf("第%d个 Token 的 ID 不能为空", i+1)
}
if token.TokenHash == "" {
return fmt.Errorf("第%d个 Token 的 TokenHash 不能为空", i+1)
}
}
global.GVA_LOG.Info("转发集成配置验证通过",
zap.Int("token_count", len(configMap.ForwardConfig.Tokens)))
return nil
}
// ValidateDingIdApiConfig 验证第三方钉钉 ID 匹配 API 配置
// @Tags System Integrated
// @Summary 验证第三方钉钉 ID 匹配 API 配置
// @param: integrate gaia.SystemIntegration
// @return: error
// extractJSONPath 按点分路径从 JSON 对象中提取字符串值,支持 "data.username" 等多层路径
func extractJSONPath(data map[string]interface{}, path string) string {
parts := strings.SplitN(path, ".", 2)
val, ok := data[parts[0]]
if !ok {
return ""
}
if len(parts) == 1 {
if s, ok := val.(string); ok {
return s
}
return fmt.Sprintf("%v", val)
}
if nested, ok := val.(map[string]interface{}); ok {
return extractJSONPath(nested, parts[1])
}
return ""
}
// resolveParamValue 根据 ValueType 解析参数值,ding_id 类型替换为实际的钉钉 ID
func resolveParamValue(vt, value, dingId string) string {
switch vt {
case request.ValueTypeDingID:
return dingId
default:
// 兼容旧格式的 {{ding_id}} 占位符和新格式的 $<{[ding_id]}> 标记
v := strings.ReplaceAll(value, "{{ding_id}}", dingId)
v = strings.ReplaceAll(v, request.DingIDMarker, dingId)
return v
}
}
// buildBodyFields 将 []BodyField 按类型转换,构建 url.Values(用于 form-data 和 urlencoded
func buildBodyFields(fields []request.BodyField, dingId string) url.Values {
form := url.Values{}
for _, f := range fields {
if f.Key == "" {
continue
}
val := resolveParamValue(f.ValueType, f.Value, dingId)
form.Set(f.Key, val)
}
return form
}
// buildURL 根据新格式 Params 或旧格式 RequestParamField 构建带查询参数的 URL
func buildURL(baseURL string, config request.EmailApiConfig, dingId string) string {
if isNewEmailApiConfig(config) {
// 新格式:遍历 Params 列表自动拼接
params := url.Values{}
for _, p := range config.Params {
if p.Key == "" {
continue
}
params.Set(p.Key, resolveParamValue(p.ValueType, p.Value, dingId))
}
if len(params) == 0 {
return baseURL
}
sep := "?"
if strings.Contains(baseURL, "?") {
sep = "&"
}
return baseURL + sep + params.Encode()
}
// 旧格式:RequestParamField 字段名 + dingId 作为值
if config.RequestParamField != "" {
sep := "?"
if strings.Contains(baseURL, "?") {
sep = "&"
}
return baseURL + sep + url.QueryEscape(config.RequestParamField) + "=" + url.QueryEscape(dingId)
}
return baseURL
}
// callEmailApi 调用第三方邮箱 API,使用 ding_id(用户名) 获取邮箱
func (e *SystemIntegratedService) callEmailApi(
dingId string, config request.EmailApiConfig) (mailList []string, err error) {
// init
respBody, _, err := e.doEmailApiRequest(dingId, config)
if err != nil {
return mailList, err
}
var respJSON map[string]interface{}
if err = json.Unmarshal(respBody, &respJSON); err != nil {
return mailList, fmt.Errorf("解析响应 JSON 失败:%s", err.Error())
}
email := extractJSONPathAdvanced(respJSON, config.ResponseEmailField)
if email == "" {
return mailList, fmt.Errorf("响应中未找到邮箱(路径:%s)", config.ResponseEmailField)
}
//
mailList = append(mailList, email)
parts := strings.Split(email, "@")
defaultMail := os.Getenv(gaia.EmailDomainEnv)
if len(defaultMail) > 1 && len(parts) > 1 && len(parts[0]) > 0 {
mailList = append(mailList, parts[0]+"@"+defaultMail)
}
return mailList, nil
}
// doEmailApiRequest 构建并执行邮箱 API 请求,返回响应体字节和状态码
func (e *SystemIntegratedService) doEmailApiRequest(dingId string, config request.EmailApiConfig) ([]byte, int, error) {
method := strings.ToUpper(config.Method)
if method == "" {
method = "GET"
}
var bodyReader io.Reader
var contentType string
if method != "GET" && method != "HEAD" {
switch strings.ToLower(config.BodyType) {
case "raw":
raw := strings.ReplaceAll(config.BodyData.Raw, "{{ding_id}}", dingId)
raw = strings.ReplaceAll(raw, request.DingIDMarker, dingId)
bodyReader = strings.NewReader(raw)
contentType = "application/json"
case "form-data":
form := buildBodyFields(config.BodyData.FormData, dingId)
// 也处理旧格式的 urlencoded 字段(兼容)
if len(form) == 0 {
form = buildBodyFields(config.BodyData.Urlencoded, dingId)
}
bodyReader = strings.NewReader(form.Encode())
contentType = "multipart/form-data"
case "x-www-form-urlencoded":
form := buildBodyFields(config.BodyData.Urlencoded, dingId)
if len(form) == 0 {
form = buildBodyFields(config.BodyData.FormData, dingId)
}
bodyReader = strings.NewReader(form.Encode())
contentType = "application/x-www-form-urlencoded"
}
}
apiURL := buildURL(config.URL, config, dingId)
req, err := http.NewRequest(method, apiURL, bodyReader)
if err != nil {
return nil, 0, fmt.Errorf("构建请求失败:%s", err.Error())
}
// 设置 Content-Type(如果 Headers 未覆盖)
if contentType != "" {
req.Header.Set("Content-Type", contentType)
}
// 设置 Headers(可覆盖 Content-Type
for k, v := range config.Headers {
req.Header.Set(k, v)
}
// 设置 Authorization
authType := strings.ToLower(config.Authorization.Type)
switch authType {
case "bearer":
req.Header.Set("Authorization", "Bearer "+config.Authorization.Token)
case "basic":
req.SetBasicAuth(config.Authorization.Username, config.Authorization.Password)
}
client := &http.Client{Timeout: 15 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, 0, fmt.Errorf("请求失败:%s", err.Error())
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, resp.StatusCode, fmt.Errorf("读取响应失败:%s", err.Error())
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return respBody, resp.StatusCode, fmt.Errorf("第三方 API 返回错误状态码:%d", resp.StatusCode)
}
return respBody, resp.StatusCode, nil
}
// extractJSONPathAdvanced 支持点分路径和数组索引(如 data[0].userName
func extractJSONPathAdvanced(data interface{}, path string) string {
if path == "" {
return ""
}
parts := splitJSONPath(path)
current := data
for _, part := range parts {
switch v := current.(type) {
case map[string]interface{}:
current = v[part.key]
case []interface{}:
if part.isIndex && part.index >= 0 && part.index < len(v) {
current = v[part.index]
} else {
return ""
}
default:
return ""
}
if current == nil {
return ""
}
}
switch v := current.(type) {
case string:
return v
case float64:
return strconv.FormatFloat(v, 'f', -1, 64)
case bool:
return strconv.FormatBool(v)
default:
return fmt.Sprintf("%v", v)
}
}
type jsonPathPart struct {
key string
isIndex bool
index int
}
// splitJSONPath 将路径字符串分割为结构化的部分列表,支持 data[0].userName 格式
func splitJSONPath(path string) []jsonPathPart {
var parts []jsonPathPart
// 先按点分割
segments := strings.Split(path, ".")
for _, seg := range segments {
seg = strings.TrimSpace(seg)
if seg == "" {
continue
}
// 检查是否包含数组索引 data[0]
if idx := strings.Index(seg, "["); idx != -1 {
key := seg[:idx]
rest := seg[idx:]
if key != "" {
parts = append(parts, jsonPathPart{key: key})
}
// 解析所有 [N] 部分
for len(rest) > 0 && rest[0] == '[' {
end := strings.Index(rest, "]")
if end == -1 {
break
}
idxStr := rest[1:end]
rest = rest[end+1:]
if n, err := strconv.Atoi(idxStr); err == nil {
parts = append(parts, jsonPathPart{isIndex: true, index: n})
}
}
} else {
parts = append(parts, jsonPathPart{key: seg})
}
}
return parts
}
// ParseForwardToken 从 Bearer Token 中验签并提取 ding_id
// 遍历 tokens 列表,找到签名匹配的条目
func (e *SystemIntegratedService) ParseForwardToken(
rawToken string, tokens []request.ForwardToken) (dingId string, err error) {
parts := strings.SplitN(rawToken, ".", 2)
if len(parts) != 2 {
return "", errors.New("token 格式非法")
}
payloadB64, sigB64 := parts[0], parts[1]
sigBytes, err := base64.RawURLEncoding.DecodeString(sigB64)
if err != nil {
return "", errors.New("签名解码失败")
}
for _, t := range tokens {
if t.TokenSecret == "" {
continue
}
// 直接使用原始字节作为 HMAC 密钥,兼容任意字符串格式的密钥
secret := []byte(t.TokenSecret)
// 验证 HMAC 签名
mac := hmac.New(sha256.New, secret)
mac.Write([]byte(payloadB64))
expected := mac.Sum(nil)
if !hmac.Equal(expected, sigBytes) {
continue // 不匹配,试下一个
}
// 签名验证通过,解析 payload
payloadBytes, err := base64.RawURLEncoding.DecodeString(payloadB64)
if err != nil {
return "", errors.New("payload 解码失败")
}
var payload struct {
DingID string `json:"ding_id"`
}
if err := json.Unmarshal(payloadBytes, &payload); err != nil {
return "", errors.New("payload 解析失败")
}
if payload.DingID == "" {
return "", errors.New("token 中缺少 ding_id")
}
return payload.DingID, nil
}
return "", errors.New("无效的转发 Token")
}
// ResolveAccountByDingId 通过钉钉 ID 解析 gaia account_id
// 解析顺序:Redis 缓存 → AccountDingTalkExtend 本地表 → 第三方 EmailApi(邮箱 API
func (e *SystemIntegratedService) ResolveAccountByDingId(
dingId string, apiConfig request.EmailApiConfig) (string, error) {
// 1. 查 Redis 缓存
ctx := context.Background()
redisKey := gaia.RedisKeyGaiaForwardDingPrefix + dingId
if cached, err := global.GVA_REDIS.Get(ctx, redisKey).Result(); err == nil && cached != "" {
return cached, nil
}
// 2. 查本地 AccountDingTalkExtend 表
var extend gaia.AccountDingTalkExtend
if err := global.GVA_DB.Where("ding_talk = ?", dingId).First(&extend).Error; err == nil {
accountID := extend.ID.String()
global.GVA_REDIS.Set(ctx, redisKey, accountID, 24*time.Hour)
return accountID, nil
}
// 3. 第三方邮箱 API(若配置)
if !apiConfig.Enabled || strings.TrimSpace(apiConfig.URL) == "" {
return "", fmt.Errorf("未找到 ding_id=%s 对应的用户,且未配置第三方邮箱 API", dingId)
}
email, err := e.callEmailApi(dingId, apiConfig)
if err != nil {
return "", fmt.Errorf("调用第三方邮箱 API 失败:%s", err.Error())
}
// 4. 按邮箱查 accounts 表(匹配 email 字段)
var account gaia.Account
if err = global.GVA_DB.Where("email = ?", email).First(&account).Error; err != nil {
return "", fmt.Errorf("用户 %s 不存在(来自第三方邮箱 API)", email)
}
accountID := account.ID.String()
// 5. 写回 AccountDingTalkExtend,方便下次本地命中
global.GVA_DB.Create(&gaia.AccountDingTalkExtend{
ID: account.ID,
DingTalk: dingId,
})
// 6. 写 Redis 缓存
global.GVA_REDIS.Set(ctx, redisKey, accountID, 24*time.Hour)
return accountID, nil
}
+131 -222
View File
@@ -15,13 +15,6 @@ import (
"github.com/flipped-aurora/gin-vue-admin/server/model/gaia"
)
// UserWorkerAllocation 用户工作器分配信息
type UserWorkerAllocation struct {
UserID uint `json:"user_id"`
Workers int `json:"workers"`
MaxLimit int `json:"max_limit"`
}
// 全局工作池实例
var globalWorkerPool *WorkerPool
@@ -29,10 +22,10 @@ var globalWorkerPool *WorkerPool
type WorkerPool struct {
ctx context.Context
cancel context.CancelFunc
totalWorkers int // 总工作器数量
userWorkers map[uint]*UserWorkerAllocation // 每个用户的工作器分配
userTaskChan map[uint]chan *gaia.BatchWorkflowTask // 每个用户的任务队列
runningWorkers map[uint]int // 每个用户当前运行的worker数量
totalWorkers int // 总工作器数量
userWorkers map[uint]*gaia.UserWorkerAllocation // 每个用户的工作器分配
userTaskChan map[uint]chan *gaia.BatchWorkflowTask // 每个用户的任务队列
runningWorkers map[uint]int // 每个用户当前运行的worker数量
wg sync.WaitGroup
batchService *BatchWorkflowService
running bool
@@ -47,7 +40,7 @@ func NewWorkerPool(totalWorkers int) *WorkerPool {
ctx: ctx,
cancel: cancel,
totalWorkers: totalWorkers,
userWorkers: make(map[uint]*UserWorkerAllocation),
userWorkers: make(map[uint]*gaia.UserWorkerAllocation),
userTaskChan: make(map[uint]chan *gaia.BatchWorkflowTask),
runningWorkers: make(map[uint]int),
batchService: &BatchWorkflowService{},
@@ -101,15 +94,13 @@ func (wp *WorkerPool) calculateUserWorkerAllocation() {
return
}
// 第二个查询:获取这些用户的所有批量工作流的累计错误次数(不限状态
type UserErrorInfo struct {
UserID uint `json:"user_id"`
ErrorCount int `json:"error_count"`
}
var userErrorInfos []UserErrorInfo
// 第二个查询:获取这些用户的累计错误次数和最小 total_rowspending 状态工作流
var userErrorInfos []gaia.UserErrorInfo
if len(activeUserIDs) > 0 {
err = global.GVA_DB.Raw(`
SELECT bw.user_id, COALESCE(SUM(bw.error_count), 0) as error_count
SELECT bw.user_id,
COALESCE(SUM(bw.error_count), 0) AS error_count,
COALESCE(SUM(bw.total_rows), 0) AS total_rows
FROM batch_workflows_extend bw
WHERE bw.user_id IN (?) AND bw.status='pending'
GROUP BY bw.user_id
@@ -121,10 +112,12 @@ func (wp *WorkerPool) calculateUserWorkerAllocation() {
}
}
// 提取活跃用户ID列表错误次数映射
// 提取活跃用户ID列表错误次数映射和 total_rows 映射
userErrorMap := make(map[uint]int)
userTotalRowsMap := make(map[uint]int)
for _, info := range userErrorInfos {
userErrorMap[info.UserID] = info.ErrorCount
userTotalRowsMap[info.UserID] = info.TotalRows
}
userCount := len(activeUserIDs)
@@ -133,7 +126,7 @@ func (wp *WorkerPool) calculateUserWorkerAllocation() {
for _, ch := range wp.userTaskChan {
close(ch)
}
wp.userWorkers = make(map[uint]*UserWorkerAllocation)
wp.userWorkers = make(map[uint]*gaia.UserWorkerAllocation)
wp.userTaskChan = make(map[uint]chan *gaia.BatchWorkflowTask)
wp.runningWorkers = make(map[uint]int)
return
@@ -155,226 +148,142 @@ func (wp *WorkerPool) calculateUserWorkerAllocation() {
}
}
// 检查用户数量是否超过了最大支持数量(每用户最少1个工作器)
maxSupportedUsers := wp.totalWorkers / 1
// 存储新的分配计算结果
newAllocations := make(map[uint]*UserWorkerAllocation)
newAllocations := make(map[uint]*gaia.UserWorkerAllocation)
if userCount <= maxSupportedUsers {
// 用户数量在可支持范围内,采用两阶段分配策略
baseAllocation := wp.totalWorkers / userCount
remainder := wp.totalWorkers % userCount
// 第一阶段:计算每个用户的基础分配和错误惩罚后的实际分配
type UserAllocationInfo struct {
UserID uint
BaseWorkers int
ActualWorkers int
ErrorCount int
PenaltyReduced int
// 计算每个用户的优先级权重(total_rows 越小权重越大)
// 权重:tier1=4, tier2=3, tier3=2, tier4=1
priorityWeightOf := func(totalRows int) int {
switch {
case totalRows <= gaia.PriorityTier1MaxRows:
return 4
case totalRows <= gaia.PriorityTier2MaxRows:
return 3
case totalRows <= gaia.PriorityTier3MaxRows:
return 2
default:
return 1
}
}
var userAllocations []UserAllocationInfo
totalPenaltyReduced := 0
// 计算所有用户的权重总和,用于按比例分配 worker
totalWeight := 0
for _, userID := range activeUserIDs {
totalWeight += priorityWeightOf(userTotalRowsMap[userID])
}
if totalWeight == 0 {
totalWeight = userCount
}
for i, userID := range activeUserIDs {
baseWorkers := baseAllocation
// 处理余数,前几个用户多分配一个
if i < remainder {
baseWorkers++
}
// UserAllocationInfo 单用户分配中间计算结构
type UserAllocationInfo struct {
UserID uint
BaseWorkers int
ActualWorkers int
ErrorCount int
PenaltyReduced int
}
// 确保每个用户至少有1个并发位
if baseWorkers < 1 {
baseWorkers = 1
}
// 应用错误惩罚:根据用户的累计错误次数减少并发位
errorCount := userErrorMap[userID]
actualWorkers := wp.calculateWorkerCountWithErrorPenalty(baseWorkers, errorCount)
penaltyReduced := baseWorkers - actualWorkers
totalPenaltyReduced += penaltyReduced
userAllocations = append(userAllocations, UserAllocationInfo{
UserID: userID,
BaseWorkers: baseWorkers,
ActualWorkers: actualWorkers,
ErrorCount: errorCount,
PenaltyReduced: penaltyReduced,
})
// 按权重比例为每个用户分配基础 worker 数,余数补给权重最高的用户
userAllocations := make([]UserAllocationInfo, 0, userCount)
allocatedBase := 0
for _, userID := range activeUserIDs {
w := priorityWeightOf(userTotalRowsMap[userID])
base := wp.totalWorkers * w / totalWeight
if base < 1 {
base = 1
}
userAllocations = append(userAllocations, UserAllocationInfo{
UserID: userID,
BaseWorkers: base,
ErrorCount: userErrorMap[userID],
})
allocatedBase += base
}
// 将剩余 worker 补给权重最大的用户(已按 activeUserIDs 顺序,这里找最大权重的)
remainder := wp.totalWorkers - allocatedBase
if remainder > 0 {
// 找权重最大的用户索引
maxW, maxIdx := 0, 0
for i, alloc := range userAllocations {
w := priorityWeightOf(userTotalRowsMap[alloc.UserID])
if w > maxW {
maxW, maxIdx = w, i
}
}
userAllocations[maxIdx].BaseWorkers += remainder
}
// 第二阶段:将空出来的并发位重新分配给错误较少的用户
if totalPenaltyReduced > 0 {
// 按错误数量排序,错误少的用户优先获得额外分配
for i := 0; i < len(userAllocations)-1; i++ {
for j := i + 1; j < len(userAllocations); j++ {
if userAllocations[i].ErrorCount > userAllocations[j].ErrorCount {
userAllocations[i], userAllocations[j] = userAllocations[j], userAllocations[i]
// 应用错误惩罚并收集被释放的 worker
totalPenaltyReduced := 0
for i := range userAllocations {
actual := wp.calculateWorkerCountWithErrorPenalty(
userAllocations[i].BaseWorkers, userAllocations[i].ErrorCount)
userAllocations[i].PenaltyReduced = userAllocations[i].BaseWorkers - actual
userAllocations[i].ActualWorkers = actual
totalPenaltyReduced += userAllocations[i].PenaltyReduced
}
// 将惩罚释放的 worker 重新分配给无惩罚用户(按权重优先)
if totalPenaltyReduced > 0 {
// 按权重降序、错误数升序排序,优先分配给高优先级无惩罚用户
for i := 0; i < len(userAllocations)-1; i++ {
for j := i + 1; j < len(userAllocations); j++ {
wi := priorityWeightOf(userTotalRowsMap[userAllocations[i].UserID])
wj := priorityWeightOf(userTotalRowsMap[userAllocations[j].UserID])
if wi < wj || (wi == wj && userAllocations[i].ErrorCount > userAllocations[j].ErrorCount) {
userAllocations[i], userAllocations[j] = userAllocations[j], userAllocations[i]
}
}
}
remainingToDistribute := totalPenaltyReduced
eligibleUsers := 0
for _, a := range userAllocations {
if a.PenaltyReduced == 0 {
eligibleUsers++
}
}
if eligibleUsers > 0 {
for i := 0; i < len(userAllocations) && remainingToDistribute > 0; i++ {
if userAllocations[i].PenaltyReduced == 0 {
extra := remainingToDistribute / eligibleUsers
if extra < 1 {
extra = 1
}
}
}
// 只为没有被惩罚的用户(PenaltyReduced = 0)重新分配空闲的并发位
// 被惩罚的用户不应该获得额外分配
remainingToDistribute := totalPenaltyReduced
eligibleUsers := 0
// 计算有资格获得额外分配的用户数量(没有被惩罚的用户)
for _, allocation := range userAllocations {
if allocation.PenaltyReduced == 0 {
eligibleUsers++
}
}
if eligibleUsers > 0 {
// 只为没有被惩罚的用户分配额外的并发位
for i := 0; i < len(userAllocations) && remainingToDistribute > 0; i++ {
if userAllocations[i].PenaltyReduced == 0 {
// 为没有错误惩罚的用户分配额外的并发位
extraWorkers := remainingToDistribute / eligibleUsers
if extraWorkers < 1 {
extraWorkers = 1
}
if extraWorkers > remainingToDistribute {
extraWorkers = remainingToDistribute
}
userAllocations[i].ActualWorkers += extraWorkers
remainingToDistribute -= extraWorkers
eligibleUsers--
if extra > remainingToDistribute {
extra = remainingToDistribute
}
userAllocations[i].ActualWorkers += extra
remainingToDistribute -= extra
eligibleUsers--
}
}
}
}
// 创建最终分配结果
totalFinalWorkers := 0
for _, allocation := range userAllocations {
newAllocations[allocation.UserID] = &UserWorkerAllocation{
// 写入最终分配,超出总量时截断(降级场景)
allocatedWorkers := 0
for _, allocation := range userAllocations {
workers := allocation.ActualWorkers
remaining := wp.totalWorkers - allocatedWorkers
if workers > remaining {
workers = remaining
}
if workers > 0 {
newAllocations[allocation.UserID] = &gaia.UserWorkerAllocation{
UserID: allocation.UserID,
Workers: allocation.ActualWorkers,
Workers: workers,
MaxLimit: wp.totalWorkers,
}
totalFinalWorkers += allocation.ActualWorkers
allocatedWorkers += workers
}
} else {
// 用户数量超过最大支持数量,采用降级分配策略(两阶段分配)
baseAllocation := wp.totalWorkers / userCount
remainder := wp.totalWorkers % userCount
// 第一阶段:计算每个用户的基础分配和错误惩罚后的实际分配
type UserAllocationInfo struct {
UserID uint
BaseWorkers int
ActualWorkers int
ErrorCount int
PenaltyReduced int
if allocatedWorkers >= wp.totalWorkers {
break
}
var userAllocations []UserAllocationInfo
totalPenaltyReduced := 0
for i, userID := range activeUserIDs {
baseWorkers := baseAllocation
// 处理余数,前几个用户多分配一个
if i < remainder {
baseWorkers++
}
// 确保至少分配1个工作器
if baseWorkers < 1 {
baseWorkers = 1
}
// 应用错误惩罚:根据用户的累计错误次数减少并发位
errorCount := userErrorMap[userID]
actualWorkers := wp.calculateWorkerCountWithErrorPenalty(baseWorkers, errorCount)
penaltyReduced := baseWorkers - actualWorkers
totalPenaltyReduced += penaltyReduced
// 添加详细的错误惩罚计算调试日志
userAllocations = append(userAllocations, UserAllocationInfo{
UserID: userID,
BaseWorkers: baseWorkers,
ActualWorkers: actualWorkers,
ErrorCount: errorCount,
PenaltyReduced: penaltyReduced,
})
}
// 第二阶段:将空出来的并发位重新分配给错误较少的用户
if totalPenaltyReduced > 0 {
// 按错误数量排序,错误少的用户优先获得额外分配
for i := 0; i < len(userAllocations)-1; i++ {
for j := i + 1; j < len(userAllocations); j++ {
if userAllocations[i].ErrorCount > userAllocations[j].ErrorCount {
userAllocations[i], userAllocations[j] = userAllocations[j], userAllocations[i]
}
}
}
// 只为没有被惩罚的用户(PenaltyReduced = 0)重新分配空闲的并发位
// 被惩罚的用户不应该获得额外分配
remainingToDistribute := totalPenaltyReduced
eligibleUsers := 0
// 计算有资格获得额外分配的用户数量(没有被惩罚的用户)
for _, allocation := range userAllocations {
if allocation.PenaltyReduced == 0 {
eligibleUsers++
}
}
if eligibleUsers > 0 {
// 只为没有被惩罚的用户分配额外的并发位
for i := 0; i < len(userAllocations) && remainingToDistribute > 0; i++ {
if userAllocations[i].PenaltyReduced == 0 {
// 为没有错误惩罚的用户分配额外的并发位
extraWorkers := remainingToDistribute / eligibleUsers
if extraWorkers < 1 {
extraWorkers = 1
}
if extraWorkers > remainingToDistribute {
extraWorkers = remainingToDistribute
}
userAllocations[i].ActualWorkers += extraWorkers
remainingToDistribute -= extraWorkers
eligibleUsers--
}
}
}
}
// 创建最终分配结果,确保不超过总工作器数量
allocatedWorkers := 0
for _, allocation := range userAllocations {
workers := allocation.ActualWorkers
// 确保不会超过剩余的工作器数量
remainingWorkers := wp.totalWorkers - allocatedWorkers
if workers > remainingWorkers {
workers = remainingWorkers
}
if workers > 0 {
newAllocations[allocation.UserID] = &UserWorkerAllocation{
UserID: allocation.UserID,
Workers: workers,
MaxLimit: wp.totalWorkers,
}
allocatedWorkers += workers
}
// 如果工作器已经分配完毕,剩余用户分配0个工作器
if allocatedWorkers >= wp.totalWorkers {
break
}
}
global.GVA_LOG.Warn(fmt.Sprintf("降级分配完成 - 总工作器: %d, 用户数: %d, 已分配: %d, 重新分配: %d, 平均每用户: %.1f个",
wp.totalWorkers, userCount, allocatedWorkers, totalPenaltyReduced, float64(allocatedWorkers)/float64(userCount)))
}
if userCount > wp.totalWorkers {
global.GVA_LOG.Warn(fmt.Sprintf("降级分配完成 - 总工作器: %d, 用户数: %d, 已分配: %d, 平均每用户: %.1f个",
wp.totalWorkers, userCount, allocatedWorkers, float64(allocatedWorkers)/float64(userCount)))
}
// 应用新的分配,只更新有变化的用户
+8 -2
View File
@@ -171,8 +171,10 @@ func (casbinService *CasbinService) AddPolicies(db *gorm.DB, rules [][]string) e
func (CasbinService *CasbinService) FreshCasbin() (err error) {
e := CasbinService.Casbin()
err = e.LoadPolicy()
return err
if e == nil {
return errors.New("casbin enforcer is nil, please check database initialization")
}
return e.LoadPolicy()
}
//@author: [piexlmax](https://github.com/piexlmax)
@@ -187,6 +189,10 @@ var (
func (casbinService *CasbinService) Casbin() *casbin.SyncedCachedEnforcer {
once.Do(func() {
if global.GVA_DB == nil {
zap.L().Warn("Casbin initialization skipped: global.GVA_DB is nil")
return
}
a, err := gormadapter.NewAdapterByDB(global.GVA_DB)
if err != nil {
zap.L().Error("适配数据库失败请检查casbin表是否为InnoDB引擎!", zap.Error(err))
+25 -1
View File
@@ -6,7 +6,9 @@ import (
"errors"
"fmt"
"github.com/flipped-aurora/gin-vue-admin/server/global"
modelSystem "github.com/flipped-aurora/gin-vue-admin/server/model/system"
"github.com/flipped-aurora/gin-vue-admin/server/model/system/request"
"go.uber.org/zap"
"gorm.io/gorm"
"sort"
)
@@ -86,6 +88,22 @@ func RegisterInit(order int, i SubInitializer) {
type InitDBService struct{}
// IfInit 判断是否数据库初始化了
func (initDBService *InitDBService) IfInit() (init bool) {
var menuCount, authorityCount int64
if global.GVA_DB != nil {
init = true
var menu modelSystem.SysBaseMenuBtn
var authority modelSystem.SysAuthority
global.GVA_DB.Model(&menu).Count(&menuCount)
global.GVA_DB.Model(&authority).Count(&authorityCount)
if menuCount <= 1 && authorityCount <= 1 {
init = false
}
}
return init
}
// InitDB 创建数据库并初始化 总入口
func (initDBService *InitDBService) InitDB(conf request.InitDB) (err error) {
ctx := context.TODO()
@@ -122,7 +140,9 @@ func (initDBService *InitDBService) InitDB(conf request.InitDB) (err error) {
db := ctx.Value("db").(*gorm.DB)
global.GVA_DB = db
db.Exec("DELETE FROM sys_base_menus")
db.Exec("DELETE FROM sys_authorities")
db.Exec("DELETE FROM sys_user_authority")
if err = initHandler.InitTables(ctx, initializers); err != nil {
return err
}
@@ -133,6 +153,10 @@ func (initDBService *InitDBService) InitDB(conf request.InitDB) (err error) {
if err = initHandler.WriteConfig(ctx); err != nil {
return err
}
// 初始化完成后刷新 Casbin 策略,避免使用旧的或空的策略
if err = CasbinServiceApp.FreshCasbin(); err != nil {
global.GVA_LOG.Warn("refresh casbin policy after InitDB failed", zap.Error(err))
}
initializers = initSlice{}
cache = map[string]*orderedInitializer{}
return nil
@@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"github.com/spf13/viper"
"go.uber.org/zap"
"path/filepath"
"github.com/flipped-aurora/gin-vue-admin/server/config"
@@ -35,10 +36,10 @@ func (h PgsqlInitHandler) WriteConfig(ctx context.Context) error {
// 改成拿dify的配置,如果不是docker运行,则从dify api的.env文件中获取jwt的加密key
if !global.GVA_CONFIG.System.DockerRun {
var err error
global.GVA_CONFIG.JWT.SigningKey, err = h.GetJwtSigningKeyFormDifyApiEnv()
if err != nil {
return err
if secretKey, err := h.GetJwtSigningKeyFormDifyApiEnv(); err != nil {
global.GVA_LOG.Warn("failed to load JWT signing key from Dify API .env, using existing configuration", zap.Error(err))
} else if secretKey != "" {
global.GVA_CONFIG.JWT.SigningKey = secretKey
}
}
cs := utils.StructToMap(global.GVA_CONFIG)
+9
View File
@@ -201,6 +201,9 @@ func (i *initApi) InitializeData(ctx context.Context) (context.Context, error) {
// Extend Start: system integration
{ApiGroup: "应用集成配置", Method: "GET", Path: "/gaia/system/dingtalk", Description: "获取钉钉系统配置"},
{ApiGroup: "应用集成配置", Method: "POST", Path: "/gaia/system/dingtalk", Description: "设置钉钉系统配置"},
{ApiGroup: "应用集成配置", Method: "POST", Path: "/gaia/system/dingtalk/test-email-config", Description: "测试钉钉邮箱API配置"},
{ApiGroup: "应用集成配置", Method: "GET", Path: "/gaia/system/dingtalk/test-auth-url", Description: "测试连接-获取钉钉授权URL"},
{ApiGroup: "应用集成配置", Method: "POST", Path: "/gaia/system/dingtalk/test-callback", Description: "测试连接-钉钉回调验证"},
// Extend Stop: system integration
// Extend Start: oauth2
@@ -244,6 +247,12 @@ func (i *initApi) InitializeData(ctx context.Context) (context.Context, error) {
{ApiGroup: "模型管理", Method: "PATCH", Path: "/gaia/proxy/*", Description: "中转API(第三方)-PATCH"},
{ApiGroup: "模型管理", Method: "DELETE", Path: "/gaia/proxy/*", Description: "中转API(第三方)-DELETE"},
// Extend Stop: model provider
// Extend Start: 转发集成 (forward tokens)
{ApiGroup: "转发集成", Method: "GET", Path: "/gaia/system/forward-tokens", Description: "获取转发 Token 列表"},
{ApiGroup: "转发集成", Method: "POST", Path: "/gaia/system/forward-tokens", Description: "新增转发 Token"},
{ApiGroup: "转发集成", Method: "DELETE", Path: "/gaia/system/forward-tokens/:id", Description: "删除转发 Token"},
// Extend Stop: 转发集成
}
if err := db.Create(&entities).Error; err != nil {
return ctx, errors.Wrap(err, sysModel.SysApi{}.TableName()+"表数据初始化失败!")
+12
View File
@@ -287,6 +287,9 @@ func (i *initCasbin) InitializeData(ctx context.Context) (context.Context, error
// Extend Start: system integration
{Ptype: "p", V0: "888", V1: "/gaia/system/dingtalk", V2: "GET"},
{Ptype: "p", V0: "888", V1: "/gaia/system/dingtalk", V2: "POST"},
{Ptype: "p", V0: "888", V1: "/gaia/system/dingtalk/test-email-config", V2: "POST"},
{Ptype: "p", V0: "888", V1: "/gaia/system/dingtalk/test-auth-url", V2: "GET"},
{Ptype: "p", V0: "888", V1: "/gaia/system/dingtalk/test-callback", V2: "POST"},
// Extend Stop: system integration
// Extend Start: oauth2
@@ -403,6 +406,15 @@ func (i *initCasbin) InitializeData(ctx context.Context) (context.Context, error
{Ptype: "p", V0: "8881", V1: "/gaia/proxy/*", V2: "PATCH"},
{Ptype: "p", V0: "8881", V1: "/gaia/proxy/*", V2: "DELETE"},
// Extend Stop: model provider
// Extend Start: 转发集成 (forward tokens)
{Ptype: "p", V0: "888", V1: "/gaia/system/forward-tokens", V2: "GET"},
{Ptype: "p", V0: "888", V1: "/gaia/system/forward-tokens", V2: "POST"},
{Ptype: "p", V0: "888", V1: "/gaia/system/forward-tokens/:id", V2: "DELETE"},
{Ptype: "p", V0: "8881", V1: "/gaia/system/forward-tokens", V2: "GET"},
{Ptype: "p", V0: "8881", V1: "/gaia/system/forward-tokens", V2: "POST"},
{Ptype: "p", V0: "8881", V1: "/gaia/system/forward-tokens/:id", V2: "DELETE"},
// Extend Stop: 转发集成
}
if err := db.Create(&entities).Error; err != nil {
return ctx, errors.Wrap(err, "Casbin 表 ("+i.InitializerName()+") 数据初始化失败!")
+65 -1
View File
@@ -41,7 +41,7 @@ export const getSystemOAuth2 = () => {
}
// @Tags systrm
// @Summary 修改OAuth2集成配置
// @Summary 修改 OAuth2 集成配置
// @Security ApiKeyAuth
// @Produce application/json
// @Success 200 {string} string "{"success":true,"data":{},"msg":"返回成功"}"
@@ -53,3 +53,67 @@ export const setSystemOAuth2 = (data) => {
data,
})
}
// @Tags systrm
// @Summary 获取转发 Token 列表
// @Security ApiKeyAuth
// @Router /gaia/system/forward-tokens [get]
export const getForwardTokens = () => {
return service({
url: '/gaia/system/forward-tokens',
method: 'get'
})
}
// @Tags systrm
// @Summary 新增转发 Token
// @Security ApiKeyAuth
// @Router /gaia/system/forward-tokens [post]
export const createForwardToken = (data) => {
return service({
url: '/gaia/system/forward-tokens',
method: 'post',
data,
})
}
// @Tags systrm
// @Summary 删除转发 Token
// @Security ApiKeyAuth
// @Router /gaia/system/forward-tokens/:seq [delete]
export const deleteForwardToken = (seq, password) => {
return service({
url: `/gaia/system/forward-tokens/${seq}`,
method: 'delete',
data: { password },
})
}
// @Tags systrm
// @Summary 测试第三方邮箱 API 配置
// @Security ApiKeyAuth
// @Router /gaia/system/dingtalk/test-email-config [post]
export const testEmailApiConfig = (data) => {
return service({
url: '/gaia/system/dingtalk/test-email-config',
method: 'post',
data,
})
}
// 测试连接:获取钉钉授权 URL(打开后扫码完成即视为连接成功)
export const getDingTalkTestAuthUrl = () => {
return service({
url: '/gaia/system/dingtalk/test-auth-url',
method: 'get',
})
}
// 测试连接回调:仅用 code 验证,不登录
export const dingtalkTestCallback = (data) => {
return service({
url: '/gaia/system/dingtalk/test-callback',
method: 'post',
data: { code: data.code },
})
}
+14
View File
@@ -11,6 +11,7 @@ import { useUserStore } from '@/pinia/modules/user'
import { useRouterStore } from '@/pinia/modules/router'
import router from '@/router'
import { gaiaOAuth2Login, dingtalkLogin } from '@/api/user_extend'
import { dingtalkTestCallback } from '@/api/gaia/system'
defineOptions({
name: 'LoginCallback',
@@ -70,6 +71,19 @@ const callback = async () => {
const redirectUri = sessionStorage.getItem('gaia_login_redirect_uri') || ''
const state = sessionStorage.getItem('gaia_login_state') || getQueryParam('state') || ''
// code token postMessage
if (provider === 'dingtalk' && state === 'dingtalk_test') {
try {
const res = await dingtalkTestCallback({ code })
const payload = { type: 'dingtalk_test_result', success: res?.code === 0, message: res?.msg }
if (window.opener) window.opener.postMessage(payload, '*')
} catch (e) {
if (window.opener) window.opener.postMessage({ type: 'dingtalk_test_result', success: false, message: e?.message || '验证失败' }, '*')
}
window.close()
return
}
try {
if (provider === 'dingtalk') {
if (!hasCode) {
File diff suppressed because it is too large Load Diff
@@ -146,7 +146,13 @@ const providerList = ref([])
const providerDisplayNames = {
openai: 'OpenAI',
tongyi: '千问(通义)',
google: 'Google Gemini'
google: 'Google Gemini',
anthropic: 'Anthropic',
aws: 'AWS Bedrock',
azure: 'Azure OpenAI',
zhipuai: '智谱 AI',
minimax: 'MiniMax',
deepseek: 'DeepSeek'
}
const getProviderDisplayName = (providerName) => {
@@ -395,7 +401,7 @@ onMounted(() => {
:deep(.el-select-dropdown) {
.el-select-dropdown__item {
padding: 8px 16px;
&.is-selected {
font-weight: 600;
color: #409eff;
+10 -1
View File
@@ -6,7 +6,6 @@ from pydantic_settings import BaseSettings
class ExtendInfo(BaseSettings):
OAUTH2_CLIENT_ID: Optional[str] = Field(
description="OA client id for OAuth",
default=None,
@@ -74,6 +73,16 @@ class ExtendInfo(BaseSettings):
)
# Extend: 记忆上下文功能
# Extend: 控制台首次安装完成后,向内网 Admin 服务触发 /init/initdb(与 DB_* 同源,无需额外密钥文件)
ADMIN_INITDB_ENABLED: bool = Field(
description="Dify 安装向导完成后是否请求 Admin 初始化业务库表",
default=False,
)
ADMIN_INITDB_URL: str = Field(
description="Admin InitDB 接口完整 URLDocker 默认可为 http://admin-server:8888/admin/api/init/initdb",
default="http://admin-server:8888/admin/api/init/initdb",
)
class ExtendConfig(ExtendInfo):
pass
+4 -3
View File
@@ -170,7 +170,8 @@ class BaseApiKeyListResource(Resource):
@marshal_with(api_key_fields)
def put(self, resource_id):
resource_id = str(resource_id)
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
current_user, current_tenant_id = current_account_with_tenant()
_get_resource(resource_id, current_tenant_id, self.resource_model)
if not current_user.is_admin_or_owner:
raise Forbidden()
@@ -198,7 +199,7 @@ class BaseApiKeyListResource(Resource):
)
if key is None:
flask_restful.abort(404, message="API密钥未找到")
flask_restx.abort(404, message="API密钥未找到")
data = request.get_json()
@@ -227,7 +228,7 @@ class BaseApiKeyListResource(Resource):
merged_data = {**api_token.__dict__, **api_token_money_extend.__dict__}
return merged_data, 200
else:
flask_restful.abort(500, message="更新API密钥时发生错误")
flask_restx.abort(500, message="更新API密钥时发生错误")
# --------------------- 二开部分End - 密钥额度限制 ---------------------
+7 -35
View File
@@ -21,6 +21,7 @@ DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class StatisticTimeRangeQuery(BaseModel):
start: str | None = Field(default=None, description="Start date (YYYY-MM-DD HH:MM)")
end: str | None = Field(default=None, description="End date (YYYY-MM-DD HH:MM)")
account: str | None = Field(default=None, description="Account ID filter")
@field_validator("start", "end", mode="before")
@classmethod
@@ -114,14 +115,8 @@ class DailyConversationStatistic(Resource):
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
converted_created_at = convert_datetime_to_date("created_at")
sql_query = f"""SELECT
{converted_created_at} AS date,
COUNT(DISTINCT conversation_id) AS conversation_count
FROM
messages
WHERE
app_id = :app_id
AND invoke_from != :invoke_from"""
sql_query = f"""SELECT {converted_created_at} AS date, COUNT(DISTINCT conversation_id) AS conversation_count
FROM messages WHERE app_id = :app_id AND invoke_from != :invoke_from"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
assert account.timezone is not None
@@ -131,33 +126,10 @@ WHERE
abort(400, description=str(e))
if args.account is not None and args.account:
sql_query += ""
# stmt = (
# select(
# func.date(
# func.date_trunc("day", text("created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz"))
# ).label("date"),
# func.count(distinct(Message.conversation_id)).label("conversation_count")
# )
# .select_from(Message)
# .where(
# Message.app_id == app_model.id,
# or_(
# Message.from_account_id == account.id,
# Message.from_end_user_id.in_(
# select(EndUser.id)
# .where(EndUser.external_user_id == account.id)
# .distinct()
# )
# )
# )
# .group_by(
# func.date(
# func.date_trunc("day", text("created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz"))
# )
# )
# .params(tz=account.timezone) # 绑定参数
# )
sql_query = f"""SELECT {converted_created_at} AS date, COUNT(DISTINCT conversation_id) AS conversation_count
FROM messages WHERE app_id = :app_id AND invoke_from != :invoke_from AND (from_account_id = :user_id OR
from_end_user_id IN (SELECT DISTINCT(id) FROM end_users WHERE external_user_id = :user_id))"""
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER, "user_id": account.id}
if start_datetime_utc:
sql_query += " AND created_at >= :start"
+3
View File
@@ -9,6 +9,7 @@ from libs.helper import EmailStr, extract_remote_ip
from libs.password import valid_password
from models.model import DifySetup, db
from services.account_service import RegisterService, TenantService
from services.admin_initdb_service import trigger_admin_initdb_if_configured
from .error import AlreadySetupError, NotInitValidateError
from .init_validate import get_init_validate_status
@@ -82,6 +83,8 @@ def setup_system(payload: SetupRequestPayload) -> SetupResponse:
language=payload.language,
)
trigger_admin_initdb_if_configured(admin_password=payload.password)
return SetupResponse(result="success")
+4 -1
View File
@@ -48,7 +48,10 @@ def load_user_from_request(request_from_flask_login):
account.current_tenant = tenant
return account
if request.blueprint in {"console", "inner_api"}:
# extend: start fastopenapi路由认证,/console/api/ 路径无blueprint时按console处理
is_console_path = request.blueprint is None and request.path.startswith("/console/api/")
if request.blueprint in {"console", "inner_api"} or is_console_path:
# extend: stop fastopenapi路由认证,/console/api/ 路径无blueprint时按console处理
if not auth_token:
raise Unauthorized("Invalid Authorization token.")
decoded = PassportService().verify(auth_token)
+14
View File
@@ -35,6 +35,7 @@ from models.account import (
)
from models.account_money_extend import AccountMoneyExtend
from models.model import DifySetup
from services.account_service_extend import TenantExtendService
from services.billing_service import BillingService
from services.errors.account import (
AccountAlreadyInTenantError,
@@ -1313,6 +1314,19 @@ class RegisterService:
TenantService.create_owner_tenant_if_not_exist(account=account, is_setup=True)
# extend begin: admin 初始化不同步问题 - 在 dify 初始化完成后,自动调用 admin 初始化
# 将 setup 用户加入到第一个工作区(admin 租户),确保后续用户信息同步时权限正确
tenant_extend_service = TenantExtendService
super_admin_id = tenant_extend_service.get_super_admin_id().id
super_admin_tenant_id = tenant_extend_service.get_super_admin_tenant_id().id
if super_admin_id and super_admin_tenant_id:
is_create = TenantExtendService.create_default_tenant_member_if_not_exist(
super_admin_tenant_id, account.id
)
if is_create:
TenantService.switch_tenant(account, super_admin_tenant_id)
# extend end: admin 初始化不同步问题
dify_setup = DifySetup(version=dify_config.project.version)
db.session.add(dify_setup)
db.session.commit()
+105
View File
@@ -0,0 +1,105 @@
"""
Trigger gin-vue-admin InitDB after Dify console setup.
Uses the same DB_* settings as the API container and the admin password from the setup form.
Does not read or write local config files; outbound call only if ADMIN_INITDB_ENABLED is true.
"""
from __future__ import annotations
import logging
from typing import Any, Literal
import httpx
from configs import dify_config
logger = logging.getLogger(__name__)
_ADMIN_ALREADY_INIT_MSG = "已存在数据库配置"
_DEFAULT_TIMEOUT = httpx.Timeout(120.0, connect=10.0)
def _admin_db_type(db_type: str) -> Literal["pgsql", "mysql"]:
if db_type == "postgresql":
return "pgsql"
return "mysql"
def _build_payload(admin_password: str) -> dict[str, Any]:
db_type = _admin_db_type(dify_config.DB_TYPE)
return {
"adminPassword": admin_password,
"dbType": db_type,
"host": dify_config.DB_HOST,
"port": str(dify_config.DB_PORT),
"userName": dify_config.DB_USERNAME,
"password": dify_config.DB_PASSWORD,
"dbName": dify_config.DB_DATABASE,
"dbPath": "",
}
def _is_acceptable_admin_response(body: dict[str, Any]) -> bool:
code = body.get("code")
msg = body.get("msg") or ""
if code == 0:
return True
if msg == _ADMIN_ALREADY_INIT_MSG:
return True
return False
def trigger_admin_initdb_if_configured(*, admin_password: str) -> None:
"""
Best-effort POST to Admin InitDB. Logs warnings on failure; never raises.
"""
if not dify_config.ADMIN_INITDB_ENABLED:
return
if dify_config.DB_TYPE not in ("postgresql", "mysql", "oceanbase", "seekdb"):
logger.warning(
"skip admin initdb: unsupported DB_TYPE for admin bridge: %s",
dify_config.DB_TYPE,
)
return
url = (dify_config.ADMIN_INITDB_URL or "").strip()
if not url:
logger.warning("ADMIN_INITDB_ENABLED is true but ADMIN_INITDB_URL is empty, skip admin initdb")
return
payload = _build_payload(admin_password)
try:
with httpx.Client(timeout=_DEFAULT_TIMEOUT, follow_redirects=True) as client:
response = client.post(
url,
json=payload,
headers={"Content-Type": "application/json", "Accept": "application/json"},
)
except httpx.RequestError as exc:
logger.warning("admin initdb request failed: %s", exc)
return
if response.status_code != 200:
logger.warning(
"admin initdb HTTP %s: %s",
response.status_code,
(response.text or "")[:500],
)
return
try:
body = response.json()
except ValueError:
logger.warning("admin initdb returned non-JSON body: %s", (response.text or "")[:500])
return
if not isinstance(body, dict):
logger.warning("admin initdb returned unexpected JSON: %s", body)
return
if _is_acceptable_admin_response(body):
logger.info("admin initdb finished: %s", body.get("msg"))
return
logger.warning("admin initdb rejected: %s", body)
+4 -1
View File
@@ -693,7 +693,7 @@ ALIBABACLOUD_MYSQL_MAX_CONNECTION=5
ALIBABACLOUD_MYSQL_HNSW_M=6
# relyt configurations, only available when VECTOR_STORE is `relyt`
RELYT_HOST=db
RELYT_HOST=db_postgres
RELYT_PORT=5432
RELYT_USER=postgres
RELYT_PASSWORD=difyai123456
@@ -1283,8 +1283,11 @@ NGINX_SSL_PROTOCOLS=TLSv1.2 TLSv1.3
NGINX_WORKER_PROCESSES=auto
NGINX_CLIENT_MAX_BODY_SIZE=100M
NGINX_KEEPALIVE_TIMEOUT=65
NGINX_CLIENT_HEADER_TIMEOUT=1800s
NGINX_CLIENT_BODY_TIMEOUT=1800s
# Proxy settings
NGINX_PROXY_CONNECT_TIMEOUT=3600s
NGINX_PROXY_READ_TIMEOUT=3600s
NGINX_PROXY_SEND_TIMEOUT=3600s
+3
View File
@@ -419,6 +419,9 @@ services:
NGINX_WORKER_PROCESSES: ${NGINX_WORKER_PROCESSES:-auto}
NGINX_CLIENT_MAX_BODY_SIZE: ${NGINX_CLIENT_MAX_BODY_SIZE:-100M}
NGINX_KEEPALIVE_TIMEOUT: ${NGINX_KEEPALIVE_TIMEOUT:-65}
NGINX_CLIENT_HEADER_TIMEOUT: ${NGINX_CLIENT_HEADER_TIMEOUT:-1800s}
NGINX_CLIENT_BODY_TIMEOUT: ${NGINX_CLIENT_BODY_TIMEOUT:-1800s}
NGINX_PROXY_CONNECT_TIMEOUT: ${NGINX_PROXY_CONNECT_TIMEOUT:-3600s}
NGINX_PROXY_READ_TIMEOUT: ${NGINX_PROXY_READ_TIMEOUT:-3600s}
NGINX_PROXY_SEND_TIMEOUT: ${NGINX_PROXY_SEND_TIMEOUT:-3600s}
NGINX_ENABLE_CERTBOT_CHALLENGE: ${NGINX_ENABLE_CERTBOT_CHALLENGE:-false}
+26 -12
View File
@@ -557,6 +557,9 @@ x-shared-env: &shared-api-worker-env
NGINX_WORKER_PROCESSES: ${NGINX_WORKER_PROCESSES:-auto}
NGINX_CLIENT_MAX_BODY_SIZE: ${NGINX_CLIENT_MAX_BODY_SIZE:-100M}
NGINX_KEEPALIVE_TIMEOUT: ${NGINX_KEEPALIVE_TIMEOUT:-65}
NGINX_CLIENT_HEADER_TIMEOUT: ${NGINX_CLIENT_HEADER_TIMEOUT:-1800s}
NGINX_CLIENT_BODY_TIMEOUT: ${NGINX_CLIENT_BODY_TIMEOUT:-1800s}
NGINX_PROXY_CONNECT_TIMEOUT: ${NGINX_PROXY_CONNECT_TIMEOUT:-3600s}
NGINX_PROXY_READ_TIMEOUT: ${NGINX_PROXY_READ_TIMEOUT:-3600s}
NGINX_PROXY_SEND_TIMEOUT: ${NGINX_PROXY_SEND_TIMEOUT:-3600s}
NGINX_ENABLE_CERTBOT_CHALLENGE: ${NGINX_ENABLE_CERTBOT_CHALLENGE:-false}
@@ -717,7 +720,7 @@ services:
# API service
api:
image: ccr.ccs.tencentyun.com/yfgaia/dify-plus-api:1.12.1
image: ccr.ccs.tencentyun.com/yfgaia/dify-plus-api:1.12.1.fix.3.2
restart: always
environment:
# Use the shared environment variables.
@@ -734,6 +737,9 @@ services:
INNER_API_KEY_FOR_PLUGIN: ${PLUGIN_DIFY_INNER_API_KEY:-QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1}
FULL_CODE_EXECUTION_ENDPOINT: ${FULL_CODE_EXECUTION_ENDPOINT:-http://sandbox-full:8194}
ALLOW_REGISTER: ${ALLOW_REGISTER:-True}
# 安装向导完成后由内网调用 Admin InitDB(库连接与 DB_* 一致,无需用户改本地配置文件)
ADMIN_INITDB_ENABLED: ${ADMIN_INITDB_ENABLED:-true}
ADMIN_INITDB_URL: ${ADMIN_INITDB_URL:-http://admin-server:8888/init/initdb}
depends_on:
init_permissions:
condition: service_completed_successfully
@@ -761,7 +767,7 @@ services:
# worker service
# The Celery worker for processing all queues (dataset, workflow, mail, etc.)
worker:
image: ccr.ccs.tencentyun.com/yfgaia/dify-plus-api:1.12.1
image: ccr.ccs.tencentyun.com/yfgaia/dify-plus-api:1.12.1.fix.3.2
restart: always
environment:
# Use the shared environment variables.
@@ -800,7 +806,7 @@ services:
# worker-gaia service
# The Celery worker-gaia for processing the queue.
worker-gaia:
image: ccr.ccs.tencentyun.com/yfgaia/dify-plus-api:1.12.1
image: ccr.ccs.tencentyun.com/yfgaia/dify-plus-api:1.12.1.fix.3.2
restart: always
environment:
# Use the shared environment variables.
@@ -826,7 +832,7 @@ services:
# worker-dataset service
# The Celery worker-dataset for processing the queue.
worker-dataset:
image: ccr.ccs.tencentyun.com/yfgaia/dify-plus-api:1.12.1
image: ccr.ccs.tencentyun.com/yfgaia/dify-plus-api:1.12.1.fix.3.2
restart: always
environment:
# Use the shared environment variables.
@@ -852,7 +858,7 @@ services:
# worker_beat service
# Celery beat for scheduling periodic tasks.
worker_beat:
image: ccr.ccs.tencentyun.com/yfgaia/dify-plus-api:1.12.1
image: ccr.ccs.tencentyun.com/yfgaia/dify-plus-api:1.12.1.fix.3.2
restart: always
environment:
# Use the shared environment variables.
@@ -882,7 +888,7 @@ services:
# Frontend web application.
web:
image: ccr.ccs.tencentyun.com/yfgaia/dify-plus-web:1.12.1
image: ccr.ccs.tencentyun.com/yfgaia/dify-plus-web:1.12.1.fix.3.2
restart: always
environment:
CONSOLE_API_URL: ${CONSOLE_API_URL:-}
@@ -943,8 +949,6 @@ services:
interval: 1s
timeout: 3s
retries: 60
ports:
- 5432:5432
# The mysql database.
db_mysql:
@@ -993,8 +997,6 @@ services:
"CMD-SHELL",
"redis-cli -a ${REDIS_PASSWORD:-difyai123456} ping | grep -q PONG",
]
ports:
- 6379:6379
# The DifySandbox
sandbox:
@@ -1171,6 +1173,9 @@ services:
NGINX_WORKER_PROCESSES: ${NGINX_WORKER_PROCESSES:-auto}
NGINX_CLIENT_MAX_BODY_SIZE: ${NGINX_CLIENT_MAX_BODY_SIZE:-100M}
NGINX_KEEPALIVE_TIMEOUT: ${NGINX_KEEPALIVE_TIMEOUT:-65}
NGINX_CLIENT_HEADER_TIMEOUT: ${NGINX_CLIENT_HEADER_TIMEOUT:-1800s}
NGINX_CLIENT_BODY_TIMEOUT: ${NGINX_CLIENT_BODY_TIMEOUT:-1800s}
NGINX_PROXY_CONNECT_TIMEOUT: ${NGINX_PROXY_CONNECT_TIMEOUT:-3600s}
NGINX_PROXY_READ_TIMEOUT: ${NGINX_PROXY_READ_TIMEOUT:-3600s}
NGINX_PROXY_SEND_TIMEOUT: ${NGINX_PROXY_SEND_TIMEOUT:-3600s}
NGINX_ENABLE_CERTBOT_CHALLENGE: ${NGINX_ENABLE_CERTBOT_CHALLENGE:-false}
@@ -1667,7 +1672,7 @@ services:
# Extend - admin-web
admin-web:
image: ccr.ccs.tencentyun.com/yfgaia/dify-plus-admin-web:1.12.1
image: ccr.ccs.tencentyun.com/yfgaia/dify-plus-admin-web:1.12.1.fix.3.2
restart: always
ports:
- '8081:8081'
@@ -1679,13 +1684,22 @@ services:
# Extend - admin-server
admin-server:
image: ccr.ccs.tencentyun.com/yfgaia/dify-plus-admin-server:1.12.1
image: ccr.ccs.tencentyun.com/yfgaia/dify-plus-admin-server:1.12.1.fix.3.3
restart: always
environment:
# Use the shared environment variables.
<<: *shared-api-worker-env
# JWT signing key must match API's SECRET_KEY for token compatibility
JWT_SIGNING_KEY: ${SECRET_KEY:-sk-9f73s3ljTXVcMT3Blb3ljTqtsKiGHXVcMT3BlbkFJLK7U}
SECRET_KEY: ${SECRET_KEY:-sk-9f73s3ljTXVcMT3Blb3ljTqtsKiGHXVcMT3BlbkFJLK7U}
EMAIL_DOMAIN: ${EMAIL_DOMAIN:-}
# 与 Dify 使用同一套数据库配置,避免密码不一致导致认证失败
DB_TYPE: ${DB_TYPE:-postgresql}
DB_HOST: ${DB_HOST:-db_postgres}
DB_PORT: ${DB_PORT:-5432}
DB_USERNAME: ${DB_USERNAME:-postgres}
DB_PASSWORD: ${DB_PASSWORD:-difyai123456}
DB_DATABASE: ${DB_DATABASE:-dify}
ports:
- '8888:8888'
depends_on:
-3
View File
@@ -2,9 +2,6 @@ services:
# The postgres database.
db_postgres:
image: postgres:15-alpine
profiles:
- ""
- postgresql
restart: always
env_file:
- ./middleware.env
+6
View File
@@ -554,6 +554,9 @@ x-shared-env: &shared-api-worker-env
NGINX_WORKER_PROCESSES: ${NGINX_WORKER_PROCESSES:-auto}
NGINX_CLIENT_MAX_BODY_SIZE: ${NGINX_CLIENT_MAX_BODY_SIZE:-100M}
NGINX_KEEPALIVE_TIMEOUT: ${NGINX_KEEPALIVE_TIMEOUT:-65}
NGINX_CLIENT_HEADER_TIMEOUT: ${NGINX_CLIENT_HEADER_TIMEOUT:-1800s}
NGINX_CLIENT_BODY_TIMEOUT: ${NGINX_CLIENT_BODY_TIMEOUT:-1800s}
NGINX_PROXY_CONNECT_TIMEOUT: ${NGINX_PROXY_CONNECT_TIMEOUT:-3600s}
NGINX_PROXY_READ_TIMEOUT: ${NGINX_PROXY_READ_TIMEOUT:-3600s}
NGINX_PROXY_SEND_TIMEOUT: ${NGINX_PROXY_SEND_TIMEOUT:-3600s}
NGINX_ENABLE_CERTBOT_CHALLENGE: ${NGINX_ENABLE_CERTBOT_CHALLENGE:-false}
@@ -1105,6 +1108,9 @@ services:
NGINX_WORKER_PROCESSES: ${NGINX_WORKER_PROCESSES:-auto}
NGINX_CLIENT_MAX_BODY_SIZE: ${NGINX_CLIENT_MAX_BODY_SIZE:-100M}
NGINX_KEEPALIVE_TIMEOUT: ${NGINX_KEEPALIVE_TIMEOUT:-65}
NGINX_CLIENT_HEADER_TIMEOUT: ${NGINX_CLIENT_HEADER_TIMEOUT:-1800s}
NGINX_CLIENT_BODY_TIMEOUT: ${NGINX_CLIENT_BODY_TIMEOUT:-1800s}
NGINX_PROXY_CONNECT_TIMEOUT: ${NGINX_PROXY_CONNECT_TIMEOUT:-3600s}
NGINX_PROXY_READ_TIMEOUT: ${NGINX_PROXY_READ_TIMEOUT:-3600s}
NGINX_PROXY_SEND_TIMEOUT: ${NGINX_PROXY_SEND_TIMEOUT:-3600s}
NGINX_ENABLE_CERTBOT_CHALLENGE: ${NGINX_ENABLE_CERTBOT_CHALLENGE:-false}
@@ -3,6 +3,9 @@
server {
listen ${NGINX_PORT};
server_name ${NGINX_SERVER_NAME};
keepalive_timeout 1800s;
client_header_timeout ${NGINX_CLIENT_HEADER_TIMEOUT};
client_body_timeout ${NGINX_CLIENT_BODY_TIMEOUT};
# 管理中心反向代理配置
location = /admin {
return 301 /admin/;
+1
View File
@@ -7,5 +7,6 @@ proxy_set_header X-Forwarded-Port $server_port;
proxy_http_version 1.1;
proxy_set_header Connection "";
proxy_buffering off;
proxy_connect_timeout ${NGINX_PROXY_CONNECT_TIMEOUT};
proxy_read_timeout ${NGINX_PROXY_READ_TIMEOUT};
proxy_send_timeout ${NGINX_PROXY_SEND_TIMEOUT};
@@ -153,4 +153,43 @@ INSERT INTO casbin_rule (ptype, v0, v1, v2) VALUES
INSERT INTO sys_authority_menus (sys_authority_authority_id, sys_base_menu_id) VALUES (888, 42);
-- --------------- 9. API sys_apis (转发集成:转发 Token 管理 3 条) 2026-03-09 18:08:33 ---------------
-- 请按当前库最大 id 调整起始 id,避免冲突。例如 MAX(id)=269 则从 270 起
INSERT INTO sys_apis (id, created_at, updated_at, deleted_at, path, description, api_group, method) VALUES
(270, NOW(), NOW(), NULL, '/gaia/system/forward-tokens', '获取转发 Token 列表', '转发集成', 'GET'),
(271, NOW(), NOW(), NULL, '/gaia/system/forward-tokens', '新增转发 Token', '转发集成', 'POST'),
(272, NOW(), NOW(), NULL, '/gaia/system/forward-tokens/:id', '删除转发 Token', '转发集成', 'DELETE');
-- --------------- 10. Casbin 规则 casbin_rule (转发集成 888/8881) ---------------
INSERT INTO casbin_rule (ptype, v0, v1, v2) VALUES
('p', '888', '/gaia/system/forward-tokens', 'GET'),
('p', '888', '/gaia/system/forward-tokens', 'POST'),
('p', '888', '/gaia/system/forward-tokens/:id', 'DELETE'),
('p', '8881', '/gaia/system/forward-tokens', 'GET'),
('p', '8881', '/gaia/system/forward-tokens', 'POST'),
('p', '8881', '/gaia/system/forward-tokens/:id', 'DELETE');
-- --------------- 11. API sys_apis (钉钉邮箱配置测试 1 条) ---------------
-- 请按当前库最大 id 调整起始 id,避免冲突。例如 MAX(id)=272 则从 273 起
INSERT INTO sys_apis (id, created_at, updated_at, deleted_at, path, description, api_group, method) VALUES
(273, NOW(), NOW(), NULL, '/gaia/system/dingtalk/test-email-config', '测试钉钉邮箱API配置', '应用集成配置', 'POST');
-- --------------- 12. Casbin 规则 casbin_rule (钉钉邮箱配置测试 888) ---------------
INSERT INTO casbin_rule (ptype, v0, v1, v2) VALUES
('p', '888', '/gaia/system/dingtalk/test-email-config', 'POST');
-- --------------- 13. API sys_apis (钉钉测试连接:授权 URL + 回调验证 2 条) ---------------
-- 请按当前库最大 id 调整起始 id,避免冲突。例如 MAX(id)=273 则从 274 起
INSERT INTO sys_apis (id, created_at, updated_at, deleted_at, path, description, api_group, method) VALUES
(274, NOW(), NOW(), NULL, '/gaia/system/dingtalk/test-auth-url', '测试连接-获取钉钉授权URL', '应用集成配置', 'GET'),
(275, NOW(), NOW(), NULL, '/gaia/system/dingtalk/test-callback', '测试连接-钉钉回调验证', '应用集成配置', 'POST');
-- --------------- 14. Casbin 规则 casbin_rule (钉钉测试连接 888) ---------------
INSERT INTO casbin_rule (ptype, v0, v1, v2) VALUES
('p', '888', '/gaia/system/dingtalk/test-auth-url', 'GET'),
('p', '888', '/gaia/system/dingtalk/test-callback', 'POST');
@@ -1,6 +1,6 @@
'use client'
import type { FC } from 'react'
import React from 'react'
import * as React from 'react'
import { useTranslation } from 'react-i18next'
import ParamItem from '.'
@@ -27,7 +27,7 @@ const DayLimitItemExtend: FC<Props> = ({
}) => {
const { t } = useTranslation()
const handleParamChange = (key: string, value: number) => {
let notOutRangeValue = parseFloat(value.toFixed(2))
let notOutRangeValue = Number.parseFloat(value.toFixed(2))
notOutRangeValue = Math.max(VALUE_LIMIT.min, notOutRangeValue)
notOutRangeValue = Math.min(VALUE_LIMIT.max, notOutRangeValue)
onChange(key, notOutRangeValue)
@@ -36,8 +36,8 @@ const DayLimitItemExtend: FC<Props> = ({
<ParamItem
className={className}
id={key}
name={t('extend.apiKeyModal.dayLimitItemName')}
tip={t('extend.apiKeyModal.noLimitTips') as string}
name={t('apiKeyModal.dayLimitItemName', { ns: 'extend' })}
tip={t('apiKeyModal.noLimitTips', { ns: 'extend' }) as string}
{...VALUE_LIMIT}
value={value}
enable={enable}
@@ -27,7 +27,7 @@ const MonthLimitItemExtend: FC<Props> = ({
}) => {
const { t } = useTranslation()
const handleParamChange = (key: string, value: number) => {
let notOutRangeValue = parseFloat(value.toFixed(2))
let notOutRangeValue = Number.parseFloat(value.toFixed(2))
notOutRangeValue = Math.max(VALUE_LIMIT.min, notOutRangeValue)
notOutRangeValue = Math.min(VALUE_LIMIT.max, notOutRangeValue)
onChange(key, notOutRangeValue)
@@ -36,8 +36,8 @@ const MonthLimitItemExtend: FC<Props> = ({
<ParamItem
className={className}
id={key}
name={t('extend.apiKeyModal.monthLimitItemName')}
tip={t('extend.apiKeyModal.noLimitTips') as string}
name={t('apiKeyModal.monthLimitItemName', { ns: 'extend' })}
tip={t('apiKeyModal.noLimitTips', { ns: 'extend' }) as string}
{...VALUE_LIMIT}
value={value}
enable={enable}
@@ -1,5 +1,5 @@
'use client'
import type { CreateApiKeyResponse } from '@/models/app'
import type { ApiKeyItemResponse, ApikeyItemResponseWithQuotaLimitExtend, CreateApiKeyResponse } from '@/models/app'
import { PlusIcon, XMarkIcon } from '@heroicons/react/20/solid'
import { RiDeleteBinLine } from '@remixicon/react'
import {
@@ -17,14 +17,18 @@ import useTimestamp from '@/hooks/use-timestamp'
import {
createApikey as createAppApikey,
delApikey as delAppApikey,
editApikey, // 二开部分 - 密钥额度
} from '@/service/apps'
import {
createApikey as createDatasetApikey,
delApikey as delDatasetApikey,
} from '@/service/datasets'
// 二开部分End - 密钥额度
import { useDatasetApiKeys, useInvalidateDatasetApiKeys } from '@/service/knowledge/use-dataset'
import { useAppApiKeys, useInvalidateAppApiKeys } from '@/service/use-apps'
import SecretKeyGenerateModal from './secret-key-generate'
// 二开部分Start - 密钥额度
import SecretKeyQuotaSetExtendModal from './secret-key-quota-set-modal-extend'
import s from './style.module.css'
type ISecretKeyModalProps = {
@@ -50,6 +54,65 @@ const SecretKeyModal = ({
const { data: datasetApiKeys, isLoading: isDatasetApiKeysLoading } = useDatasetApiKeys({ enabled: !appId && isShow })
const apiKeysList = appId ? appApiKeys : datasetApiKeys
const isApiKeysLoading = appId ? isAppApiKeysLoading : isDatasetApiKeysLoading
// ---------------------- 二开部分Begin - 密钥额度 ----------------------
const [isVisibleExtend, setVisibleExtend] = useState(false)
const [keyItem, setKeyItem] = useState<ApikeyItemResponseWithQuotaLimitExtend>({
created_at: '',
id: '',
last_used_at: '',
token: '',
description: '',
day_limit_quota: -1,
month_limit_quota: -1,
})
// 打开新增密钥额度编辑框
const openSecretKeyQuotaSetModalExtend = async () => {
setVisibleExtend(true)
setKeyItem({
created_at: '',
id: '',
last_used_at: '',
token: '',
description: '',
day_limit_quota: -1,
month_limit_quota: -1,
})
}
// 打开编辑密钥额度编辑框
const openSecretKeyQuotaEditModalExtend = async (api: ApiKeyItemResponse) => {
setVisibleExtend(true)
setKeyItem({
created_at: api.created_at,
id: api.id,
last_used_at: api.last_used_at,
token: api.token,
description: api.description,
day_limit_quota: api.day_limit_quota,
month_limit_quota: api.month_limit_quota,
})
}
// 设置密钥额度数据
const handleSetKeyDataSetQuotas = (newKeyItems: ApikeyItemResponseWithQuotaLimitExtend) => {
setKeyItem(newKeyItems)
}
// ---------------------- 二开部分End - 密钥额度 ----------------------
// 二开部分 Begin - 密钥额度限制编辑
const onEdit = async () => {
const params = {
url: `/apps/${appId}/api-keys`,
body: {
id: keyItem.id,
description: keyItem.description,
day_limit_quota: keyItem.day_limit_quota,
month_limit_quota: keyItem.month_limit_quota,
},
}
const res = await editApikey(params)
setVisibleExtend(false)
setNewKey(res)
}
// 二开部分 Begin - 密钥额度限制编辑
const [delKeyID, setDelKeyId] = useState('')
@@ -101,6 +164,12 @@ const SecretKeyModal = ({
<div className="w-64 shrink-0 px-3">{t('apiKeyModal.secretKey', { ns: 'appApi' })}</div>
<div className="w-[200px] shrink-0 px-3">{t('apiKeyModal.created', { ns: 'appApi' })}</div>
<div className="w-[200px] shrink-0 px-3">{t('apiKeyModal.lastUsed', { ns: 'appApi' })}</div>
{/* ---------------------- 二开部分Begin - 密钥额度限制 ---------------------- */}
<div className="w-[100px] shrink-0 px-3">{t('apiKeyModal.descriptionPlaceholder', { ns: 'extend' })}</div>
<div className="w-[200px] shrink-0 px-3">{t('apiKeyModal.dayLimit', { ns: 'extend' })}</div>
<div className="w-[200px] shrink-0 px-3">{t('apiKeyModal.monthLimit', { ns: 'extend' })}</div>
<div className="w-[200px] shrink-0 px-3">{t('apiKeyModal.accumulatedLimit', { ns: 'extend' })}</div>
{/* ---------------------- 二开部分End - 密钥额度限制 ---------------------- */}
<div className="grow px-3"></div>
</div>
<div className="grow overflow-auto">
@@ -109,6 +178,27 @@ const SecretKeyModal = ({
<div className="w-64 shrink-0 truncate px-3 font-mono">{generateToken(api.token)}</div>
<div className="w-[200px] shrink-0 truncate px-3">{formatTime(Number(api.created_at), t('dateTimeFormat', { ns: 'appLog' }) as string)}</div>
<div className="w-[200px] shrink-0 truncate px-3">{api.last_used_at ? formatTime(Number(api.last_used_at), t('dateTimeFormat', { ns: 'appLog' }) as string) : t('never', { ns: 'appApi' })}</div>
{/* ---------------------- 二开部分Begin - 密钥额度限制 ---------------------- */}
<div className="w-[100px] shrink-0 truncate px-3">{api.description}</div>
<div className="w-[200px] shrink-0 truncate px-3">
$
{api.day_used_quota}
{' '}
/
{api.day_limit_quota === -1 ? t('apiKeyModal.noLimit', { ns: 'extend' }) : `$ ${api.day_limit_quota}`}
</div>
<div className="w-[200px] shrink-0 truncate px-3">
$
{api.month_used_quota}
{' '}
/
{api.month_limit_quota === -1 ? t('apiKeyModal.noLimit', { ns: 'extend' }) : `$ ${api.month_limit_quota}`}
</div>
<div className="w-[200px] shrink-0 truncate px-3">
$
{api.accumulated_quota}
</div>
{/* ---------------------- 二开部分End - 密钥额度限制 ---------------------- */}
<div className="flex grow space-x-2 px-3">
<CopyFeedback content={api.token} />
{isCurrentWorkspaceManager && (
@@ -121,6 +211,17 @@ const SecretKeyModal = ({
<RiDeleteBinLine className="h-4 w-4" />
</ActionButton>
)}
{/* // 二开部分 End - 密钥额度限制编辑 */}
{isCurrentWorkspaceManager && (
<div
className={`flex h-6 w-6 shrink-0 cursor-pointer items-center justify-center rounded-lg ${s.editIcon}`}
onClick={() => {
openSecretKeyQuotaEditModalExtend(api).then()
}}
>
</div>
)}
{/* // 二开部分 Begin - 密钥额度限制编辑 */}
</div>
</div>
))}
@@ -147,6 +248,10 @@ const SecretKeyModal = ({
}}
/>
)}
{/* ----------------------二开部分Begin - 密钥额度限制---------------------- */}
<SecretKeyQuotaSetExtendModal className="shrink-0" isShow={isVisibleExtend} onClose={() => setVisibleExtend(false)} newKey={keyItem} onChange={handleSetKeyDataSetQuotas} onCreate={keyItem.id === '' ? onCreate : onEdit} />
{/* ----------------------二开部分End - 密钥额度限制---------------------- */}
</Modal>
)
}
@@ -1,12 +1,12 @@
'use client'
import { useTranslation } from 'react-i18next'
import type { ApikeyItemResponseWithQuotaLimitExtend } from '@/models/app'
import { XMarkIcon } from '@heroicons/react/20/solid'
import s from './style.module.css'
import { useTranslation } from 'react-i18next'
import Button from '@/app/components/base/button'
import Modal from '@/app/components/base/modal'
import type { ApikeyItemResponseWithQuotaLimitExtend } from '@/models/app'
import DayLimitItemExtend from '@/app/components/base/param-item/day-limit-item-extend'
import MonthLimitItemExtend from '@/app/components/base/param-item/month-limit-item-extend'
import s from './style.module.css'
type ISecretKeyGenerateModalProps = {
isShow: boolean
@@ -53,35 +53,41 @@ const SecretKeyQuotaSetExtendModal = ({
}
return (
<Modal isShow={isShow} onClose={onClose} title={(newKey?.id ? '编辑' : '创建')+ `${t('appApi.apiKeyModal.apiSecretKey')}`}
className={`px-8 ${className}`}>
<XMarkIcon className={`w-6 h-6 absolute cursor-pointer text-gray-500 ${s.close}`} onClick={onClose}/>
<p className='mt-1 text-[13px] text-gray-500 font-normal leading-5'>{t('extend.apiKeyModal.apiSecretKeyTips')}</p>
<div className='my-4'>
<Modal
isShow={isShow}
onClose={onClose}
title={`${newKey?.id ? '编辑' : '创建'}${t('apiKeyModal.apiSecretKey', { ns: 'appApi' })}`}
className={`px-8 ${className}`}
>
<XMarkIcon className={`absolute h-6 w-6 cursor-pointer text-gray-500 ${s.close}`} onClick={onClose} />
<p className="mt-1 text-[13px] font-normal leading-5 text-gray-500">
{t('apiKeyModal.apiSecretKeyTips', { ns: 'extend' })}
</p>
<div className="my-4">
<input
value={newKey?.description ?? ''}
onChange={e => handleParamChangeDesc(e.target.value)}
placeholder={t('extend.apiKeyModal.descriptionPlaceholder') || '密钥用途'}
className='grow h-10 px-3 text-sm font-normal bg-gray-100 rounded-lg border border-transparent outline-none appearance-none caret-primary-600 placeholder:text-gray-400 hover:bg-gray-50 hover:border hover:border-gray-300 focus:bg-gray-50 focus:border focus:border-gray-300 focus:shadow-xs'
placeholder={t('apiKeyModal.descriptionPlaceholder', { ns: 'extend' }) || '密钥用途'}
className="h-10 grow appearance-none rounded-lg border border-transparent bg-gray-100 px-3 text-sm font-normal caret-primary-600 outline-none placeholder:text-gray-400 hover:border hover:border-gray-300 hover:bg-gray-50 focus:border focus:border-gray-300 focus:bg-gray-50 focus:shadow-xs"
/>
</div>
<div className='my-4'>
<div className="my-4">
<DayLimitItemExtend
value={newKey?.day_limit_quota ?? -1}
onChange={handleParamChange}
enable={true}
/>
</div>
<div className='my-4'>
<div className="my-4">
<MonthLimitItemExtend
value={newKey?.month_limit_quota ?? -1}
onChange={handleParamChange}
enable={true}
/>
</div>
<div className='flex justify-end my-4'>
<Button variant='primary' className={`flex-shrink-0 ${s.w64}`} onClick={onCreate}>
{newKey?.id ? t('common.operation.save') : t('common.operation.create')}
<div className="my-4 flex justify-end">
<Button variant="primary" className={`shrink-0 ${s.w64}`} onClick={onCreate}>
{newKey?.id ? t('operation.save', { ns: 'common' }) : t('operation.create', { ns: 'common' })}
</Button>
</div>
@@ -55,3 +55,17 @@
.copyIcon.copied {
background-image: url(./assets/copied.svg);
}
/* 二开部分Begin - 密钥额度限制 */
.editIcon {
background-image: url(./assets/edit.svg);
background-position: center;
background-repeat: no-repeat;
}
.editIcon:hover {
background-image: url(./assets/edit-hover.svg);
background-position: center;
background-repeat: no-repeat;
}
/* 二开部分End - 密钥额度限制 */
@@ -180,12 +180,19 @@ const TextGeneration: FC<IMainProps> = ({
const batchJobsLimit = 5 // 每页5个任务
const [totalBatchJobs, setTotalBatchJobs] = useState(0)
const [isLoadingBatchJobs, setIsLoadingBatchJobs] = useState(false)
const lastRefreshTimeRef = useRef(0) // 记录上次刷新时间,避免频繁刷新
// 从后端获取批量工作流列表
const loadBatchWorkflows = useCallback(async () => {
const loadBatchWorkflows = useCallback(async (force = false) => {
if (!appId || currentTab !== 'batch')
return
// 防止过于频繁的刷新(至少间隔 1 秒)
const now = Date.now()
if (!force && now - lastRefreshTimeRef.current < 1000)
return
lastRefreshTimeRef.current = now
setIsLoadingBatchJobs(true)
try {
const result = await fetchBatchWorkflowListApi(installedAppInfo?.id, currentPage, batchJobsLimit)
@@ -218,25 +225,9 @@ const TextGeneration: FC<IMainProps> = ({
loadBatchWorkflows()
}, [loadBatchWorkflows])
// 自动刷新批量工作流列表(每3秒)
useEffect(() => {
if (currentTab !== 'batch' || batchJobs.length === 0)
return
// 注意:不再需要自动刷新逻辑,因为每个批量任务现在自己管理进度刷新
// 每个 BatchProgress 组件会独立轮询自己的进度(每 3 秒)
// 检查是否有进行中的任务
const hasActiveJobs = batchJobs.some(job =>
job.status === 'pending' || job.status === 'processing',
)
if (!hasActiveJobs)
return
const refreshInterval = setInterval(() => {
loadBatchWorkflows()
}, 3000) // 每3秒刷新一次
return () => clearInterval(refreshInterval)
}, [currentTab, batchJobs, loadBatchWorkflows])
// 计算分页数据 - 现在数据已经是从后端分页获取的,不需要再切片
const paginatedBatchJobs = batchJobs
@@ -472,6 +463,28 @@ const TextGeneration: FC<IMainProps> = ({
loadBatchWorkflows()
console.log('批量任务重试成功,已刷新列表')
}
// 处理单个任务进度更新(只更新列表中的对应项,不刷新整个列表)
const handleJobUpdate = useCallback((jobId: string, updatedData: { status: string, processedRows: number, error?: string }) => {
setBatchJobs(prevJobs => {
// 检查是否真的有变化
const job = prevJobs.find(j => j.id === jobId)
if (!job)
return prevJobs
// 如果没有变化,不更新
if (job.status === updatedData.status && job.processedRows === updatedData.processedRows && job.error === updatedData.error)
return prevJobs
// 有变化才更新
return prevJobs.map(job =>
job.id === jobId
? { ...job, ...updatedData }
: job
)
})
}, [])
// Extend: Stop Batch import
const handleCompleted = (completionRes: string, taskId?: number, isSuccess?: boolean) => {
@@ -662,6 +675,7 @@ const TextGeneration: FC<IMainProps> = ({
jobData={job}
onDownload={() => handleBatchDownload(job.id)}
onRetrySuccess={handleRetrySuccess}
onJobUpdate={(updatedData) => handleJobUpdate(job.id, updatedData)}
/>
))
) : (
@@ -11,7 +11,7 @@ import {
RiRefreshLine,
RiStopLine,
} from '@remixicon/react'
import { resumeBatchApi, retryFailedTasksApi, stopBatchApi } from '@/service/web-extend' // extend: 批量运行工单
import { fetchProgressApi, resumeBatchApi, retryFailedTasksApi, stopBatchApi } from '@/service/web-extend' // extend: 批量运行工单
import type { BatchStatus } from '@/utils/batch-progress-manager' // extend: 批量运行工单
import ActionButton from '@/app/components/base/action-button'
@@ -32,6 +32,7 @@ export type BatchProgressProps = {
}
onDownload: () => void
onRetrySuccess?: () => void
onJobUpdate?: (jobData: { status: string, processedRows: number, error?: string }) => void // 新增:任务更新回调
}
const BatchProgress: FC<BatchProgressProps> = ({
@@ -41,10 +42,62 @@ const BatchProgress: FC<BatchProgressProps> = ({
jobData,
onDownload,
onRetrySuccess,
onJobUpdate,
}) => {
const { t } = useTranslation()
const [isLoading, setIsLoading] = useState(false)
// 本地进度状态,用于独立刷新
const [localProgress, setLocalProgress] = useState({
status: jobData.status,
processedRows: jobData.processedRows,
totalRows: jobData.totalRows,
error: jobData.error,
})
// 自动刷新单个任务的进度(每 3 秒)
useEffect(() => {
// 只在任务进行中时刷新
if (localProgress.status !== 'pending' && localProgress.status !== 'processing')
return
const refreshInterval = setInterval(async () => {
try {
const progress = await fetchProgressApi(batchId)
if (progress) {
const newStatus = progress.status as string
const newProcessedRows = progress.processed_rows as number
const newError = progress.error as string | undefined
// 只有当数据有变化时才更新
if (
newStatus !== localProgress.status
|| newProcessedRows !== localProgress.processedRows
|| newError !== localProgress.error
) {
const updatedProgress = {
status: newStatus,
processedRows: newProcessedRows,
totalRows: progress.total_rows as number,
error: newError,
}
setLocalProgress(updatedProgress)
// 通知父组件数据已更新(用于列表级别的状态同步)
onJobUpdate?.({
status: newStatus,
processedRows: newProcessedRows,
error: newError,
})
}
}
}
catch (error) {
console.error('Failed to fetch progress:', error)
}
}, 3000)
return () => clearInterval(refreshInterval)
}, [batchId, localProgress.status, localProgress.processedRows, localProgress.error, onJobUpdate])
// 停止批量处理
const handleStop = async () => {
@@ -160,8 +213,8 @@ const BatchProgress: FC<BatchProgressProps> = ({
})
// 计算进度
const progress = jobData.totalRows > 0 ? (jobData.processedRows / jobData.totalRows) * 100 : 0
const status = jobData.status as BatchStatus
const progress = localProgress.totalRows > 0 ? (localProgress.processedRows / localProgress.totalRows) * 100 : 0
const status = localProgress.status as BatchStatus
const failed_count = 0 // 从列表API没有这个字段,如果需要可以后续添加
const getBorderColor = (status: BatchStatus) => {
@@ -232,18 +285,18 @@ const BatchProgress: FC<BatchProgressProps> = ({
</div>
{/* 详细进度信息 */}
{jobData.totalRows > 0 && (
{localProgress.totalRows > 0 && (
<div className="mt-2 text-xs text-gray-500">
{t('batchWorkflow.processed', {
processed: jobData.processedRows || 0,
total: jobData.totalRows || 0,
processed: localProgress.processedRows || 0,
total: localProgress.totalRows || 0,
ns: 'extend',
})}
</div>
)}
{/* 错误信息显示 */}
{jobData.error && status === 'failed' && (
{localProgress.error && status === 'failed' && (
<div className="mt-3 rounded-lg border border-red-200 bg-red-50 p-3">
<div className="flex items-start space-x-2">
<RiErrorWarningLine className="h-4 w-4 text-red-500 mt-0.5 flex-shrink-0" />
@@ -252,7 +305,7 @@ const BatchProgress: FC<BatchProgressProps> = ({
{t('batchWorkflow.errorOccurred', { ns: 'extend'} )}
</div>
<div className="text-xs text-red-700 break-words">
{jobData.error}
{localProgress.error}
</div>
</div>
</div>
+2 -2
View File
@@ -97,7 +97,7 @@
"dompurify": "3.3.0",
"echarts": "5.6.0",
"echarts-for-react": "3.0.5",
"elkjs": "0.9.3",
"elkjs": "0.11.0",
"embla-carousel-autoplay": "8.6.0",
"embla-carousel-react": "8.6.0",
"emoji-mart": "5.6.0",
@@ -167,7 +167,7 @@
"devDependencies": {
"@antfu/eslint-config": "7.2.0",
"@chromatic-com/storybook": "5.0.0",
"@eslint-react/eslint-plugin": "2.8.1",
"@eslint-react/eslint-plugin": "2.12.4",
"@mdx-js/loader": "3.1.1",
"@mdx-js/react": "3.1.1",
"@next/bundle-analyzer": "16.1.5",