mirror of
https://github.com/YFGaia/dify-plus.git
synced 2026-06-14 20:41:21 +08:00
Compare commits
49 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| cb0233f7b6 | |||
| 6c0c6ba1fe | |||
| 3240e6e4e5 | |||
| 8d823787b7 | |||
| 84fc7bcb43 | |||
| e08b1b079c | |||
| 3f6ef97148 | |||
| 33a34e4181 | |||
| 347e48cef6 | |||
| d2a7ade1b0 | |||
| efc25217dc | |||
| 618a355ec8 | |||
| d474f15673 | |||
| 3bad30bff1 | |||
| ea77171028 | |||
| bb1db4ca99 | |||
| 5618c89721 | |||
| 3a02769e4a | |||
| 0af791f56c | |||
| 0b20a17074 | |||
| d13e083f37 | |||
| 1cc7f4bc7b | |||
| 8341905b21 | |||
| b5af9263f8 | |||
| 950ef2d13e | |||
| 8257113c50 | |||
| fff9543a37 | |||
| 02e568c6d5 | |||
| 2bc3e4dc39 | |||
| d7b77bee2e | |||
| ce82b5f776 | |||
| f328825da7 | |||
| 892a5f9127 | |||
| 1d6e41829a | |||
| 0484655f13 | |||
| 786920c7e3 | |||
| b84e94250f | |||
| 9591795b10 | |||
| 8df2e46658 | |||
| 22d01c3c55 | |||
| d1b32f4310 | |||
| 4b5e2eaf35 | |||
| e283aa4055 | |||
| b5aba401e5 | |||
| bc2edcdde6 | |||
| 5962b9b518 | |||
| 1b447b7b0b | |||
| e0cf5e2e27 | |||
| 2520715b8a |
@@ -209,6 +209,9 @@ api/.vscode
|
||||
.history
|
||||
|
||||
.idea/
|
||||
.claude/
|
||||
.cursor/
|
||||
openspec/
|
||||
|
||||
# pnpm
|
||||
/.pnpm-store
|
||||
|
||||
@@ -0,0 +1,9 @@
|
||||
# Admin (Go Backend) Agent Guide
|
||||
|
||||
## Rules (must follow)
|
||||
|
||||
### 禁止匿名 struct
|
||||
|
||||
- **禁止在代码中出现匿名 struct**。不得使用 `var x []struct { ... }`、`var x struct { ... }` 或字面量 `struct { A int }{1}` 等匿名结构体。
|
||||
- 所有用于 GORM 查询扫描、缓存结构、API 请求/响应的结构体必须定义为**具名类型**,放在合适的 model 包(如 `model/gaia/request`、`model/gaia/response`)或当前包顶部,便于复用和规范约束。
|
||||
- 示例:用 `[]response.AppQuotaRankingRow` 替代 `[]struct { AppID string; TotalCost float64; ... }`;用 `response.AppQuotaRankingCache` 替代 `struct { List ...; Total int64 }`。
|
||||
@@ -1,4 +1,4 @@
|
||||
FROM golang:alpine as builder
|
||||
FROM golang:alpine AS builder
|
||||
|
||||
RUN mkdir /app
|
||||
WORKDIR /app
|
||||
|
||||
@@ -12,6 +12,7 @@ type ApiGroup struct {
|
||||
BatchWorkflowApi
|
||||
AppVersionApi
|
||||
ModelProviderApi
|
||||
ForwardProxyApi
|
||||
}
|
||||
|
||||
var (
|
||||
|
||||
@@ -0,0 +1,108 @@
|
||||
package gaia
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/global"
|
||||
gaiaModel "github.com/flipped-aurora/gin-vue-admin/server/model/gaia"
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/model/gaia/request"
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type ForwardProxyApi struct{}
|
||||
|
||||
// ForwardProxy 转发代理入口:免 JWT,通过 forwarding token + ding_id 鉴权并计费
|
||||
// @Tags ForwardProxy
|
||||
// @Summary GPT 转发代理(钉钉入口,无需 JWT)
|
||||
// @Param X-Forward-Token header string false "转发 Token"
|
||||
// @Param X-Ding-Id header string false "钉钉 ID"
|
||||
// @Param forward_token query string false "转发 Token(Header 优先)"
|
||||
// @Param ding_id query string false "钉钉 ID(Header 优先)"
|
||||
// @Param path path string true "上游路径"
|
||||
// @Router /gaia/forward/proxy/{path} [get,post,put,patch,delete]
|
||||
func (f *ForwardProxyApi) ForwardProxy(c *gin.Context) {
|
||||
// 打印请求 Header,便于排查转发问题
|
||||
global.GVA_LOG.Info("ForwardProxy 请求头",
|
||||
zap.Any("headers", c.Request.Header),
|
||||
zap.String("method", c.Request.Method),
|
||||
zap.String("path", c.Request.URL.Path),
|
||||
)
|
||||
|
||||
// 1. 读取转发配置
|
||||
integrate := systemIntegratedService.GetIntegratedConfig(gaiaModel.SystemIntegrationDingTalk)
|
||||
configMap, err := systemIntegratedService.ParseDingTalkConfig(integrate.Config)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{"message": "配置解析失败"}})
|
||||
return
|
||||
}
|
||||
|
||||
// 2. 获取并校验 forwarding token(存在有效 Token 即视为开启转发能力)
|
||||
dingId := c.GetHeader("X-Ding-Id")
|
||||
apiKey := c.GetHeader("X-Api-Key")
|
||||
bearer := c.GetHeader("Authorization")
|
||||
token := c.GetHeader("X-Forward-Token")
|
||||
|
||||
if (len(bearer) > gaiaModel.BearerLength || len(apiKey) > gaiaModel.BearerLength) && len(dingId) == 0 {
|
||||
if len(bearer) > gaiaModel.BearerLength {
|
||||
if bearer[:gaiaModel.BearerLength] == "Bearer " {
|
||||
bearer = bearer[gaiaModel.BearerLength:]
|
||||
}
|
||||
} else if len(apiKey) > gaiaModel.BearerLength {
|
||||
if apiKey[:gaiaModel.BearerLength] == "Bearer " {
|
||||
bearer = apiKey[gaiaModel.BearerLength:]
|
||||
} else {
|
||||
bearer = apiKey
|
||||
}
|
||||
}
|
||||
|
||||
if dingId, err = systemIntegratedService.ParseForwardToken(bearer, configMap.ForwardConfig.Tokens); err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": gin.H{"message": "Token 验证失败: " + err.Error()}})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if token == "" {
|
||||
token = c.Query("forward_token")
|
||||
}
|
||||
if token == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": gin.H{"message": "缺少转发 Token"}})
|
||||
return
|
||||
}
|
||||
if !validateForwardToken(token, configMap.ForwardConfig.Tokens) {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": gin.H{"message": "无效的转发 Token"}})
|
||||
return
|
||||
}
|
||||
|
||||
// 4. 获取 ding_id
|
||||
if dingId == "" {
|
||||
dingId = c.Query("ding_id")
|
||||
}
|
||||
if dingId == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{"message": "缺少 ding_id"}})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 5. 解析 account_id
|
||||
accountId, err := systemIntegratedService.ResolveAccountByDingId(dingId, configMap.EmailApi)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Warn("ForwardProxy 用户解析失败", zap.String("ding_id", dingId), zap.Error(err))
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{"message": "无法解析用户:" + err.Error()}})
|
||||
return
|
||||
}
|
||||
|
||||
// 6. 复用与 Proxy 相同的转发逻辑(path/body/ProxyRequest)
|
||||
proxyWithAccountId(c, accountId)
|
||||
}
|
||||
|
||||
// validateForwardToken 校验 token 是否在转发 Token 列表中(SHA256 比对)
|
||||
func validateForwardToken(token string, tokens []request.ForwardToken) bool {
|
||||
hash := fmt.Sprintf("%x", sha256.Sum256([]byte(token)))
|
||||
for _, t := range tokens {
|
||||
if t.TokenHash == hash {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -2,13 +2,13 @@ package gaia
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/global"
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/model/common/response"
|
||||
gaiaReq "github.com/flipped-aurora/gin-vue-admin/server/model/gaia/request"
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/service"
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/utils"
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -31,7 +31,7 @@ func (m *ModelProviderApi) GetProviderList(c *gin.Context) {
|
||||
list, err := modelProviderService.GetProviderList()
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("获取提供商配置列表失败", zap.Error(err))
|
||||
response.FailWithMessage("获取失败: "+err.Error(), c)
|
||||
response.FailWithMessage("获取失败:"+err.Error(), c)
|
||||
return
|
||||
}
|
||||
response.OkWithData(list, c)
|
||||
@@ -54,20 +54,20 @@ func (m *ModelProviderApi) UpdateProviderConfig(c *gin.Context) {
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.FailWithMessage("参数错误: "+err.Error(), c)
|
||||
response.FailWithMessage("参数错误:"+err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
if err := modelProviderService.UpdateProviderConfig(req.ProviderName, req.Enabled, req.Models); err != nil {
|
||||
global.GVA_LOG.Error("更新提供商配置失败", zap.String("provider", req.ProviderName), zap.Error(err))
|
||||
response.FailWithMessage("更新失败: "+err.Error(), c)
|
||||
response.FailWithMessage("更新失败:"+err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
response.OkWithMessage("更新成功", c)
|
||||
}
|
||||
|
||||
// GetModels 获取开启的模型列表(OpenAI格式)
|
||||
// GetModels 获取开启的模型列表(OpenAI 格式,供第三方兼容调用;成功时返回裸 JSON,错误时与项目统一使用 response)。
|
||||
// @Tags ModelProvider
|
||||
// @Summary 获取开启的模型列表
|
||||
// @Security ApiKeyAuth
|
||||
@@ -79,73 +79,73 @@ func (m *ModelProviderApi) GetModels(c *gin.Context) {
|
||||
models, err := modelProviderService.GetEnabledModels()
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("获取模型列表失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": gin.H{
|
||||
"message": "获取模型列表失败: " + err.Error(),
|
||||
},
|
||||
})
|
||||
response.FailWithMessage("获取失败:"+err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, models)
|
||||
}
|
||||
|
||||
// Proxy 通用中转 API:将 /gaia/proxy/* 的请求按路径转发到上游(如 /v1/chat/completions、/v1/messages、/v1/images/generations、/v1/embeddings 等)。
|
||||
// 上游 base 优先使用 provider_credentials 的 openai_api_base(如 "https://yunwu.ai"),便于计费区分。
|
||||
// proxyWithAccountId 通用代理逻辑:按路径转发到上游并计费。
|
||||
func proxyWithAccountId(c *gin.Context, accountId string) {
|
||||
path := c.Param("path")
|
||||
if path == "" || path == "/" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{"message": "代理路径不能为空"}})
|
||||
return
|
||||
}
|
||||
reqHeader := c.Request.Header.Clone()
|
||||
if q := strings.TrimSpace(c.Query("provider")); q != "" {
|
||||
reqHeader.Set("X-Gaia-Provider", q)
|
||||
}
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{"message": "读取请求体失败"}})
|
||||
return
|
||||
}
|
||||
var bodyModel string
|
||||
if len(body) > 0 {
|
||||
var parseObj map[string]interface{}
|
||||
if jsonErr := json.Unmarshal(body, &parseObj); jsonErr == nil {
|
||||
if mv, ok := parseObj["model"].(string); ok {
|
||||
bodyModel = mv
|
||||
}
|
||||
}
|
||||
}
|
||||
global.GVA_LOG.Info("Gaia代理请求入参",
|
||||
zap.String("account_id", accountId),
|
||||
zap.String("path", path),
|
||||
zap.String("method", c.Request.Method),
|
||||
zap.Int("body_len", len(body)),
|
||||
zap.String("body_model", bodyModel),
|
||||
zap.String("client_ip", c.ClientIP()),
|
||||
)
|
||||
|
||||
// 余额前置检查:余额耗尽时直接拦截,不继续请求上游
|
||||
if quotaErr := modelProviderService.CheckAccountQuota(accountId); quotaErr != nil {
|
||||
c.JSON(http.StatusPaymentRequired, gin.H{"error": gin.H{"message": quotaErr.Error()}})
|
||||
return
|
||||
}
|
||||
|
||||
if err = modelProviderService.ProxyRequest(
|
||||
accountId, path, c.Request.Method, reqHeader, body, c.Writer); err != nil {
|
||||
global.GVA_LOG.Info("Gaia代理请求body",
|
||||
zap.String("body", string(body)),
|
||||
)
|
||||
global.GVA_LOG.Error("代理请求失败", zap.String("account_id", accountId), zap.String("path", path), zap.Error(err))
|
||||
if !c.Writer.Written() {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{"message": err.Error()}})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Proxy 通用中转 API:将 /gaia/proxy/* 的请求按路径转发到上游(需 JWT,account 来自当前登录用户)。
|
||||
// @Tags ModelProvider
|
||||
// @Summary 通用中转API(按路径转发)
|
||||
// @Security ApiKeyAuth
|
||||
// @Param path path string true "上游路径,如 v1/chat/completions、v1/messages"
|
||||
// @Router /gaia/proxy/*path [get,post,put,patch,delete]
|
||||
func (m *ModelProviderApi) Proxy(c *gin.Context) {
|
||||
// init
|
||||
var err error
|
||||
var body []byte
|
||||
path := c.Param("path")
|
||||
userID := utils.GetUserUuid(c).String()
|
||||
if path == "" || path == "/" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{"message": "代理路径不能为空"}})
|
||||
return
|
||||
}
|
||||
// 将 query provider 转为请求头,供 service 解析
|
||||
reqHeader := c.Request.Header.Clone()
|
||||
if q := strings.TrimSpace(c.Query("provider")); q != "" {
|
||||
reqHeader.Set("X-Gaia-Provider", q)
|
||||
}
|
||||
|
||||
if body, err = io.ReadAll(c.Request.Body); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{"message": "读取请求体失败"}})
|
||||
return
|
||||
}
|
||||
|
||||
// 打印传入参数便于排查
|
||||
queryProvider := strings.TrimSpace(c.Query("provider"))
|
||||
var bodyModel string
|
||||
if len(body) > 0 {
|
||||
var parseObj map[string]interface{}
|
||||
if jsonErr := json.Unmarshal(body, &parseObj); jsonErr == nil {
|
||||
if m, ok := parseObj["model"].(string); ok {
|
||||
bodyModel = m
|
||||
}
|
||||
}
|
||||
}
|
||||
global.GVA_LOG.Info("Gaia代理请求入参",
|
||||
zap.String("path", path),
|
||||
zap.String("method", c.Request.Method),
|
||||
zap.String("query_provider", queryProvider),
|
||||
zap.Int("body_len", len(body)),
|
||||
zap.String("body_model", bodyModel),
|
||||
zap.String("body", string(body)),
|
||||
)
|
||||
|
||||
if err = modelProviderService.ProxyRequest(
|
||||
userID, path, c.Request.Method, reqHeader, body, c.Writer); err != nil {
|
||||
global.GVA_LOG.Error("代理请求失败", zap.String("user_id", userID), zap.String(
|
||||
"path", path), zap.Error(err))
|
||||
if !c.Writer.Written() {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{"message": err.Error()}})
|
||||
}
|
||||
}
|
||||
accountId := utils.GetUserUuid(c).String()
|
||||
proxyWithAccountId(c, accountId)
|
||||
}
|
||||
|
||||
// GetAvailableModels 获取提供商的可用模型
|
||||
@@ -160,17 +160,15 @@ func (m *ModelProviderApi) Proxy(c *gin.Context) {
|
||||
func (m *ModelProviderApi) GetAvailableModels(c *gin.Context) {
|
||||
providerName := c.Query("provider_name")
|
||||
if providerName == "" {
|
||||
response.FailWithMessage("参数错误: provider_name不能为空", c)
|
||||
response.FailWithMessage("参数错误:provider_name不能为空", c)
|
||||
return
|
||||
}
|
||||
|
||||
models, err := modelProviderService.GetAvailableModelsFromDify(providerName)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("获取可用模型失败", zap.String("provider", providerName), zap.Error(err))
|
||||
response.FailWithMessage("获取失败: "+err.Error(), c)
|
||||
response.FailWithMessage("获取失败:"+err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
response.OkWithData(models, c)
|
||||
}
|
||||
|
||||
@@ -186,35 +184,55 @@ func (m *ModelProviderApi) GetAvailableModels(c *gin.Context) {
|
||||
func (m *ModelProviderApi) TestProviderCredentials(c *gin.Context) {
|
||||
providerName := c.Query("provider_name")
|
||||
if providerName == "" {
|
||||
response.FailWithMessage("参数错误: provider_name不能为空", c)
|
||||
response.FailWithMessage("参数错误:provider_name不能为空", c)
|
||||
return
|
||||
}
|
||||
|
||||
creds, err := modelProviderService.GetDifyProviderCredentials(providerName)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("获取提供商凭证失败", zap.String("provider", providerName), zap.Error(err))
|
||||
response.FailWithMessage("获取凭证失败: "+err.Error(), c)
|
||||
response.FailWithMessage("获取凭证失败:"+err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
// 隐藏API Key的大部分内容
|
||||
maskedKey := ""
|
||||
if len(creds.APIKey) > 8 {
|
||||
maskedKey = creds.APIKey[:4] + "****" + creds.APIKey[len(creds.APIKey)-4:]
|
||||
} else {
|
||||
maskedKey = "****"
|
||||
}
|
||||
var result map[string]interface{}
|
||||
|
||||
result := map[string]interface{}{
|
||||
"provider": providerName,
|
||||
"has_api_key": creds.APIKey != "",
|
||||
"api_key": maskedKey,
|
||||
// AWS Bedrock:展示 access key 信息
|
||||
if creds.AWSAccessKeyID != "" {
|
||||
maskedKey := "****"
|
||||
if len(creds.AWSAccessKeyID) > 8 {
|
||||
maskedKey = creds.AWSAccessKeyID[:4] + "****" + creds.AWSAccessKeyID[len(creds.AWSAccessKeyID)-4:]
|
||||
}
|
||||
region := creds.AWSRegion
|
||||
if region == "" {
|
||||
region = "us-east-1(默认)"
|
||||
}
|
||||
result = map[string]interface{}{
|
||||
"provider": providerName,
|
||||
"has_api_key": true,
|
||||
"api_key": maskedKey,
|
||||
"aws_access_key_id": maskedKey,
|
||||
"aws_region": region,
|
||||
"has_session_token": creds.AWSSessionToken != "",
|
||||
}
|
||||
} else {
|
||||
// 隐藏API Key的大部分内容
|
||||
maskedKey := ""
|
||||
if len(creds.APIKey) > 8 {
|
||||
maskedKey = creds.APIKey[:4] + "****" + creds.APIKey[len(creds.APIKey)-4:]
|
||||
} else {
|
||||
maskedKey = "****"
|
||||
}
|
||||
result = map[string]interface{}{
|
||||
"provider": providerName,
|
||||
"has_api_key": creds.APIKey != "",
|
||||
"api_key": maskedKey,
|
||||
}
|
||||
}
|
||||
|
||||
response.OkWithData(result, c)
|
||||
}
|
||||
|
||||
// GetProxyLogs 获取代理日志
|
||||
// GetProxyLogs 获取代理日志(分页)
|
||||
// @Tags ModelProvider
|
||||
// @Summary 获取代理日志
|
||||
// @Security ApiKeyAuth
|
||||
@@ -222,54 +240,30 @@ func (m *ModelProviderApi) TestProviderCredentials(c *gin.Context) {
|
||||
// @Produce application/json
|
||||
// @Param page query int false "页码"
|
||||
// @Param page_size query int false "每页数量"
|
||||
// @Success 200 {object} response.Response{data=map[string]interface{},msg=string} "获取成功"
|
||||
// @Success 200 {object} response.Response{data=response.PageResult,msg=string} "获取成功"
|
||||
// @Router /gaia/model-provider/logs [get]
|
||||
func (m *ModelProviderApi) GetProxyLogs(c *gin.Context) {
|
||||
page := c.DefaultQuery("page", "1")
|
||||
pageSize := c.DefaultQuery("page_size", "20")
|
||||
|
||||
var pageInt, pageSizeInt int
|
||||
if _, err := fmt.Sscanf(page, "%d", &pageInt); err != nil {
|
||||
pageInt = 1
|
||||
}
|
||||
if _, err := fmt.Sscanf(pageSize, "%d", &pageSizeInt); err != nil {
|
||||
pageSizeInt = 20
|
||||
}
|
||||
|
||||
if pageInt < 1 {
|
||||
pageInt = 1
|
||||
}
|
||||
if pageSizeInt < 1 || pageSizeInt > 100 {
|
||||
pageSizeInt = 20
|
||||
}
|
||||
|
||||
var logs []map[string]interface{}
|
||||
var total int64
|
||||
|
||||
db := global.GVA_DB.Table("model_proxy_log")
|
||||
|
||||
// 获取总数
|
||||
if err := db.Count(&total).Error; err != nil {
|
||||
global.GVA_LOG.Error("获取日志总数失败", zap.Error(err))
|
||||
response.FailWithMessage("获取失败: "+err.Error(), c)
|
||||
var req gaiaReq.GetProxyLogsReq
|
||||
if err := c.ShouldBindQuery(&req); err != nil {
|
||||
response.FailWithMessage("参数错误:"+err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
// 分页查询
|
||||
offset := (pageInt - 1) * pageSizeInt
|
||||
if err := db.Order("created_at DESC").Limit(pageSizeInt).Offset(
|
||||
offset).Find(&logs).Error; err != nil {
|
||||
global.GVA_LOG.Error("获取日志列表失败", zap.Error(err))
|
||||
response.FailWithMessage("获取失败: "+err.Error(), c)
|
||||
if req.Page < 1 {
|
||||
req.Page = 1
|
||||
}
|
||||
if req.PageSize < 1 || req.PageSize > 100 {
|
||||
req.PageSize = 20
|
||||
}
|
||||
list, total, err := modelProviderService.GetProxyLogs(req)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("获取代理日志失败", zap.Error(err))
|
||||
response.FailWithMessage("获取失败:"+err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
result := map[string]interface{}{
|
||||
"list": logs,
|
||||
"total": total,
|
||||
"page": pageInt,
|
||||
"page_size": pageSizeInt,
|
||||
}
|
||||
|
||||
response.OkWithData(result, c)
|
||||
response.OkWithDetailed(response.PageResult{
|
||||
List: list,
|
||||
Total: total,
|
||||
Page: req.Page,
|
||||
PageSize: req.PageSize,
|
||||
}, "获取成功", c)
|
||||
}
|
||||
|
||||
@@ -2,10 +2,26 @@ package gaia
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/global"
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/model/common/response"
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/model/gaia"
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/model/gaia/request"
|
||||
gaiaResp "github.com/flipped-aurora/gin-vue-admin/server/model/gaia/response"
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/model/system"
|
||||
serviceGaia "github.com/flipped-aurora/gin-vue-admin/server/service/gaia"
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/utils"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"time"
|
||||
)
|
||||
|
||||
type SystemApi struct{}
|
||||
@@ -55,3 +71,260 @@ func (systemApi *SystemApi) SetDingTalk(c *gin.Context) {
|
||||
}
|
||||
response.OkWithData("ok", c)
|
||||
}
|
||||
|
||||
// TestEmailApiConfig 测试第三方邮箱 API 配置
|
||||
// @Tags System
|
||||
// @Summary 测试第三方邮箱 API 配置
|
||||
// @Security ApiKeyAuth
|
||||
// @accept application/json
|
||||
// @Produce application/json
|
||||
// @Param data body request.TestEmailApiConfigRequest true "测试配置请求"
|
||||
// @Success 200 {object} response.Response{data=gaiaResp.TestEmailApiConfigResponse,msg=string} "测试结果"
|
||||
// @Router /gaia/system/dingtalk/test-email-config [post]
|
||||
func (systemApi *SystemApi) TestEmailApiConfig(c *gin.Context) {
|
||||
var req request.TestEmailApiConfigRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.FailWithMessage(err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
result, err := systemIntegratedService.TestEmailApiConfig(req.Config, req.TestDingID)
|
||||
if err != nil {
|
||||
response.FailWithMessage("测试失败:"+err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
response.OkWithData(result, c)
|
||||
}
|
||||
|
||||
// GetDingTalkTestAuthURL 获取「测试连接」用的钉钉授权 URL,打开后扫码完成即视为连接成功
|
||||
// @Router /gaia/system/dingtalk/test-auth-url [get]
|
||||
func (systemApi *SystemApi) GetDingTalkTestAuthURL(c *gin.Context) {
|
||||
origin := c.GetHeader("Referer")
|
||||
if origin == "" {
|
||||
origin = c.GetHeader("Origin")
|
||||
}
|
||||
if origin != "" {
|
||||
if u, err := url.Parse(origin); err == nil {
|
||||
origin = u.Scheme + "://" + u.Host + strings.TrimSuffix(u.Path, "/")
|
||||
}
|
||||
}
|
||||
if origin == "" {
|
||||
response.FailWithMessage("无法获取前端地址,请从配置页点击「测试连接」", c)
|
||||
return
|
||||
}
|
||||
authURL, err := systemIntegratedService.GetDingTalkTestAuthURL(origin)
|
||||
if err != nil {
|
||||
response.FailWithMessage(err.Error(), c)
|
||||
return
|
||||
}
|
||||
response.OkWithData(gin.H{"auth_url": authURL}, c)
|
||||
}
|
||||
|
||||
// DingTalkTestCallback 测试连接回调:仅用 code 换 token 验证,不登录
|
||||
// @Router /gaia/system/dingtalk/test-callback [post]
|
||||
func (systemApi *SystemApi) DingTalkTestCallback(c *gin.Context) {
|
||||
var req struct {
|
||||
Code string `json:"code"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil || strings.TrimSpace(req.Code) == "" {
|
||||
response.FailWithMessage("缺少授权码 code", c)
|
||||
return
|
||||
}
|
||||
if err := systemIntegratedService.DingTalkTestCallback(req.Code); err != nil {
|
||||
response.FailWithMessage("验证失败: "+err.Error(), c)
|
||||
return
|
||||
}
|
||||
response.OkWithMessage("验证成功", c)
|
||||
}
|
||||
|
||||
// GetForwardTokens 获取转发 Token 列表
|
||||
// @Tags System
|
||||
// @Summary 获取转发 Token 列表
|
||||
// @Security ApiKeyAuth
|
||||
// @accept application/json
|
||||
// @Produce application/json
|
||||
// @Success 200 {object} response.Response{data=gaiaResp.ForwardTokensResponse,msg=string} "查询成功"
|
||||
// @Router /gaia/system/forward-tokens [get]
|
||||
func (systemApi *SystemApi) GetForwardTokens(c *gin.Context) {
|
||||
integrate := systemIntegratedService.GetIntegratedConfig(gaia.SystemIntegrationDingTalk)
|
||||
|
||||
var configMap request.DingTalkConfigRequest
|
||||
if integrate.Config != "" {
|
||||
if err := json.Unmarshal([]byte(integrate.Config), &configMap); err != nil {
|
||||
response.FailWithMessage("解析配置失败:"+err.Error(), c)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
tokens := make([]gaiaResp.ForwardTokenInfo, 0, len(configMap.ForwardConfig.Tokens))
|
||||
for i, token := range configMap.ForwardConfig.Tokens {
|
||||
tokens = append(tokens, gaiaResp.ForwardTokenInfo{
|
||||
ID: utils.AddAsteriskToString(token.TokenSecret),
|
||||
CreatedAt: token.CreatedAt,
|
||||
Seq: i + 1,
|
||||
})
|
||||
}
|
||||
|
||||
response.OkWithData(gaiaResp.ForwardTokensResponse{Tokens: tokens, Count: len(tokens), Max: 20}, c)
|
||||
}
|
||||
|
||||
// CreateForwardToken 新增转发 Token
|
||||
// @Tags System
|
||||
// @Summary 新增转发 Token
|
||||
// @Security ApiKeyAuth
|
||||
// @accept application/json
|
||||
// @Produce application/json
|
||||
// @Param token body string true "Token 明文"
|
||||
// @Success 200 {object} response.Response{data=request.ForwardToken,msg=string} "创建成功"
|
||||
// @Router /gaia/system/forward-tokens [post]
|
||||
func (systemApi *SystemApi) CreateForwardToken(c *gin.Context) {
|
||||
var req struct {
|
||||
Token string `json:"token"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.FailWithMessage(err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
if req.Token == "" {
|
||||
response.FailWithMessage("Token 不能为空", c)
|
||||
return
|
||||
}
|
||||
|
||||
integrate := systemIntegratedService.GetIntegratedConfig(gaia.SystemIntegrationDingTalk)
|
||||
var configMap request.DingTalkConfigRequest
|
||||
if integrate.Config != "" {
|
||||
if err := json.Unmarshal([]byte(integrate.Config), &configMap); err != nil {
|
||||
response.FailWithMessage("解析配置失败:"+err.Error(), c)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 检查数量限制
|
||||
if len(configMap.ForwardConfig.Tokens) >= 20 {
|
||||
response.FailWithMessage("转发 Token 最多 20 个", c)
|
||||
return
|
||||
}
|
||||
|
||||
// 生成唯一 ID 和哈希
|
||||
tokenID := "tok_" + uuid.New().String()
|
||||
tokenHash := fmt.Sprintf("%x", sha256.Sum256([]byte(req.Token)))
|
||||
// 生成 HMAC 签名密钥(仅创建时回传一次)
|
||||
secretBytes := make([]byte, 32)
|
||||
if _, err := rand.Read(secretBytes); err != nil {
|
||||
response.FailWithMessage("生成 TokenSecret 失败:"+err.Error(), c)
|
||||
return
|
||||
}
|
||||
tokenSecret := base64.RawURLEncoding.EncodeToString(secretBytes)
|
||||
|
||||
newToken := request.ForwardToken{
|
||||
ID: tokenID,
|
||||
TokenHash: tokenHash,
|
||||
CreatedAt: time.Now(),
|
||||
TokenSecret: tokenSecret,
|
||||
}
|
||||
|
||||
// 添加到配置
|
||||
configMap.ForwardConfig.Tokens = append(configMap.ForwardConfig.Tokens, newToken)
|
||||
seq := len(configMap.ForwardConfig.Tokens) // 1..N
|
||||
configJSON, _ := json.Marshal(configMap)
|
||||
integrate.Config = string(configJSON)
|
||||
|
||||
// 保存配置
|
||||
if err := systemIntegratedService.SetIntegratedConfig(integrate, "", false); err != nil {
|
||||
response.FailWithMessage("保存失败:"+err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
// 返回明文 token(仅此次展示)
|
||||
response.OkWithData(gin.H{
|
||||
"seq": seq,
|
||||
"token": req.Token,
|
||||
"token_secret": tokenSecret,
|
||||
"created_at": newToken.CreatedAt,
|
||||
}, c)
|
||||
}
|
||||
|
||||
// DeleteForwardToken 删除转发 Token
|
||||
// @Tags System
|
||||
// @Summary 删除转发 Token
|
||||
// @Security ApiKeyAuth
|
||||
// @accept application/json
|
||||
// @Produce application/json
|
||||
// @Param seq path int true "Token 序列号(从列表获取,1..N)"
|
||||
// @Param password body string true "当前用户密码"
|
||||
// @Success 200 {object} response.Response{msg=string} "删除成功"
|
||||
// @Router /gaia/system/forward-tokens/:seq [delete]
|
||||
func (systemApi *SystemApi) DeleteForwardToken(c *gin.Context) {
|
||||
seqStr := c.Param("seq")
|
||||
if seqStr == "" {
|
||||
response.FailWithMessage("Token 序列号不能为空", c)
|
||||
return
|
||||
}
|
||||
seq, err := strconv.Atoi(seqStr)
|
||||
if err != nil || seq <= 0 {
|
||||
response.FailWithMessage("Token 序列号非法", c)
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Password string `json:"password"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.FailWithMessage(err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
// 验证当前用户密码(使用 Dify account 密码体系)
|
||||
userID := utils.GetUserUuid(c).String()
|
||||
var user system.SysUser
|
||||
if err := global.GVA_DB.Select("email").Where(
|
||||
"uuid = ?", userID).First(&user).Error; err != nil {
|
||||
response.FailWithMessage("查询用户失败:"+err.Error(), c)
|
||||
return
|
||||
}
|
||||
account, err := user.GetAccount()
|
||||
if err != nil {
|
||||
response.FailWithMessage("查询账号失败:"+err.Error(), c)
|
||||
return
|
||||
}
|
||||
var pwd serviceGaia.PasswdEncode
|
||||
if ok, pwdErr := pwd.ComparePassword(
|
||||
req.Password, account.Password, account.PasswordSalt); pwdErr != nil || !ok {
|
||||
response.FailWithMessage("密码错误", c)
|
||||
return
|
||||
}
|
||||
|
||||
// 获取配置
|
||||
integrate := systemIntegratedService.GetIntegratedConfig(gaia.SystemIntegrationDingTalk)
|
||||
var configMap request.DingTalkConfigRequest
|
||||
if integrate.Config != "" {
|
||||
if err = json.Unmarshal([]byte(integrate.Config), &configMap); err != nil {
|
||||
response.FailWithMessage("解析配置失败:"+err.Error(), c)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 查找并删除 token
|
||||
if seq > len(configMap.ForwardConfig.Tokens) {
|
||||
response.FailWithMessage("Token 不存在", c)
|
||||
return
|
||||
}
|
||||
idx := seq - 1
|
||||
newTokens := make([]request.ForwardToken, 0, len(configMap.ForwardConfig.Tokens)-1)
|
||||
newTokens = append(newTokens, configMap.ForwardConfig.Tokens[:idx]...)
|
||||
newTokens = append(newTokens, configMap.ForwardConfig.Tokens[idx+1:]...)
|
||||
|
||||
// 更新配置
|
||||
configMap.ForwardConfig.Tokens = newTokens
|
||||
configJSON, _ := json.Marshal(configMap)
|
||||
integrate.Config = string(configJSON)
|
||||
|
||||
if err = systemIntegratedService.SetIntegratedConfig(integrate, "", false); err != nil {
|
||||
response.FailWithMessage("保存失败:"+err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
response.OkWithMessage("删除成功", c)
|
||||
}
|
||||
|
||||
@@ -5,9 +5,8 @@ import (
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/model/common/response"
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/model/system/request"
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/service/system"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type DBApi struct{}
|
||||
@@ -20,7 +19,8 @@ type DBApi struct{}
|
||||
// @Success 200 {object} response.Response{data=string} "初始化用户数据库"
|
||||
// @Router /init/initdb [post]
|
||||
func (i *DBApi) InitDB(c *gin.Context) {
|
||||
if global.GVA_DB != nil {
|
||||
|
||||
if initDBService.IfInit() {
|
||||
global.GVA_LOG.Error("已存在数据库配置!")
|
||||
response.FailWithMessage("已存在数据库配置", c)
|
||||
return
|
||||
@@ -51,12 +51,11 @@ func (i *DBApi) InitDB(c *gin.Context) {
|
||||
// @Success 200 {object} response.Response{data=map[string]interface{},msg=string} "初始化用户数据库"
|
||||
// @Router /init/checkdb [post]
|
||||
func (i *DBApi) CheckDB(c *gin.Context) {
|
||||
var (
|
||||
message = "前往初始化数据库"
|
||||
needInit = true
|
||||
)
|
||||
// init
|
||||
var needInit = true
|
||||
var message = "前往初始化数据库"
|
||||
|
||||
if global.GVA_DB != nil {
|
||||
if initDBService.IfInit() {
|
||||
message = "数据库无需初始化"
|
||||
needInit = false
|
||||
}
|
||||
|
||||
@@ -113,10 +113,10 @@ pgsql:
|
||||
prefix: ""
|
||||
port: "5432"
|
||||
config: sslmode=disable TimeZone=Asia/Shanghai
|
||||
db-name:
|
||||
username:
|
||||
password:
|
||||
path:
|
||||
db-name: dify
|
||||
username: postgres
|
||||
password: difyai123456
|
||||
path: db_postgres
|
||||
engine: ""
|
||||
log-mode: error
|
||||
max-idle-conns: 10
|
||||
|
||||
@@ -5,5 +5,5 @@ type Gaia struct {
|
||||
LoginMaxErrorLimit int `mapstructure:"login_max_error_limit" json:"login_max_error_limit" yaml:"login_max_error_limit"`
|
||||
SuperAdminAccountId string `mapstructure:"SUPER_ADMIN_ACCOUNT_ID" json:"SUPER_ADMIN_ACCOUNT_ID" yaml:"SUPER_ADMIN_ACCOUNT_ID"` // 超级管理员账号
|
||||
SuperAdminTenantId string `mapstructure:"SUPER_ADMIN_TENANT_ID" json:"SUPER_ADMIN_TENANT_ID" yaml:"SUPER_ADMIN_TENANT_ID"` // 系统默认工作区
|
||||
StoragePath string `mapstructure:"storage-path" json:"storage-path" yaml:"storage-path"` // Dify storage 目录路径,用于读取私钥
|
||||
StoragePath string `mapstructure:"storage-path" json:"storage-path" yaml:"storage-path"` // Dify storage 目录路径,用于读取私钥
|
||||
}
|
||||
|
||||
@@ -0,0 +1,126 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/global"
|
||||
)
|
||||
|
||||
// overrideDBFromEnv 从环境变量覆盖数据库配置
|
||||
// 优先级:环境变量 > 配置文件
|
||||
func overrideDBFromEnv() {
|
||||
// 数据库类型
|
||||
if dbType := os.Getenv("DB_TYPE"); dbType != "" {
|
||||
switch dbType {
|
||||
case "mysql":
|
||||
global.GVA_CONFIG.System.DbType = "mysql"
|
||||
overrideMysqlFromEnv()
|
||||
case "postgresql", "postgres":
|
||||
global.GVA_CONFIG.System.DbType = "pgsql"
|
||||
overridePgsqlFromEnv()
|
||||
default:
|
||||
global.GVA_CONFIG.System.DbType = "pgsql"
|
||||
overridePgsqlFromEnv()
|
||||
}
|
||||
fmt.Printf("Database type overridden from DB_TYPE environment variable: %s\n", dbType)
|
||||
}
|
||||
}
|
||||
|
||||
// overrideMysqlFromEnv 从环境变量覆盖 MySQL 配置
|
||||
func overrideMysqlFromEnv() {
|
||||
cfg := &global.GVA_CONFIG.Mysql
|
||||
if host := os.Getenv("DB_HOST"); host != "" {
|
||||
cfg.Path = host
|
||||
}
|
||||
if port := os.Getenv("DB_PORT"); port != "" {
|
||||
cfg.Port = port
|
||||
} else {
|
||||
cfg.Port = "3306"
|
||||
}
|
||||
if username := os.Getenv("DB_USERNAME"); username != "" {
|
||||
cfg.Username = username
|
||||
}
|
||||
if password := os.Getenv("DB_PASSWORD"); password != "" {
|
||||
cfg.Password = password
|
||||
}
|
||||
if dbname := os.Getenv("DB_DATABASE"); dbname != "" {
|
||||
cfg.Dbname = dbname
|
||||
}
|
||||
if config := os.Getenv("DB_CONFIG"); config != "" {
|
||||
cfg.Config = config
|
||||
}
|
||||
}
|
||||
|
||||
// overridePgsqlFromEnv 从环境变量覆盖 PostgreSQL 配置
|
||||
func overridePgsqlFromEnv() {
|
||||
cfg := &global.GVA_CONFIG.Pgsql
|
||||
if host := os.Getenv("DB_HOST"); host != "" {
|
||||
cfg.Path = host
|
||||
}
|
||||
if port := os.Getenv("DB_PORT"); port != "" {
|
||||
cfg.Port = port
|
||||
} else {
|
||||
cfg.Port = "5432"
|
||||
}
|
||||
if username := os.Getenv("DB_USERNAME"); username != "" {
|
||||
cfg.Username = username
|
||||
}
|
||||
if password := os.Getenv("DB_PASSWORD"); password != "" {
|
||||
cfg.Password = password
|
||||
}
|
||||
if dbname := os.Getenv("DB_DATABASE"); dbname != "" {
|
||||
cfg.Dbname = dbname
|
||||
}
|
||||
if config := os.Getenv("DB_CONFIG"); config != "" {
|
||||
cfg.Config = config
|
||||
} else {
|
||||
cfg.Config = "sslmode=disable TimeZone=Asia/Shanghai"
|
||||
}
|
||||
}
|
||||
|
||||
// overrideRedisFromEnv 从环境变量覆盖 Redis 配置
|
||||
func overrideRedisFromEnv() {
|
||||
// 覆盖主 Redis 配置
|
||||
if host := os.Getenv("REDIS_HOST"); host != "" {
|
||||
port := os.Getenv("REDIS_PORT")
|
||||
if port == "" {
|
||||
port = "6379"
|
||||
}
|
||||
global.GVA_CONFIG.Redis.Addr = host + ":" + port
|
||||
}
|
||||
if password := os.Getenv("REDIS_PASSWORD"); password != "" {
|
||||
global.GVA_CONFIG.Redis.Password = password
|
||||
}
|
||||
if db := os.Getenv("REDIS_DB"); db != "" {
|
||||
if dbNum, err := strconv.Atoi(db); err == nil {
|
||||
global.GVA_CONFIG.Redis.DB = dbNum
|
||||
}
|
||||
}
|
||||
|
||||
// 覆盖 Dify Redis 配置(与主 Redis 相同)
|
||||
if host := os.Getenv("REDIS_HOST"); host != "" {
|
||||
port := os.Getenv("REDIS_PORT")
|
||||
if port == "" {
|
||||
port = "6379"
|
||||
}
|
||||
global.GVA_CONFIG.DifyRedis.Addr = host + ":" + port
|
||||
}
|
||||
if password := os.Getenv("REDIS_PASSWORD"); password != "" {
|
||||
global.GVA_CONFIG.DifyRedis.Password = password
|
||||
}
|
||||
if db := os.Getenv("REDIS_DB"); db != "" {
|
||||
if dbNum, err := strconv.Atoi(db); err == nil {
|
||||
global.GVA_CONFIG.DifyRedis.DB = dbNum
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Printf("Redis configuration overridden from environment variables: %s\n", global.GVA_CONFIG.Redis.Addr)
|
||||
}
|
||||
|
||||
// overrideAllFromEnv 从环境变量覆盖所有配置
|
||||
func overrideAllFromEnv() {
|
||||
overrideDBFromEnv()
|
||||
overrideRedisFromEnv()
|
||||
}
|
||||
@@ -4,10 +4,11 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/initialize"
|
||||
"github.com/fvbock/endless"
|
||||
"github.com/gin-gonic/gin"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
func initServer(address string, router *gin.Engine) server {
|
||||
@@ -15,5 +16,9 @@ func initServer(address string, router *gin.Engine) server {
|
||||
s.ReadHeaderTimeout = 10 * time.Minute
|
||||
s.WriteTimeout = 10 * time.Minute
|
||||
s.MaxHeaderBytes = 1 << 20
|
||||
// 优雅关闭:在收到 SIGTERM/SIGINT 时先停止工作池,再关闭 HTTP 服务,避免 goroutine 与连接未释放
|
||||
stopPool := func() { initialize.StopWorkerPool() }
|
||||
s.SignalHooks[endless.PRE_SIGNAL][syscall.SIGTERM] = append(s.SignalHooks[endless.PRE_SIGNAL][syscall.SIGTERM], stopPool)
|
||||
s.SignalHooks[endless.PRE_SIGNAL][syscall.SIGINT] = append(s.SignalHooks[endless.PRE_SIGNAL][syscall.SIGINT], stopPool)
|
||||
return s
|
||||
}
|
||||
|
||||
@@ -26,7 +26,7 @@ func overrideJWTSigningKeyFromEnv() {
|
||||
}
|
||||
|
||||
// Viper //
|
||||
// 优先级: 命令行 > 环境变量 > 默认值
|
||||
// 优先级:命令行 > 环境变量 > 默认值
|
||||
// Author [SliverHorn](https://github.com/SliverHorn)
|
||||
func Viper(path ...string) *viper.Viper {
|
||||
var config string
|
||||
@@ -45,17 +45,17 @@ func Viper(path ...string) *viper.Viper {
|
||||
case gin.TestMode:
|
||||
config = internal.ConfigTestFile
|
||||
}
|
||||
fmt.Printf("您正在使用gin模式的%s环境名称,config的路径为%s\n", gin.Mode(), config)
|
||||
} else { // internal.ConfigEnv 常量存储的环境变量不为空 将值赋值于config
|
||||
fmt.Printf("您正在使用 gin 模式的%s环境名称,config 的路径为%s\n", gin.Mode(), config)
|
||||
} else { // internal.ConfigEnv 常量存储的环境变量不为空 将值赋值于 config
|
||||
config = configEnv
|
||||
fmt.Printf("您正在使用%s环境变量,config的路径为%s\n", internal.ConfigEnv, config)
|
||||
fmt.Printf("您正在使用%s环境变量,config 的路径为%s\n", internal.ConfigEnv, config)
|
||||
}
|
||||
} else { // 命令行参数不为空 将值赋值于config
|
||||
fmt.Printf("您正在使用命令行的-c参数传递的值,config的路径为%s\n", config)
|
||||
} else { // 命令行参数不为空 将值赋值于 config
|
||||
fmt.Printf("您正在使用命令行的 -c 参数传递的值,config 的路径为%s\n", config)
|
||||
}
|
||||
} else { // 函数传递的可变参数的第一个值赋值于config
|
||||
} else { // 函数传递的可变参数的第一个值赋值于 config
|
||||
config = path[0]
|
||||
fmt.Printf("您正在使用func Viper()传递的值,config的路径为%s\n", config)
|
||||
fmt.Printf("您正在使用 func Viper() 传递的值,config 的路径为%s\n", config)
|
||||
}
|
||||
|
||||
v := viper.New()
|
||||
@@ -82,7 +82,11 @@ func Viper(path ...string) *viper.Viper {
|
||||
// Extend: Override JWT signing key from environment variable after initial load
|
||||
overrideJWTSigningKeyFromEnv()
|
||||
|
||||
// root 适配性 根据root位置去找到对应迁移位置,保证root路径有效
|
||||
// Extend: Override database and redis configuration from environment variables
|
||||
// This allows admin-server to use the same configuration as docker-compose
|
||||
overrideAllFromEnv()
|
||||
|
||||
// root 适配性 根据 root 位置去找到对应迁移位置,保证 root 路径有效
|
||||
global.GVA_CONFIG.AutoCode.Root, _ = filepath.Abs("..")
|
||||
|
||||
return v
|
||||
|
||||
@@ -21,11 +21,12 @@ func newWithSeconds() *cron.Cron {
|
||||
|
||||
func Corn() {
|
||||
var lock bool
|
||||
initDBService := system.InitDBService{}
|
||||
c := newWithSeconds()
|
||||
// 每分钟同步一次用户列表
|
||||
if _, err := c.AddFunc("0 */1 * * * *", func() {
|
||||
if global.GVA_DB == nil {
|
||||
global.GVA_LOG.Info("【定时任务-每1分钟执行1次】同步用户列表任务,数据库没有初始化,暂未开始同步")
|
||||
if global.GVA_DB == nil || !initDBService.IfInit() {
|
||||
global.GVA_LOG.Info("【定时任务-每1分钟执行1次】同步用户列表任务,数据库没有初始化或尚未完成初始化,暂未开始同步")
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -81,8 +81,58 @@ func Gorm() *gorm.DB {
|
||||
}
|
||||
|
||||
func RegisterTables(db *gorm.DB) {
|
||||
var err error
|
||||
var count int64
|
||||
var menu system.SysBaseMenuBtn
|
||||
var authority system.SysAuthority
|
||||
if err = global.GVA_DB.Model(&menu).Count(&count).Error; count == 0 {
|
||||
if err = global.GVA_DB.Model(&authority).Count(&count).Error; count == 1 {
|
||||
return
|
||||
}
|
||||
}
|
||||
// auto
|
||||
err = db.AutoMigrate(
|
||||
system.SysApi{},
|
||||
system.SysIgnoreApi{},
|
||||
system.SysUser{},
|
||||
system.SysBaseMenu{},
|
||||
system.JwtBlacklist{},
|
||||
system.SysAuthority{},
|
||||
system.SysDictionary{},
|
||||
system.SysOperationRecord{},
|
||||
system.SysAutoCodeHistory{},
|
||||
system.SysDictionaryDetail{},
|
||||
system.SysBaseMenuParameter{},
|
||||
system.SysBaseMenuBtn{},
|
||||
system.SysAuthorityBtn{},
|
||||
system.SysAutoCodePackage{},
|
||||
system.SysExportTemplate{},
|
||||
system.Condition{},
|
||||
system.JoinTemplate{},
|
||||
system.SysParams{},
|
||||
|
||||
err := db.AutoMigrate(tables)
|
||||
example.ExaFile{},
|
||||
example.ExaCustomer{},
|
||||
example.ExaFileChunk{},
|
||||
example.ExaFileUploadAndDownload{},
|
||||
|
||||
adapter.CasbinRule{},
|
||||
|
||||
// Extend gaia model
|
||||
gaia.AccountDingTalkExtend{},
|
||||
gaia.AppRequestTestBatch{},
|
||||
gaia.AppRequestTest{},
|
||||
gaia.SystemIntegration{}, // Extend System Integration
|
||||
gaia.ForwardingExtend{}, // Extend Forwarding Extend
|
||||
gaia.BatchWorkflow{}, // Extend Batch Workflow
|
||||
gaia.BatchWorkflowTask{}, // Extend Batch Workflow Task
|
||||
gaia.AppVersionConfig{}, // 应用版本全局配置(Token)
|
||||
gaia.AppVersionRelease{}, // 应用版本发布
|
||||
gaia.AppVersionDownload{}, // 应用版本各平台安装包
|
||||
gaia.ModelProviderConfig{}, // 模型提供商配置
|
||||
gaia.ModelProxyLog{}, // 模型中转请求日志
|
||||
system.SysUserGlobalCode{}, // Extend Global Code
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("register table failed", zap.Error(err))
|
||||
|
||||
@@ -22,6 +22,7 @@ func initBizRouter(routers ...*gin.RouterGroup) {
|
||||
gaiaRouter.InitSystemRouter(privateGroup)
|
||||
gaiaRouter.InitWorkflowRouter(privateGroup)
|
||||
gaiaRouter.InitAppVersionRouter(publicGroup, privateGroup)
|
||||
gaiaRouter.InitModelProviderRouter(privateGroup) // 模型提供商路由
|
||||
gaiaRouter.InitModelProviderRouter(privateGroup) // 模型提供商路由
|
||||
gaiaRouter.InitForwardProxyRouter(publicGroup) // GPT 转发代理(免 JWT)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -25,6 +25,11 @@ func CasbinHandler() gin.HandlerFunc {
|
||||
// 获取用户的角色
|
||||
sub := strconv.Itoa(int(waitUse.AuthorityId))
|
||||
e := casbinService.Casbin() // 判断策略中是否存在
|
||||
if e == nil {
|
||||
global.GVA_LOG.Warn("Casbin enforcer is nil, skipping permission check")
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
success, _ := e.Enforce(sub, obj, act)
|
||||
if !success {
|
||||
response.FailWithDetailed(gin.H{}, "权限不足", c)
|
||||
|
||||
@@ -32,10 +32,32 @@ const (
|
||||
|
||||
// 批量工作流配置常量
|
||||
const (
|
||||
MaxTaskRetryCount = 3 // 最大任务重试次数
|
||||
ErrorPenaltyThreshold = 50 // 错误惩罚阈值(每50个错误减少1个并发位)
|
||||
MaxTaskRetryCount = 3 // 最大任务重试次数
|
||||
ErrorPenaltyThreshold = 50 // 错误惩罚阈值(每50个错误减少1个并发位)
|
||||
)
|
||||
|
||||
// 优先级分档阈值(按 total_rows 划分)
|
||||
const (
|
||||
PriorityTier1MaxRows = 300 // 第一优先级:≤300 行
|
||||
PriorityTier2MaxRows = 800 // 第二优先级:≤800 行
|
||||
PriorityTier3MaxRows = 3000 // 第三优先级:≤3000 行
|
||||
// 第四优先级:>3000 行
|
||||
)
|
||||
|
||||
// UserWorkerAllocation 用户工作器分配信息
|
||||
type UserWorkerAllocation struct {
|
||||
UserID uint `json:"user_id"`
|
||||
Workers int `json:"workers"`
|
||||
MaxLimit int `json:"max_limit"`
|
||||
}
|
||||
|
||||
// UserErrorInfo 用户错误信息(用于工作器分配计算)
|
||||
type UserErrorInfo struct {
|
||||
UserID uint `json:"user_id"`
|
||||
ErrorCount int `json:"error_count"`
|
||||
TotalRows int `json:"total_rows"`
|
||||
}
|
||||
|
||||
// BatchWorkflow 批量工作流处理
|
||||
type BatchWorkflow struct {
|
||||
ID string `json:"id" gorm:"primaryKey;comment:批量处理ID"`
|
||||
|
||||
@@ -2,6 +2,49 @@ package gaia
|
||||
|
||||
import "time"
|
||||
|
||||
// ModelPricing 从 Dify Console API 拉取的模型定价信息(对应 pricing 字段)
|
||||
type ModelPricing struct {
|
||||
Input float64 `json:"input"` // 每 unit 的输入单价
|
||||
Output float64 `json:"output"` // 每 unit 的输出单价(0 表示与 Input 相同或不区分)
|
||||
Unit float64 `json:"unit"` // 计费单位(通常 0.001,即每千 token)
|
||||
Currency string `json:"currency"` // 货币(USD / RMB)
|
||||
}
|
||||
|
||||
// ModelUsage OpenAI 格式响应中的 usage 字段(非流式及流式末尾行)
|
||||
type ModelUsage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
}
|
||||
|
||||
// ModelUsageResponse OpenAI 格式响应体(仅用于提取 usage 字段)
|
||||
type ModelUsageResponse struct {
|
||||
Usage *ModelUsage `json:"usage"`
|
||||
}
|
||||
|
||||
// DifyModelPricingRaw Dify Console API 返回的原始定价字段(值为字符串形式的数字)
|
||||
type DifyModelPricingRaw struct {
|
||||
Input string `json:"input"`
|
||||
Output string `json:"output"`
|
||||
Unit string `json:"unit"`
|
||||
Currency string `json:"currency"`
|
||||
}
|
||||
|
||||
// DifyModelItem Dify Console API 返回的单个模型信息
|
||||
type DifyModelItem struct {
|
||||
Model string `json:"model"`
|
||||
Pricing *DifyModelPricingRaw `json:"pricing"`
|
||||
}
|
||||
|
||||
// DifyProviderModels Dify Console API 返回的单个 provider 下的模型列表
|
||||
type DifyProviderModels struct {
|
||||
Models []DifyModelItem `json:"models"`
|
||||
}
|
||||
|
||||
// DifyModelsResponse Dify Console API GET /models/model-types/llm 的响应结构
|
||||
type DifyModelsResponse struct {
|
||||
Data []DifyProviderModels `json:"data"`
|
||||
}
|
||||
|
||||
// ModelProviderConfig 模型提供商配置表
|
||||
type ModelProviderConfig struct {
|
||||
Id uint `json:"id" form:"id" gorm:"primarykey;column:id;comment:id;"`
|
||||
|
||||
@@ -9,6 +9,8 @@ const (
|
||||
ProviderAzure = "azure"
|
||||
ProviderZhipuai = "zhipuai"
|
||||
ProviderMinimax = "minimax"
|
||||
ProviderAWS = "aws" // AWS Bedrock 渠道(用于转发 Claude 等 Anthropic 模型)
|
||||
ProviderDeepSeek = "deepseek" // DeepSeek 渠道
|
||||
)
|
||||
|
||||
// DifyProviderTypeCustom Dify providers 表 provider_type 枚举
|
||||
@@ -16,23 +18,31 @@ const DifyProviderTypeCustom = "custom"
|
||||
|
||||
// 凭证配置中的 key 名
|
||||
const (
|
||||
ConfigKeyOpenaiAPIKey = "openai_api_key"
|
||||
ConfigKeyOpenaiAPIBase = "openai_api_base"
|
||||
ConfigKeyOpenaiAPIVersion = "openai_api_version"
|
||||
ConfigKeyDashScopeAPIKey = "dashscope_api_key"
|
||||
ConfigKeyAPIKey = "api_key"
|
||||
ConfigKeyOpenaiAPIKey = "openai_api_key"
|
||||
ConfigKeyOpenaiAPIBase = "openai_api_base"
|
||||
ConfigKeyOpenaiAPIVersion = "openai_api_version"
|
||||
ConfigKeyDashScopeAPIKey = "dashscope_api_key"
|
||||
ConfigKeyAPIKey = "api_key"
|
||||
|
||||
// AWS Bedrock 凭证字段(Dify bedrock provider 的 encrypted_config 中使用)
|
||||
ConfigKeyAWSAccessKeyID = "aws_access_key_id"
|
||||
ConfigKeyAWSSecretAccessKey = "aws_secret_access_key"
|
||||
ConfigKeyAWSSessionToken = "aws_session_token"
|
||||
ConfigKeyAWSRegion = "aws_region"
|
||||
ConfigKeyBedrockProxyURL = "bedrock_proxy_url" // 可选:HTTP 代理地址,格式 host:port 或 http://host:port
|
||||
)
|
||||
|
||||
// SupportedProviders 列表展示的提供商顺序
|
||||
var SupportedProviders = []string{ProviderOpenai, ProviderTongyi, ProviderGoogle, ProviderAnthropic, ProviderAzure, ProviderZhipuai, ProviderMinimax}
|
||||
var SupportedProviders = []string{ProviderOpenai, ProviderTongyi, ProviderGoogle, ProviderAnthropic, ProviderAWS, ProviderAzure, ProviderZhipuai, ProviderMinimax, ProviderDeepSeek}
|
||||
|
||||
// DefaultChatCompletionsEndpoints 各提供商聊天接口默认完整 URL(兼容旧 ProxyChat)
|
||||
var DefaultChatCompletionsEndpoints = map[string]string{
|
||||
ProviderOpenai: "https://api.openai.com/v1/chat/completions",
|
||||
ProviderTongyi: "https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions",
|
||||
ProviderGoogle: "https://generativelanguage.googleapis.com/v1beta/chat/completions",
|
||||
ProviderZhipuai: "https://open.bigmodel.cn/api/paas/v4/chat/completions",
|
||||
ProviderMinimax: "https://api.minimax.chat/v1/text/chatcompletion_v2",
|
||||
ProviderOpenai: "https://api.openai.com/v1/chat/completions",
|
||||
ProviderTongyi: "https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions",
|
||||
ProviderGoogle: "https://generativelanguage.googleapis.com/v1beta/chat/completions",
|
||||
ProviderZhipuai: "https://open.bigmodel.cn/api/paas/v4/chat/completions",
|
||||
ProviderMinimax: "https://api.minimax.chat/v1/text/chatcompletion_v2",
|
||||
ProviderDeepSeek: "https://api.deepseek.com/v1/chat/completions",
|
||||
// Azure 需要动态构建 URL,不使用默认值
|
||||
}
|
||||
|
||||
@@ -42,10 +52,90 @@ var DefaultAPIBase = map[string]string{
|
||||
ProviderTongyi: "https://dashscope.aliyuncs.com/compatible-mode",
|
||||
ProviderGoogle: "https://generativelanguage.googleapis.com",
|
||||
ProviderAnthropic: "https://api.anthropic.com",
|
||||
ProviderAWS: "https://bedrock-runtime.us-east-1.amazonaws.com",
|
||||
ProviderZhipuai: "https://open.bigmodel.cn",
|
||||
ProviderMinimax: "https://api.minimax.chat",
|
||||
ProviderDeepSeek: "https://api.deepseek.com",
|
||||
// Azure 的 base URL 来自 openai_api_base 配置,不设置默认值
|
||||
}
|
||||
|
||||
// CredentialKeyFallback 未知提供商时依次尝试的配置 key
|
||||
var CredentialKeyFallback = []string{ConfigKeyOpenaiAPIKey, ConfigKeyAPIKey, ConfigKeyDashScopeAPIKey}
|
||||
|
||||
// RmbToUSDRate 人民币兑美元汇率
|
||||
const RmbToUSDRate = 7.26
|
||||
|
||||
// DefaultImageGenerationPriceUSD 图片生成等按次计费接口的默认单价(USD),无 usage 时使用
|
||||
const DefaultImageGenerationPriceUSD = 0.04
|
||||
|
||||
// DefaultQuotaFallbackUSDPerToken 未命中定价时的兜底单价:每 token 的 USD 金额(仅做记账占位,约 $0.001/千 token)
|
||||
const DefaultQuotaFallbackUSDPerToken = 0.000001
|
||||
|
||||
// Gaia 相关 Redis Key(GVA_REDIS / GVA_Dify_REDIS)
|
||||
const (
|
||||
RedisKeyGaiaAdminConsoleToken = "gaia:admin_console_token"
|
||||
RedisKeyGaiaModelPricingPrefix = "gaia:model_pricing:"
|
||||
RedisKeyGaiaForwardDingPrefix = "gaia:forward:ding:"
|
||||
RedisKeyModelProviderCredentialsPrefix = "model_provider_credentials:"
|
||||
)
|
||||
|
||||
// BuiltinModelPricing 内置兜底定价表(当 Dify Console API 未返回该模型定价时使用)。
|
||||
// 价格单位:每千 token(与 ModelPricing.Unit=0.001 对应),货币为各模型实际结算货币。
|
||||
// 通义/百炼模型官方定价(人民币,参考 https://help.aliyun.com/document_detail/2586379.html):
|
||||
// - 输入/输出价格均为「每百万 token」,换算为每千 token 时除以 1000。
|
||||
var BuiltinModelPricing = map[string]ModelPricing{
|
||||
// ──── 通义千问 Qwen3 系列(RMB / 百万 token,128K 档) ────
|
||||
"qwen3-235b-a22b": {Input: 0.4 / 1000, Output: 1.6 / 1000, Unit: 0.001, Currency: "RMB"},
|
||||
"qwen3-30b-a3b": {Input: 0.11 / 1000, Output: 0.44 / 1000, Unit: 0.001, Currency: "RMB"},
|
||||
"qwen3-32b": {Input: 0.8 / 1000, Output: 3.2 / 1000, Unit: 0.001, Currency: "RMB"},
|
||||
"qwen3-14b": {Input: 0.3 / 1000, Output: 1.2 / 1000, Unit: 0.001, Currency: "RMB"},
|
||||
"qwen3-8b": {Input: 0.1 / 1000, Output: 0.4 / 1000, Unit: 0.001, Currency: "RMB"},
|
||||
"qwen3-4b": {Input: 0.04 / 1000, Output: 0.16 / 1000, Unit: 0.001, Currency: "RMB"},
|
||||
"qwen3-1.7b": {Input: 0.02 / 1000, Output: 0.08 / 1000, Unit: 0.001, Currency: "RMB"},
|
||||
"qwen3-0.6b": {Input: 0.01 / 1000, Output: 0.04 / 1000, Unit: 0.001, Currency: "RMB"},
|
||||
|
||||
// ──── 通义千问 Qwen3.5 系列(RMB / 百万 token,128K 档) ────
|
||||
"qwen3.5-plus": {Input: 0.8 / 1000, Output: 4.8 / 1000, Unit: 0.001, Currency: "RMB"},
|
||||
"qwen3.5-turbo": {Input: 0.3 / 1000, Output: 1.2 / 1000, Unit: 0.001, Currency: "RMB"},
|
||||
|
||||
// ──── 通义千问 Qwen2.5 系列(RMB / 百万 token) ────
|
||||
"qwen2.5-72b-instruct": {Input: 4.0 / 1000, Output: 12.0 / 1000, Unit: 0.001, Currency: "RMB"},
|
||||
"qwen2.5-32b-instruct": {Input: 3.5 / 1000, Output: 7.0 / 1000, Unit: 0.001, Currency: "RMB"},
|
||||
"qwen2.5-14b-instruct": {Input: 2.0 / 1000, Output: 6.0 / 1000, Unit: 0.001, Currency: "RMB"},
|
||||
"qwen2.5-7b-instruct": {Input: 1.0 / 1000, Output: 2.0 / 1000, Unit: 0.001, Currency: "RMB"},
|
||||
"qwen2.5-3b-instruct": {Input: 0.3 / 1000, Output: 0.6 / 1000, Unit: 0.001, Currency: "RMB"},
|
||||
"qwen-plus": {Input: 0.8 / 1000, Output: 2.0 / 1000, Unit: 0.001, Currency: "RMB"},
|
||||
"qwen-turbo": {Input: 0.3 / 1000, Output: 0.6 / 1000, Unit: 0.001, Currency: "RMB"},
|
||||
"qwen-max": {Input: 40.0 / 1000, Output: 120.0 / 1000, Unit: 0.001, Currency: "RMB"},
|
||||
"qwen-long": {Input: 0.5 / 1000, Output: 2.0 / 1000, Unit: 0.001, Currency: "RMB"},
|
||||
|
||||
// ──── 月之暗面 Kimi 系列(RMB / 百万 token) ────
|
||||
// Kimi 走 tongyi(百炼)渠道转发,命名沿用 kimi 前缀;前缀匹配会让 kimi2-k2.6-xxx 也命中 kimi2-k2.6
|
||||
"kimi2-k2.6": {Input: 4.0 / 1000, Output: 16.0 / 1000, Unit: 0.001, Currency: "RMB"},
|
||||
"moonshot-v1-8k": {Input: 12.0 / 1000, Output: 12.0 / 1000, Unit: 0.001, Currency: "RMB"},
|
||||
"moonshot-v1-32k": {Input: 24.0 / 1000, Output: 24.0 / 1000, Unit: 0.001, Currency: "RMB"},
|
||||
"moonshot-v1-128k": {Input: 60.0 / 1000, Output: 60.0 / 1000, Unit: 0.001, Currency: "RMB"},
|
||||
|
||||
// ──── Anthropic Claude 系列(USD / 百万 token) ────
|
||||
// Claude 4.6 / 4.7 系列(Sonnet 与 Opus);anthropic 直连与 AWS Bedrock 走同一份定价
|
||||
"claude-sonnet-4-6": {Input: 3.0 / 1000, Output: 15.0 / 1000, Unit: 0.001, Currency: "USD"},
|
||||
"claude-sonnet-4-7": {Input: 3.0 / 1000, Output: 15.0 / 1000, Unit: 0.001, Currency: "USD"},
|
||||
"claude-opus-4-6": {Input: 15.0 / 1000, Output: 75.0 / 1000, Unit: 0.001, Currency: "USD"},
|
||||
"claude-opus-4-7": {Input: 15.0 / 1000, Output: 75.0 / 1000, Unit: 0.001, Currency: "USD"},
|
||||
// AWS Bedrock 上常用的模型 ID 形式(带 anthropic. 前缀与 -v1:0 后缀),单独列出避免前缀匹配漂移
|
||||
"anthropic.claude-sonnet-4-6-v1:0": {Input: 3.0 / 1000, Output: 15.0 / 1000, Unit: 0.001, Currency: "USD"},
|
||||
"anthropic.claude-sonnet-4-7-v1:0": {Input: 3.0 / 1000, Output: 15.0 / 1000, Unit: 0.001, Currency: "USD"},
|
||||
"anthropic.claude-opus-4-6-v1:0": {Input: 15.0 / 1000, Output: 75.0 / 1000, Unit: 0.001, Currency: "USD"},
|
||||
"anthropic.claude-opus-4-7-v1:0": {Input: 15.0 / 1000, Output: 75.0 / 1000, Unit: 0.001, Currency: "USD"},
|
||||
|
||||
// ──── OpenAI 图片生成(按次计费,Input 字段表示「每次请求的 USD 单价」) ────
|
||||
// 命中分支见 service/gaia/model_provider.go 中的 isImageOrPerRequestPath 与 ProxyRequest 计费逻辑
|
||||
"gpt-image-1": {Input: 0.04, Currency: "USD"},
|
||||
"gpt-image-2": {Input: 0.05, Currency: "USD"},
|
||||
|
||||
// ──── DeepSeek 系列(USD / 百万 token) ────
|
||||
// deepseek-v4-pro:旗舰推理模型,定价参考官方 https://platform.deepseek.com/api-docs/pricing
|
||||
"deepseek-v4-pro": {Input: 2.19 / 1000, Output: 8.19 / 1000, Unit: 0.001, Currency: "USD"},
|
||||
// deepseek-v4-flash:高速轻量模型
|
||||
"deepseek-v4-flash": {Input: 0.27 / 1000, Output: 1.10 / 1000, Unit: 0.001, Currency: "USD"},
|
||||
}
|
||||
|
||||
@@ -1,5 +1,11 @@
|
||||
package request
|
||||
|
||||
// GetProxyLogsReq 代理日志分页请求
|
||||
type GetProxyLogsReq struct {
|
||||
Page int `form:"page"` // 页码,从 1 开始
|
||||
PageSize int `form:"page_size"` // 每页条数,最大 100
|
||||
}
|
||||
|
||||
// ChatRequest 聊天请求(OpenAI 兼容)
|
||||
type ChatRequest struct {
|
||||
Model string `json:"model"`
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
package request
|
||||
|
||||
import "time"
|
||||
|
||||
// SystemOAuth2Error OAuth2 错误返回
|
||||
type SystemOAuth2Error struct {
|
||||
Code int `json:"code" gorm:"comment:分类"` // 错误代码
|
||||
@@ -8,56 +10,102 @@ type SystemOAuth2Error struct {
|
||||
|
||||
// SystemOAuth2Request OAuth2 集成配置
|
||||
type SystemOAuth2Request struct {
|
||||
Classify uint `json:"classify" gorm:"comment:分类"` // 分类
|
||||
Status bool `json:"status" gorm:"comment:状态"` // 状态
|
||||
ServerURL string `json:"server_url" gorm:"comment:服务器地址"` // OAuth2 服务器地址
|
||||
AuthorizeURL string `json:"authorize_url" gorm:"comment:申请认证的URL"` // 申请认证的URL
|
||||
TokenURL string `json:"token_url" gorm:"comment:获取Token的URL"` // 获取Token的URL
|
||||
UserinfoURL string `json:"userinfo_url" gorm:"comment:获取用户信息URL"` // 获取用户信息的URL
|
||||
LogoutURL string `json:"logout_url" gorm:"comment:退出登录回调URL"` // 退出登录回调URL
|
||||
DiscoveryURL string `json:"discovery_url" gorm:"comment:OIDC发现配置URL"` // OIDC 发现配置URL
|
||||
AppID string `json:"app_id" gorm:"comment:Client ID"` // Client ID
|
||||
AppSecret string `json:"app_secret" gorm:"comment:Client Secret"` // Client Secret
|
||||
UserNameField string `json:"user_name_field" gorm:"comment:用户名字段"` // 用户名字段
|
||||
UserEmailField string `json:"user_email_field" gorm:"comment:邮箱字段"` // 邮箱字段
|
||||
UserIDField string `json:"user_id_field" gorm:"comment:用户唯一标识字段"` // 用户唯一标识字段
|
||||
Scope string `json:"scope" gorm:"comment:授权范围scope"` // 授权范围
|
||||
TokenAuthMethod string `json:"token_auth_method" gorm:"comment:令牌端点认证方式"` // client_secret_post|client_secret_basic
|
||||
RedirectUri string `json:"redirect_uri" gorm:"comment:测试用回调地址"` // 测试用回调地址
|
||||
Test bool `json:"test" gorm:"default:0;comment:是否测试链接联通性"` // 是否测试链接联通性
|
||||
Code string `json:"code" gorm:"default:0;comment:code代码"` // code代码
|
||||
Classify uint `json:"classify" gorm:"comment:分类"` // 分类
|
||||
Status bool `json:"status" gorm:"comment:状态"` // 状态
|
||||
ServerURL string `json:"server_url" gorm:"comment:服务器地址"` // OAuth2 服务器地址
|
||||
AuthorizeURL string `json:"authorize_url" gorm:"comment:申请认证的 URL"` // 申请认证的 URL
|
||||
TokenURL string `json:"token_url" gorm:"comment:获取 Token 的 URL"` // 获取 Token 的 URL
|
||||
UserinfoURL string `json:"userinfo_url" gorm:"comment:获取用户信息 URL"` // 获取用户信息的 URL
|
||||
LogoutURL string `json:"logout_url" gorm:"comment:退出登录回调 URL"` // 退出登录回调 URL
|
||||
DiscoveryURL string `json:"discovery_url" gorm:"comment:OIDC 发现配置 URL"` // OIDC 发现配置 URL
|
||||
AppID string `json:"app_id" gorm:"comment:Client ID"` // Client ID
|
||||
AppSecret string `json:"app_secret" gorm:"comment:Client Secret"` // Client Secret
|
||||
UserNameField string `json:"user_name_field" gorm:"comment:用户名字段"` // 用户名字段
|
||||
UserEmailField string `json:"user_email_field" gorm:"comment:邮箱字段"` // 邮箱字段
|
||||
UserIDField string `json:"user_id_field" gorm:"comment:用户唯一标识字段"` // 用户唯一标识字段
|
||||
Scope string `json:"scope" gorm:"comment:授权范围 scope"` // 授权范围
|
||||
TokenAuthMethod string `json:"token_auth_method" gorm:"comment:令牌端点认证方式"` // client_secret_post|client_secret_basic
|
||||
RedirectUri string `json:"redirect_uri" gorm:"comment:测试用回调地址"` // 测试用回调地址
|
||||
Test bool `json:"test" gorm:"default:0;comment:是否测试链接联通性"` // 是否测试链接联通性
|
||||
Code string `json:"code" gorm:"default:0;comment:code 代码"` // code 代码
|
||||
}
|
||||
|
||||
// ValueType 参数/字段值类型
|
||||
const (
|
||||
ValueTypeString = "string" // 字符串类型
|
||||
ValueTypeInt = "int" // 整数类型
|
||||
ValueTypeBool = "bool" // 布尔类型
|
||||
ValueTypeDingID = "ding_id" // 钉钉 ID 类型(运行时自动替换)
|
||||
)
|
||||
|
||||
// DingIDMarker Raw 模式下钉钉 ID 占位符
|
||||
const DingIDMarker = "$<{[ding_id]}>"
|
||||
|
||||
// AuthorizationConfig 认证配置
|
||||
type AuthorizationConfig struct {
|
||||
Type string `json:"type"` // none | bearer | basic
|
||||
Token string `json:"token"` // Bearer Token
|
||||
Username string `json:"username"` // Basic Auth用户名
|
||||
Password string `json:"password"` // Basic Auth密码
|
||||
Username string `json:"username"` // Basic Auth 用户名
|
||||
Password string `json:"password"` // Basic Auth 密码
|
||||
}
|
||||
|
||||
// BodyData Body数据配置
|
||||
// RequestParam URL 查询参数配置
|
||||
type RequestParam struct {
|
||||
Key string `json:"key"` // 参数名
|
||||
ValueType string `json:"value_type"` // string | int | bool | ding_id
|
||||
Value string `json:"value"` // 参数值(ding_id 类型时运行时自动替换)
|
||||
}
|
||||
|
||||
// BodyField Body 字段配置(支持类型化)
|
||||
type BodyField struct {
|
||||
Key string `json:"key"` // 字段名
|
||||
ValueType string `json:"value_type"` // string | int | bool | ding_id
|
||||
Value string `json:"value"` // 字段值
|
||||
}
|
||||
|
||||
// BodyData Body 数据配置
|
||||
type BodyData struct {
|
||||
FormData []map[string]string `json:"form_data"` // form-data格式数据
|
||||
Urlencoded []map[string]string `json:"urlencoded"` // x-www-form-urlencoded格式数据
|
||||
Raw string `json:"raw"` // raw JSON字符串
|
||||
FormData []BodyField `json:"form_data"` // form-data 格式数据(新格式)
|
||||
Urlencoded []BodyField `json:"urlencoded"` // x-www-form-urlencoded 格式数据(新格式)
|
||||
Raw string `json:"raw"` // raw JSON 字符串
|
||||
}
|
||||
|
||||
// EmailApiConfig 第三方邮箱API配置
|
||||
// EmailApiConfig 第三方邮箱 API 配置
|
||||
type EmailApiConfig struct {
|
||||
Enabled bool `json:"enabled"` // 是否启用
|
||||
URL string `json:"url"` // API地址
|
||||
Method string `json:"method"` // HTTP方法
|
||||
RequestParamField string `json:"request_param_field"` // 请求参数字段名
|
||||
BodyType string `json:"body_type"` // Body类型: form-data | x-www-form-urlencoded | raw
|
||||
Headers map[string]string `json:"headers"` // 请求头
|
||||
Authorization AuthorizationConfig `json:"authorization"` // 认证配置
|
||||
BodyData BodyData `json:"body_data"` // Body数据
|
||||
ResponseEmailField string `json:"response_email_field"` // 响应邮箱字段路径
|
||||
Enabled bool `json:"enabled"` // 是否启用
|
||||
URL string `json:"url"` // API 地址
|
||||
Method string `json:"method"` // HTTP 方法
|
||||
RequestParamField string `json:"request_param_field"` // 请求参数字段名(旧格式兼容)
|
||||
Params []RequestParam `json:"params"` // URL 查询参数列表(新格式)
|
||||
BodyType string `json:"body_type"` // Body 类型:form-data | x-www-form-urlencoded | raw
|
||||
Headers map[string]string `json:"headers"` // 请求头
|
||||
Authorization AuthorizationConfig `json:"authorization"` // 认证配置
|
||||
BodyData BodyData `json:"body_data"` // Body 数据
|
||||
ResponseEmailField string `json:"response_email_field"` // 响应邮箱字段路径
|
||||
}
|
||||
|
||||
// TestEmailApiConfigRequest 测试邮箱 API 配置请求
|
||||
type TestEmailApiConfigRequest struct {
|
||||
Config EmailApiConfig `json:"config"` // 完整的邮箱配置
|
||||
TestDingID string `json:"test_ding_id"` // 测试用的钉钉 ID(可选)
|
||||
}
|
||||
|
||||
// ForwardToken 转发 Token 配置
|
||||
type ForwardToken struct {
|
||||
ID string `json:"id"` // 前端生成的唯一 ID(用于删除)
|
||||
TokenHash string `json:"token_hash"` // SHA256(token)
|
||||
CreatedAt time.Time `json:"created_at"` // 创建时间
|
||||
TokenSecret string `json:"token_secret"` // HMAC 签名密钥(随机生成,服务端保存)
|
||||
}
|
||||
|
||||
// ForwardConfig 转发集成配置
|
||||
type ForwardConfig struct {
|
||||
Enabled bool `json:"enabled"` // 是否启用转发
|
||||
Tokens []ForwardToken `json:"tokens"` // Token 列表,最多 20 个
|
||||
}
|
||||
|
||||
// DingTalkConfigRequest 钉钉集成配置
|
||||
type DingTalkConfigRequest struct {
|
||||
EmailApi EmailApiConfig `json:"email_api"` // 第三方邮箱API配置
|
||||
EmailApi EmailApiConfig `json:"email_api"` // 第三方邮箱 API 配置
|
||||
ForwardConfig ForwardConfig `json:"forward_config"` // 转发集成配置
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,29 @@
|
||||
package response
|
||||
|
||||
// AppQuotaRankingRow 应用配额排名查询单行(仅用于 service 层 GORM 查询扫描)
|
||||
type AppQuotaRankingRow struct {
|
||||
AppID string `gorm:"column:app_id"`
|
||||
TotalCost float64 `gorm:"column:total_cost"`
|
||||
MessageCost float64 `gorm:"column:message_cost"`
|
||||
WorkflowCost float64 `gorm:"column:workflow_cost"`
|
||||
RecordNum float64 `gorm:"column:record_num"`
|
||||
}
|
||||
|
||||
// AppQuotaRankingCache 应用配额排名缓存结构(List + Total)
|
||||
type AppQuotaRankingCache struct {
|
||||
List []GetAppQuotaRankingDataRes
|
||||
Total int64
|
||||
}
|
||||
|
||||
// AiImageQuotaRankingRow AI 图片使用量排名查询单行(仅用于 service 层 GORM 查询扫描)
|
||||
type AiImageQuotaRankingRow struct {
|
||||
Address string `gorm:"column:address"`
|
||||
Path string `gorm:"column:path"`
|
||||
TotalCost float64 `gorm:"column:total_cost"`
|
||||
RecordNum int `gorm:"column:record_num"`
|
||||
Model string `gorm:"column:model"`
|
||||
}
|
||||
|
||||
// GetAccountQuotaRankingDataRes 获取账户配额排名数据的响应结构
|
||||
type GetAccountQuotaRankingDataRes struct {
|
||||
Ranking int `json:"ranking"` // 排名
|
||||
|
||||
@@ -5,6 +5,14 @@ type ProviderCredentials struct {
|
||||
APIKey string `json:"api_key"`
|
||||
Endpoint string `json:"endpoint,omitempty"`
|
||||
APIVersion string `json:"api_version,omitempty"` // Azure OpenAI API 版本
|
||||
|
||||
// AWS Bedrock 直连用:access key + secret + region(不走 APIKey/Endpoint)
|
||||
AWSAccessKeyID string `json:"aws_access_key_id,omitempty"`
|
||||
AWSSecretAccessKey string `json:"aws_secret_access_key,omitempty"`
|
||||
AWSSessionToken string `json:"aws_session_token,omitempty"`
|
||||
AWSRegion string `json:"aws_region,omitempty"`
|
||||
// Bedrock 可选代理地址(host:port 或 http://host:port),非空时请求经该代理转发到 AWS
|
||||
BedrockProxyURL string `json:"bedrock_proxy_url,omitempty"`
|
||||
}
|
||||
|
||||
// ModelInfo 模型信息
|
||||
@@ -40,10 +48,10 @@ type OpenAIModelsListResponse struct {
|
||||
type TongyiModelsListResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Output struct {
|
||||
Total int `json:"total"`
|
||||
PageNo int `json:"page_no"`
|
||||
PageSize int `json:"page_size"`
|
||||
Models []TongyiModelItem `json:"models"`
|
||||
Total int `json:"total"`
|
||||
PageNo int `json:"page_no"`
|
||||
PageSize int `json:"page_size"`
|
||||
Models []TongyiModelItem `json:"models"`
|
||||
} `json:"output"`
|
||||
}
|
||||
|
||||
@@ -55,8 +63,8 @@ type TongyiModelItem struct {
|
||||
|
||||
// GeminiModelsListResponse Google Gemini GET /v1beta/models 返回:models[] + nextPageToken
|
||||
type GeminiModelsListResponse struct {
|
||||
Models []GeminiModelItem `json:"models"`
|
||||
NextPageToken string `json:"nextPageToken"`
|
||||
Models []GeminiModelItem `json:"models"`
|
||||
NextPageToken string `json:"nextPageToken"`
|
||||
}
|
||||
|
||||
// GeminiModelItem Gemini 模型单项,name 为 "models/gemini-xxx",baseModelId 用于请求
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
package response
|
||||
|
||||
import "time"
|
||||
|
||||
type CheckAccountQuotaRow struct {
|
||||
TotalQuota float64 `gorm:"column:total_quota"`
|
||||
UsedQuota float64 `gorm:"column:used_quota"`
|
||||
}
|
||||
|
||||
// ForwardTokenInfo 转发 Token 列表项(不暴露内部 ID)
|
||||
type ForwardTokenInfo struct {
|
||||
ID string `json:"id"` // token
|
||||
Seq int `json:"seq"` // 1..N 序列号(用于删除)
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// ForwardTokensResponse 获取转发 Token 列表响应
|
||||
type ForwardTokensResponse struct {
|
||||
Tokens []ForwardTokenInfo `json:"tokens"`
|
||||
Count int `json:"count"`
|
||||
Max int `json:"max"`
|
||||
}
|
||||
|
||||
// TestEmailApiConfigResponse 测试邮箱 API 配置响应
|
||||
type TestEmailApiConfigResponse struct {
|
||||
StatusCode int `json:"status_code"` // HTTP 状态码
|
||||
Body interface{} `json:"body"` // 响应 Body(JSON 时为对象,否则为字符串)
|
||||
EmailFieldPreview string `json:"email_field_preview"` // 邮箱字段解析预览(如 data[0].userName = test@example.com)
|
||||
IsValid bool `json:"is_valid"` // 配置是否有效(能正确提取邮箱)
|
||||
ErrorMessage string `json:"error_message,omitempty"` // 错误信息(可选)
|
||||
}
|
||||
@@ -5,6 +5,7 @@ const SystemIntegrationDingTalk = uint(1) // 钉钉集成
|
||||
const SystemIntegrationWeiXin = uint(2) // 微信集成
|
||||
const SystemIntegrationFeiShu = uint(3) // 飞书集成
|
||||
const SystemIntegrationOAuth2 = uint(4) // OAuth2集成
|
||||
const BearerLength = 7 // OAuth2集成
|
||||
|
||||
// SystemIntegration 系统集成表
|
||||
type SystemIntegration struct {
|
||||
|
||||
@@ -23,3 +23,4 @@ var testApi = api.ApiGroupApp.GaiaApiGroup.TestApi
|
||||
var batchWorkflowApi = api.ApiGroupApp.GaiaApiGroup.BatchWorkflowApi
|
||||
var appVersionApi = api.ApiGroupApp.GaiaApiGroup.AppVersionApi
|
||||
var modelProviderApi = api.ApiGroupApp.GaiaApiGroup.ModelProviderApi
|
||||
var forwardProxyApi = api.ApiGroupApp.GaiaApiGroup.ForwardProxyApi
|
||||
|
||||
@@ -10,16 +10,30 @@ type SystemRouter struct{}
|
||||
func (s *SystemRouter) InitSystemRouter(Router *gin.RouterGroup) {
|
||||
systemRouter := Router.Group("gaia/system")
|
||||
{
|
||||
systemRouter.GET("dingtalk", systemApi.GetDingTalk) // 获取钉钉系统配置
|
||||
systemRouter.POST("dingtalk", systemApi.SetDingTalk) // 设置钉钉系统配置
|
||||
systemRouter.GET("oauth2", systemOAuth2Api.GetOAuth2Config) // 获取OAuth2配置
|
||||
systemRouter.POST("oauth2", systemOAuth2Api.SetOAuth2Config) // 设置OAuth2配置
|
||||
systemRouter.GET("dingtalk", systemApi.GetDingTalk) // 获取钉钉系统配置
|
||||
systemRouter.POST("dingtalk", systemApi.SetDingTalk) // 设置钉钉系统配置
|
||||
systemRouter.GET("dingtalk/test-auth-url", systemApi.GetDingTalkTestAuthURL) // 测试连接:获取钉钉授权 URL
|
||||
systemRouter.POST("dingtalk/test-callback", systemApi.DingTalkTestCallback) // 测试连接:回调验证 code
|
||||
systemRouter.GET("oauth2", systemOAuth2Api.GetOAuth2Config) // 获取 OAuth2 配置
|
||||
systemRouter.POST("oauth2", systemOAuth2Api.SetOAuth2Config) // 设置 OAuth2 配置
|
||||
// 邮箱 API 配置测试
|
||||
systemRouter.POST("dingtalk/test-email-config", systemApi.TestEmailApiConfig) // 测试第三方邮箱 API 配置
|
||||
// 转发 Token 管理
|
||||
systemRouter.GET("forward-tokens", systemApi.GetForwardTokens) // 获取转发 Token 列表
|
||||
systemRouter.POST("forward-tokens", systemApi.CreateForwardToken) // 新增转发 Token
|
||||
systemRouter.DELETE("forward-tokens/:seq", systemApi.DeleteForwardToken) // 删除转发 Token(按序列号)
|
||||
}
|
||||
}
|
||||
|
||||
// InitForwardProxyRouter 初始化 GPT 转发代理路由
|
||||
func (s *SystemRouter) InitForwardProxyRouter(PublicRouter *gin.RouterGroup) {
|
||||
// 免 JWT 转发入口,通过 forwarding token + ding_id 鉴权
|
||||
PublicRouter.Any("gaia/forward/proxy/*path", forwardProxyApi.ForwardProxy)
|
||||
}
|
||||
|
||||
// InitModelProviderRouter 初始化模型提供商路由
|
||||
func (s *SystemRouter) InitModelProviderRouter(Router *gin.RouterGroup) {
|
||||
// 管理端API(需要JWT认证)
|
||||
// 管理端 API(需要 JWT 认证)
|
||||
modelProviderRouter := Router.Group("gaia/model-provider")
|
||||
{
|
||||
modelProviderRouter.GET("list", modelProviderApi.GetProviderList) // 获取提供商配置列表
|
||||
@@ -29,10 +43,10 @@ func (s *SystemRouter) InitModelProviderRouter(Router *gin.RouterGroup) {
|
||||
modelProviderRouter.GET("logs", modelProviderApi.GetProxyLogs) // 获取代理日志
|
||||
}
|
||||
|
||||
// 第三方API(需要JWT认证)
|
||||
// 第三方 API(需要 JWT 认证)
|
||||
gaiaRouter := Router.Group("gaia")
|
||||
{
|
||||
gaiaRouter.GET("models", modelProviderApi.GetModels) // 获取开启的模型列表(OpenAI格式)
|
||||
gaiaRouter.Any("proxy/*path", modelProviderApi.Proxy) // 通用中转API:按路径转发(v1/chat/completions、v1/messages、v1/images/generations、v1/embeddings 等)
|
||||
gaiaRouter.GET("models", modelProviderApi.GetModels) // 获取开启的模型列表(OpenAI 格式)
|
||||
gaiaRouter.Any("proxy/*path", modelProviderApi.Proxy) // 通用中转 API:按路径转发(v1/chat/completions、v1/messages、v1/images/generations、v1/embeddings 等)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -28,6 +28,31 @@ func (s *BatchWorkflowService) CreateBatchWorkflow(
|
||||
return nil, fmt.Errorf("数据库连接未初始化")
|
||||
}
|
||||
|
||||
// 计算本次上传的有效数据行数(去掉表头和空行)
|
||||
uploadedDataRows := 0
|
||||
if len(fileContent) > 1 {
|
||||
for _, row := range fileContent[1:] {
|
||||
for _, v := range row {
|
||||
if strings.TrimSpace(v) != "" {
|
||||
uploadedDataRows++
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 检查当前用户 pending 状态队列中是否已有相同文件(文件名 + 行数一致视为重复)
|
||||
var duplicateCount int64
|
||||
if err := global.GVA_DB.Model(&gaia.BatchWorkflow{}).Where(
|
||||
"user_id = ? AND file_name = ? AND installed_id = ? AND total_rows = ? AND status = ?",
|
||||
userId, fileName, installedID, uploadedDataRows, gaia.BatchWorkflowStatusPending).
|
||||
Count(&duplicateCount).Error; err != nil {
|
||||
return nil, fmt.Errorf("检查重复文件失败: %v", err)
|
||||
}
|
||||
if duplicateCount > 0 {
|
||||
return nil, fmt.Errorf("文件重复上传:当前队列中已存在相同文件(文件名:%s,行数:%d),请勿重复提交", fileName, uploadedDataRows)
|
||||
}
|
||||
|
||||
// 创建批量处理记录
|
||||
keyByte, _ := json.Marshal(keyNameMapping)
|
||||
batchWorkflow := &gaia.BatchWorkflow{
|
||||
@@ -363,8 +388,8 @@ func (s *BatchWorkflowService) callDifyAPI(
|
||||
}
|
||||
// Extend End: 添加CSRF token支持
|
||||
|
||||
// 发送请求
|
||||
client := &http.Client{}
|
||||
// 发送请求(设置超时,避免 Dify 卡住时 goroutine 与连接长期占用不释放)
|
||||
client := &http.Client{Timeout: 5 * time.Minute}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
||||
@@ -0,0 +1,399 @@
|
||||
package gaia
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||
v4 "github.com/aws/aws-sdk-go/aws/signer/v4"
|
||||
estream "github.com/aws/aws-sdk-go/private/protocol/eventstream"
|
||||
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/global"
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/model/gaia"
|
||||
gaiaResponse "github.com/flipped-aurora/gin-vue-admin/server/model/gaia/response"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// proxyBedrockRequest 直连 AWS Bedrock 原生 API 转发 Anthropic Messages 请求。
|
||||
//
|
||||
// 路径转换:v1/messages → model/{modelId}/invoke 或 model/{modelId}/invoke-with-response-stream
|
||||
// 鉴权:SigV4(service=bedrock,region 来自 Dify 凭证 aws_region 字段)
|
||||
// 请求体改写:去掉 model 字段(Bedrock 走 URL 路径),注入 anthropic_version=bedrock-2023-05-31
|
||||
// 响应:
|
||||
// - 非流式:Bedrock 返回 Anthropic Messages JSON(含 usage.input_tokens/output_tokens),原样转发
|
||||
// - 流式:vnd.amazon.eventstream 二进制帧,每帧 payload 是 {"bytes":"<base64>"},
|
||||
// 解出后是 Anthropic SSE 事件 JSON,在此重组为标准 SSE 写回客户端
|
||||
//
|
||||
// 计费:成功后按 (input_tokens, output_tokens) 调 calcQuotaDelta 扣额。
|
||||
func (s *ModelProviderService) proxyBedrockRequest(
|
||||
userID, _ /* path */, method string, _ /* reqHeader */ http.Header, body []byte, writer io.Writer,
|
||||
creds *gaiaResponse.ProviderCredentials,
|
||||
) error {
|
||||
// 1) 校验 AWS 凭证
|
||||
if creds == nil || creds.AWSAccessKeyID == "" || creds.AWSSecretAccessKey == "" {
|
||||
return fmt.Errorf("AWS Bedrock 凭证缺失(需要 aws_access_key_id / aws_secret_access_key)")
|
||||
}
|
||||
region := creds.AWSRegion
|
||||
if region == "" {
|
||||
region = "us-east-1"
|
||||
}
|
||||
|
||||
// 2) 解析 body:拿到 modelId 与 stream 标记,并改写为 Bedrock 期望的格式
|
||||
if len(body) == 0 {
|
||||
return fmt.Errorf("Bedrock 请求 body 不能为空")
|
||||
}
|
||||
var bodyObj map[string]interface{}
|
||||
if err := json.Unmarshal(body, &bodyObj); err != nil {
|
||||
return fmt.Errorf("解析 Bedrock 请求 body 失败:%w", err)
|
||||
}
|
||||
modelID, _ := bodyObj["model"].(string)
|
||||
if modelID == "" {
|
||||
return fmt.Errorf("Bedrock 请求 body 缺少 model 字段")
|
||||
}
|
||||
streaming := false
|
||||
if v, ok := bodyObj["stream"].(bool); ok {
|
||||
streaming = v
|
||||
}
|
||||
// 移除 model(Bedrock 不需要);删除 stream(流式由 URL 决定)
|
||||
delete(bodyObj, "model")
|
||||
delete(bodyObj, "stream")
|
||||
delete(bodyObj, "stream_options")
|
||||
// OpenAI 兼容字段转换:max_completion_tokens → max_tokens(Bedrock/Anthropic 使用 max_tokens)
|
||||
if _, hasMaxTokens := bodyObj["max_tokens"]; !hasMaxTokens {
|
||||
if v, ok := bodyObj["max_completion_tokens"]; ok {
|
||||
bodyObj["max_tokens"] = v
|
||||
delete(bodyObj, "max_completion_tokens")
|
||||
}
|
||||
} else {
|
||||
delete(bodyObj, "max_completion_tokens")
|
||||
}
|
||||
// 注入 Bedrock 必需的 anthropic_version
|
||||
if _, ok := bodyObj["anthropic_version"]; !ok {
|
||||
bodyObj["anthropic_version"] = "bedrock-2023-05-31"
|
||||
}
|
||||
rewritten, err := json.Marshal(bodyObj)
|
||||
if err != nil {
|
||||
return fmt.Errorf("重写 Bedrock 请求 body 失败:%w", err)
|
||||
}
|
||||
|
||||
// 3) 构建 Bedrock URL
|
||||
// 新一代 Claude 模型(3.5v2、3.7、Sonnet-4、Opus-4 等)要求通过跨区域推理配置文件调用,
|
||||
// 模型 ID 需加地理前缀(us. / eu. / ap.),否则 Bedrock 返回 "on-demand throughput isn't supported" 错误。
|
||||
// 若调用方已传入带前缀的 ID(如 us.anthropic.xxx)则直接使用,不重复添加。
|
||||
invokeModelID := bedrockResolveModelID(modelID, region)
|
||||
if invokeModelID != modelID {
|
||||
global.GVA_LOG.Info("Bedrock 模型 ID 已映射为跨区域推理配置文件",
|
||||
zap.String("original", modelID),
|
||||
zap.String("resolved", invokeModelID),
|
||||
zap.String("region", region),
|
||||
)
|
||||
}
|
||||
host := fmt.Sprintf("bedrock-runtime.%s.amazonaws.com", region)
|
||||
op := "invoke"
|
||||
if streaming {
|
||||
op = "invoke-with-response-stream"
|
||||
}
|
||||
requestURL := fmt.Sprintf("https://%s/model/%s/%s", host, url.PathEscape(invokeModelID), op)
|
||||
|
||||
// 打印请求地址、参数和代理地址
|
||||
global.GVA_LOG.Info("Bedrock 请求详情",
|
||||
zap.String("request_url", requestURL),
|
||||
zap.String("method", method),
|
||||
zap.ByteString("body", rewritten),
|
||||
zap.String("proxy_url", creds.BedrockProxyURL),
|
||||
)
|
||||
|
||||
httpReq, err := http.NewRequest(method, requestURL, bytes.NewReader(rewritten))
|
||||
if err != nil {
|
||||
return fmt.Errorf("构建 Bedrock 请求失败:%w", err)
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
if streaming {
|
||||
httpReq.Header.Set("Accept", "application/vnd.amazon.eventstream")
|
||||
httpReq.Header.Set("X-Amzn-Bedrock-Accept", "application/json")
|
||||
} else {
|
||||
httpReq.Header.Set("Accept", "application/json")
|
||||
}
|
||||
|
||||
// 4) SigV4 签名(service=bedrock)
|
||||
awsCreds := credentials.NewStaticCredentials(creds.AWSAccessKeyID, creds.AWSSecretAccessKey, creds.AWSSessionToken)
|
||||
signer := v4.NewSigner(awsCreds)
|
||||
if _, err = signer.Sign(httpReq, bytes.NewReader(rewritten), "bedrock", region, time.Now()); err != nil {
|
||||
return fmt.Errorf("Bedrock SigV4 签名失败:%w", err)
|
||||
}
|
||||
|
||||
// 5) 发起请求(若配置了 bedrock_proxy_url 则经 HTTP 代理转发)
|
||||
startTime := time.Now()
|
||||
transport := http.DefaultTransport
|
||||
if creds.BedrockProxyURL != "" {
|
||||
proxyAddr := creds.BedrockProxyURL
|
||||
if !strings.HasPrefix(proxyAddr, "http://") && !strings.HasPrefix(proxyAddr, "https://") {
|
||||
proxyAddr = "http://" + proxyAddr
|
||||
}
|
||||
if proxyURL, parseErr := url.Parse(proxyAddr); parseErr == nil {
|
||||
transport = &http.Transport{Proxy: http.ProxyURL(proxyURL)}
|
||||
}
|
||||
}
|
||||
client := &http.Client{Timeout: 5 * time.Minute, Transport: transport}
|
||||
resp, err := client.Do(httpReq)
|
||||
if err != nil {
|
||||
s.logBedrock(userID, modelID, "error", err.Error(), startTime, 0, 0)
|
||||
return err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
// 6) 写回响应头/状态码(流式改写 Content-Type 为 SSE)
|
||||
if w, ok := writer.(http.ResponseWriter); ok {
|
||||
for k, v := range resp.Header {
|
||||
lower := strings.ToLower(k)
|
||||
if streaming && (lower == "content-type" || lower == "content-length") {
|
||||
continue
|
||||
}
|
||||
for _, vv := range v {
|
||||
w.Header().Add(k, vv)
|
||||
}
|
||||
}
|
||||
if streaming {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
}
|
||||
w.WriteHeader(resp.StatusCode)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
|
||||
raw, _ := io.ReadAll(resp.Body)
|
||||
_, _ = writer.Write(raw)
|
||||
s.logBedrock(userID, modelID, "error",
|
||||
fmt.Sprintf("bedrock %d: %s", resp.StatusCode, string(raw)), startTime, 0, 0)
|
||||
return nil
|
||||
}
|
||||
|
||||
// 7) 处理响应体
|
||||
var inputTokens, outputTokens int
|
||||
if streaming {
|
||||
inputTokens, outputTokens, err = s.streamBedrockEventStream(resp.Body, writer)
|
||||
if err != nil {
|
||||
s.logBedrock(userID, modelID, "error", err.Error(), startTime, inputTokens, outputTokens)
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
var buf bytes.Buffer
|
||||
tee := io.TeeReader(resp.Body, &buf)
|
||||
if _, err = io.Copy(writer, tee); err != nil {
|
||||
s.logBedrock(userID, modelID, "error", err.Error(), startTime, 0, 0)
|
||||
return err
|
||||
}
|
||||
inputTokens, outputTokens = parseAnthropicUsage(buf.Bytes())
|
||||
}
|
||||
|
||||
// 8) 记录日志 + 计费扣款
|
||||
s.logBedrock(userID, modelID, "success", "", startTime, inputTokens, outputTokens)
|
||||
if inputTokens > 0 || outputTokens > 0 {
|
||||
pricing, _ := s.fetchModelPricingFromDify(modelID)
|
||||
delta := calcQuotaDelta(pricing, modelID, inputTokens, outputTokens)
|
||||
deductAccountQuota(userID, delta)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// streamBedrockEventStream 解析 Bedrock 的 vnd.amazon.eventstream 二进制流,
|
||||
// 把每个事件还原为 Anthropic SSE(event: <type>\ndata: <json>\n\n)写给客户端。
|
||||
// 返回累计的 input/output token 数(用于计费)。
|
||||
func (s *ModelProviderService) streamBedrockEventStream(r io.Reader, w io.Writer) (int, int, error) {
|
||||
flusher, _ := w.(http.Flusher)
|
||||
dec := estream.NewDecoder(r)
|
||||
payloadBuf := make([]byte, 0, 32*1024)
|
||||
|
||||
var inputTokens, outputTokens int
|
||||
for {
|
||||
msg, err := dec.Decode(payloadBuf)
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
return inputTokens, outputTokens, nil
|
||||
}
|
||||
return inputTokens, outputTokens, fmt.Errorf("eventstream decode 失败:%w", err)
|
||||
}
|
||||
|
||||
// Bedrock 的事件 payload 形如 {"bytes":"<base64-encoded inner JSON>"}
|
||||
var wrap struct {
|
||||
Bytes string `json:"bytes"`
|
||||
}
|
||||
var inner []byte
|
||||
if e := json.Unmarshal(msg.Payload, &wrap); e == nil && wrap.Bytes != "" {
|
||||
if decoded, e2 := base64.StdEncoding.DecodeString(wrap.Bytes); e2 == nil {
|
||||
inner = decoded
|
||||
}
|
||||
}
|
||||
if len(inner) == 0 {
|
||||
// 非包装格式(如错误/ping),直接用原 payload
|
||||
inner = msg.Payload
|
||||
}
|
||||
|
||||
// 解析事件类型和 usage(Anthropic 在 message_start.message.usage 给 input_tokens,
|
||||
// message_delta.usage 给 output_tokens)
|
||||
var ev struct {
|
||||
Type string `json:"type"`
|
||||
Message struct {
|
||||
Usage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
} `json:"usage"`
|
||||
} `json:"message"`
|
||||
Usage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
} `json:"usage"`
|
||||
}
|
||||
_ = json.Unmarshal(inner, &ev)
|
||||
if ev.Message.Usage.InputTokens > 0 {
|
||||
inputTokens = ev.Message.Usage.InputTokens
|
||||
}
|
||||
if ev.Message.Usage.OutputTokens > 0 {
|
||||
outputTokens = ev.Message.Usage.OutputTokens
|
||||
}
|
||||
if ev.Usage.InputTokens > 0 {
|
||||
inputTokens = ev.Usage.InputTokens
|
||||
}
|
||||
if ev.Usage.OutputTokens > 0 {
|
||||
outputTokens = ev.Usage.OutputTokens
|
||||
}
|
||||
|
||||
// 重组为 Anthropic SSE 写回
|
||||
eventName := ev.Type
|
||||
if eventName == "" {
|
||||
eventName = "message"
|
||||
}
|
||||
sse := "event: " + eventName + "\ndata: " + string(inner) + "\n\n"
|
||||
if _, err = w.Write([]byte(sse)); err != nil {
|
||||
return inputTokens, outputTokens, err
|
||||
}
|
||||
if flusher != nil {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// parseAnthropicUsage 从非流式 Anthropic Messages 响应 JSON 中提取 usage 字段。
|
||||
func parseAnthropicUsage(data []byte) (input, output int) {
|
||||
var obj struct {
|
||||
Usage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
} `json:"usage"`
|
||||
}
|
||||
if json.Unmarshal(data, &obj) == nil {
|
||||
return obj.Usage.InputTokens, obj.Usage.OutputTokens
|
||||
}
|
||||
return 0, 0
|
||||
}
|
||||
|
||||
// bedrockCrossRegionPrefixes 是需要跨区域推理配置文件的模型 ID 前缀列表(anthropic. 开头)。
|
||||
// 来源:https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-support.html
|
||||
// 规则:凡模型不在旧版 on-demand 列表中,均需加地理前缀才能调用。
|
||||
var bedrockCrossRegionPrefixes = []string{
|
||||
// Claude 3.5 v2
|
||||
"anthropic.claude-3-5-sonnet-20241022-v2",
|
||||
"anthropic.claude-3-5-haiku-20241022",
|
||||
// Claude 3.7
|
||||
"anthropic.claude-3-7-sonnet",
|
||||
// Claude Sonnet 4 / Opus 4 / Haiku 4(及后续新版本)
|
||||
"anthropic.claude-sonnet-4",
|
||||
"anthropic.claude-opus-4",
|
||||
"anthropic.claude-haiku-4",
|
||||
}
|
||||
|
||||
// bedrockResolveModelID 将用户输入的模型 ID 解析为正确的 Bedrock 调用 ID。
|
||||
//
|
||||
// 支持以下输入格式(会自动规范化):
|
||||
// - 完整带前缀:us.anthropic.claude-sonnet-4-6 → 直接使用
|
||||
// - 带厂商前缀:anthropic.claude-sonnet-4-6 → 按需加 us./eu./ap.
|
||||
// - 短横线名称:claude-sonnet-4-6 → 补 anthropic. 再按需加前缀
|
||||
// - 空格+点号: "claude sonnet 4.6" → 规范化后同上
|
||||
func bedrockResolveModelID(modelID, region string) string {
|
||||
// Step 1: 规范化输入
|
||||
// 小写、空格→横线、版本中的点→横线(如 4.6 → 4-6,但保留 : 用于版本后缀如 v2:0)
|
||||
normalized := strings.ToLower(strings.TrimSpace(modelID))
|
||||
normalized = strings.ReplaceAll(normalized, " ", "-")
|
||||
// 仅将数字之间的 `.` 替换为 `-`(处理 "4.6" → "4-6"),保留 anthropic. 这样的厂商点
|
||||
normalized = replaceVersionDots(normalized)
|
||||
|
||||
// Step 2: 补全 anthropic. 前缀(用户只填了 claude-xxx)
|
||||
if !strings.HasPrefix(normalized, "anthropic.") &&
|
||||
!strings.HasPrefix(normalized, "us.") &&
|
||||
!strings.HasPrefix(normalized, "eu.") &&
|
||||
!strings.HasPrefix(normalized, "ap.") {
|
||||
normalized = "anthropic." + normalized
|
||||
}
|
||||
|
||||
// Step 3: 已经带地理前缀 → 直接使用
|
||||
if strings.HasPrefix(normalized, "us.") || strings.HasPrefix(normalized, "eu.") || strings.HasPrefix(normalized, "ap.") {
|
||||
return normalized
|
||||
}
|
||||
|
||||
// Step 4: 判断是否属于需要跨区域推理配置文件的模型
|
||||
needsCrossRegion := false
|
||||
for _, prefix := range bedrockCrossRegionPrefixes {
|
||||
if strings.HasPrefix(normalized, prefix) {
|
||||
needsCrossRegion = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !needsCrossRegion {
|
||||
return normalized
|
||||
}
|
||||
|
||||
// Step 5: 根据 region 推导地理前缀
|
||||
geoPrefix := "us" // 默认
|
||||
switch {
|
||||
case strings.HasPrefix(region, "us-"):
|
||||
geoPrefix = "us"
|
||||
case strings.HasPrefix(region, "eu-"):
|
||||
geoPrefix = "eu"
|
||||
case strings.HasPrefix(region, "ap-"):
|
||||
geoPrefix = "ap"
|
||||
}
|
||||
return geoPrefix + "." + normalized
|
||||
}
|
||||
|
||||
// replaceVersionDots 将版本号中数字之间的 `.` 替换为 `-`(如 4.6 → 4-6),
|
||||
// 保留厂商命名空间中的点(如 anthropic. 开头不受影响,因为点后紧跟字母)。
|
||||
func replaceVersionDots(s string) string {
|
||||
var b strings.Builder
|
||||
for i := 0; i < len(s); i++ {
|
||||
if s[i] == '.' && i > 0 && i < len(s)-1 {
|
||||
prev := s[i-1]
|
||||
next := s[i+1]
|
||||
// 仅当点号两侧都是数字时才替换为 -
|
||||
if prev >= '0' && prev <= '9' && next >= '0' && next <= '9' {
|
||||
b.WriteByte('-')
|
||||
continue
|
||||
}
|
||||
}
|
||||
b.WriteByte(s[i])
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// logBedrock 记录代理日志(与 ProxyRequest 中的 ModelProxyLog 行为一致)。
|
||||
func (s *ModelProviderService) logBedrock(userID, modelID, status, errMsg string, startTime time.Time, in, out int) {
|
||||
if err := global.GVA_DB.Create(&gaia.ModelProxyLog{
|
||||
UserId: userID,
|
||||
ProviderName: gaia.ProviderAWS,
|
||||
ModelName: modelID,
|
||||
RequestTokens: in,
|
||||
ResponseTokens: out,
|
||||
Status: status,
|
||||
ErrorMessage: errMsg,
|
||||
CreatedAt: startTime,
|
||||
}).Error; err != nil {
|
||||
global.GVA_LOG.Warn("logBedrock 写日志失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
@@ -19,7 +19,8 @@ import (
|
||||
type DashboardService struct{}
|
||||
|
||||
// GetAccountQuotaRankingData 分页获取【账号】额度排名列表
|
||||
func (dashboardService *DashboardService) GetAccountQuotaRankingData(info gaiaReq.GetAccountQuotaRankingDataReq) (list []response.GetAccountQuotaRankingDataRes, total int64, err error) {
|
||||
func (s *DashboardService) GetAccountQuotaRankingData(info gaiaReq.GetAccountQuotaRankingDataReq) (
|
||||
list []response.GetAccountQuotaRankingDataRes, total int64, err error) {
|
||||
limit := info.PageSize
|
||||
offset := info.PageSize * (info.Page - 1)
|
||||
|
||||
@@ -79,15 +80,13 @@ func (dashboardService *DashboardService) GetAccountQuotaRankingData(info gaiaRe
|
||||
}
|
||||
|
||||
// GetAppQuotaRankingData 分页获取【应用】配额排名数据
|
||||
func (dashboardService *DashboardService) GetAppQuotaRankingData(info gaiaReq.GetAppQuotaRankingDataReq) (list []response.GetAppQuotaRankingDataRes, total int64, err error) {
|
||||
func (s *DashboardService) GetAppQuotaRankingData(info gaiaReq.GetAppQuotaRankingDataReq) (
|
||||
list []response.GetAppQuotaRankingDataRes, total int64, err error) {
|
||||
|
||||
cacheKey := fmt.Sprintf("app_token_quota_ranking:%d:%d", info.Page, info.PageSize)
|
||||
var cachedResult struct {
|
||||
List []response.GetAppQuotaRankingDataRes
|
||||
Total int64
|
||||
}
|
||||
var cachedResult response.AppQuotaRankingCache
|
||||
|
||||
if found, err := dashboardService.getCachedResult(cacheKey, &cachedResult); err == nil && found {
|
||||
if found, err := s.getCachedResult(cacheKey, &cachedResult); err == nil && found {
|
||||
return cachedResult.List, cachedResult.Total, nil
|
||||
}
|
||||
|
||||
@@ -102,7 +101,7 @@ func (dashboardService *DashboardService) GetAppQuotaRankingData(info gaiaReq.Ge
|
||||
Select("" +
|
||||
"app_id, " +
|
||||
"COUNT(id) as message_num, " +
|
||||
"SUM(CASE WHEN currency = 'RMB' THEN total_price / 7.26 ELSE total_price END) as message_cost").
|
||||
fmt.Sprintf("SUM(CASE WHEN currency = 'RMB' THEN total_price / %f ELSE total_price END) as message_cost", gaia.RmbToUSDRate)).
|
||||
Group("app_id")
|
||||
|
||||
workflowCosts := global.GVA_DB.Table("public.workflow_node_executions").
|
||||
@@ -111,7 +110,7 @@ func (dashboardService *DashboardService) GetAppQuotaRankingData(info gaiaReq.Ge
|
||||
"COUNT(id) as workflow_num, " +
|
||||
"SUM(CASE " +
|
||||
" WHEN execution_metadata::json->>'currency' = 'RMB' " +
|
||||
" THEN CAST((execution_metadata::json->>'total_price') AS NUMERIC) / 7.26 " +
|
||||
fmt.Sprintf(" THEN CAST((execution_metadata::json->>'total_price') AS NUMERIC) / %f ", gaia.RmbToUSDRate) +
|
||||
" ELSE CAST((execution_metadata::json->>'total_price') AS NUMERIC) " +
|
||||
"END) AS workflow_cost").
|
||||
Where("execution_metadata IS NOT NULL AND execution_metadata != '' AND (execution_metadata::json->>'total_price') IS NOT NULL").
|
||||
@@ -140,13 +139,7 @@ func (dashboardService *DashboardService) GetAppQuotaRankingData(info gaiaReq.Ge
|
||||
}
|
||||
|
||||
// 执行查询
|
||||
var results []struct {
|
||||
AppID string `gorm:"column:app_id"`
|
||||
TotalCost float64 `gorm:"column:total_cost"`
|
||||
MessageCost float64 `gorm:"column:message_cost"`
|
||||
WorkflowCost float64 `gorm:"column:workflow_cost"`
|
||||
RecordNum float64 `gorm:"column:record_num"`
|
||||
}
|
||||
var results []response.AppQuotaRankingRow
|
||||
|
||||
err = query.Find(&results).Error
|
||||
if err != nil {
|
||||
@@ -269,12 +262,8 @@ func (dashboardService *DashboardService) GetAppQuotaRankingData(info gaiaReq.Ge
|
||||
}
|
||||
|
||||
// 在返回结果之前,缓存结果
|
||||
result := struct {
|
||||
List []response.GetAppQuotaRankingDataRes
|
||||
Total int64
|
||||
}{list, total}
|
||||
|
||||
if err := dashboardService.cacheResult(cacheKey, result, 24*time.Hour); err != nil {
|
||||
cachePayload := response.AppQuotaRankingCache{List: list, Total: total}
|
||||
if err := s.cacheResult(cacheKey, cachePayload, 24*time.Hour); err != nil {
|
||||
global.GVA_LOG.Error("Failed to cache result", zap.Error(err))
|
||||
}
|
||||
|
||||
@@ -282,7 +271,8 @@ func (dashboardService *DashboardService) GetAppQuotaRankingData(info gaiaReq.Ge
|
||||
}
|
||||
|
||||
// GetAppTokenQuotaRankingData 分页获取【应用密钥】配额排名数据列表
|
||||
func (dashboardService *DashboardService) GetAppTokenQuotaRankingData(info gaiaReq.GetAppTokenQuotaRankingDataReq) (list []response.GetAppTokenQuotaRankingDataRes, total int64, err error) {
|
||||
func (s *DashboardService) GetAppTokenQuotaRankingData(info gaiaReq.GetAppTokenQuotaRankingDataReq) (
|
||||
list []response.GetAppTokenQuotaRankingDataRes, total int64, err error) {
|
||||
limit := info.PageSize
|
||||
offset := info.PageSize * (info.Page - 1)
|
||||
|
||||
@@ -365,9 +355,11 @@ func (dashboardService *DashboardService) GetAppTokenQuotaRankingData(info gaiaR
|
||||
}
|
||||
|
||||
// GetAppTokenDailyQuotaData 获取每天密钥花费数据列表
|
||||
func (dashboardService *DashboardService) GetAppTokenDailyQuotaData(info gaiaReq.GetAppTokenDailyQuotaDataReq) (list []response.GetAppTokenDailyQuotaDataRes, err error) {
|
||||
func (s *DashboardService) GetAppTokenDailyQuotaData(info gaiaReq.GetAppTokenDailyQuotaDataReq) (
|
||||
list []response.GetAppTokenDailyQuotaDataRes, err error) {
|
||||
|
||||
db := global.GVA_DB.Select("DATE(stat_at) as stat_at, SUM(day_used_quota) as day_used_quota").Model(&gaia.ApiTokenMoneyDailyStatExtend{}).Order("stat_at desc").Group("DATE(stat_at)")
|
||||
db := global.GVA_DB.Select("DATE(stat_at) as stat_at, SUM(day_used_quota) as day_used_quota").Model(
|
||||
&gaia.ApiTokenMoneyDailyStatExtend{}).Order("stat_at desc").Group("DATE(stat_at)")
|
||||
var apiTokenMoneyDailyStatExtends []gaia.ApiTokenMoneyDailyStatExtend
|
||||
|
||||
if info.AppId != "" {
|
||||
@@ -397,7 +389,8 @@ func (dashboardService *DashboardService) GetAppTokenDailyQuotaData(info gaiaReq
|
||||
}
|
||||
|
||||
// GetAiImageQuotaRankingData 获取【AI图片】使用量排名数据列表
|
||||
func (dashboardService *DashboardService) GetAiImageQuotaRankingData(info gaiaReq.GetAiImageQuotaRankingDataReq) (list []response.GetAiImageQuotaRankingRes, err error) {
|
||||
func (s *DashboardService) GetAiImageQuotaRankingData(info gaiaReq.GetAiImageQuotaRankingDataReq) (
|
||||
list []response.GetAiImageQuotaRankingRes, err error) {
|
||||
limit := info.PageSize
|
||||
offset := info.PageSize * (info.Page - 1)
|
||||
|
||||
@@ -427,13 +420,7 @@ func (dashboardService *DashboardService) GetAiImageQuotaRankingData(info gaiaRe
|
||||
db = db.Limit(limit).Offset(offset)
|
||||
}
|
||||
|
||||
var results []struct {
|
||||
Address string
|
||||
Path string
|
||||
TotalCost float64
|
||||
RecordNum int
|
||||
Model string
|
||||
}
|
||||
var results []response.AiImageQuotaRankingRow
|
||||
|
||||
err = db.Find(&results).Error
|
||||
if err != nil {
|
||||
@@ -456,7 +443,7 @@ func (dashboardService *DashboardService) GetAiImageQuotaRankingData(info gaiaRe
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (dashboardService *DashboardService) cacheResult(key string, data interface{}, expiration time.Duration) error {
|
||||
func (s *DashboardService) cacheResult(key string, data interface{}, expiration time.Duration) error {
|
||||
|
||||
jsonData, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
@@ -465,7 +452,7 @@ func (dashboardService *DashboardService) cacheResult(key string, data interface
|
||||
return global.GVA_REDIS.Set(context.Background(), key, jsonData, expiration).Err()
|
||||
}
|
||||
|
||||
func (dashboardService *DashboardService) getCachedResult(key string, result interface{}) (bool, error) {
|
||||
func (s *DashboardService) getCachedResult(key string, result interface{}) (bool, error) {
|
||||
data, err := global.GVA_REDIS.Get(context.Background(), key).Bytes()
|
||||
if err != nil {
|
||||
if errors.Is(err, redis.Nil) {
|
||||
|
||||
@@ -0,0 +1,273 @@
|
||||
package gaia
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/model/gaia/request"
|
||||
)
|
||||
|
||||
// TestBuildURL_NewFormat 测试新格式 Params 自动拼接 URL
|
||||
func TestBuildURL_NewFormat(t *testing.T) {
|
||||
config := request.EmailApiConfig{
|
||||
URL: "https://api.example.com/user",
|
||||
Params: []request.RequestParam{},
|
||||
}
|
||||
|
||||
// 无参数
|
||||
got := buildURL("https://api.example.com/user", config, "USER123")
|
||||
if got != "https://api.example.com/user" {
|
||||
t.Errorf("无参数时 URL 应保持不变,got: %s", got)
|
||||
}
|
||||
|
||||
// 单个 string 类型参数
|
||||
config.Params = []request.RequestParam{
|
||||
{Key: "appKey", ValueType: "string", Value: "mykey"},
|
||||
}
|
||||
got = buildURL("https://api.example.com/user", config, "USER123")
|
||||
if got != "https://api.example.com/user?appKey=mykey" {
|
||||
t.Errorf("单参数拼接错误,got: %s", got)
|
||||
}
|
||||
|
||||
// URL 已有 ? 时使用 &
|
||||
got = buildURL("https://api.example.com/user?type=admin", config, "USER123")
|
||||
if got != "https://api.example.com/user?type=admin&appKey=mykey" {
|
||||
t.Errorf("已有参数时应使用 & 拼接,got: %s", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildURL_DingIDParam 测试钉钉 ID 类型参数自动替换
|
||||
func TestBuildURL_DingIDParam(t *testing.T) {
|
||||
config := request.EmailApiConfig{
|
||||
URL: "https://api.example.com/user",
|
||||
Params: []request.RequestParam{
|
||||
{Key: "userId", ValueType: "ding_id"},
|
||||
},
|
||||
}
|
||||
|
||||
got := buildURL("https://api.example.com/user", config, "USER123")
|
||||
if got != "https://api.example.com/user?userId=USER123" {
|
||||
t.Errorf("钉钉 ID 类型参数应自动替换,got: %s", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildURL_OldFormat 测试旧格式 RequestParamField 兼容
|
||||
func TestBuildURL_OldFormat(t *testing.T) {
|
||||
config := request.EmailApiConfig{
|
||||
URL: "https://api.example.com/user",
|
||||
RequestParamField: "userId",
|
||||
Params: nil, // 旧格式:Params 为 nil
|
||||
}
|
||||
|
||||
got := buildURL("https://api.example.com/user", config, "USER123")
|
||||
if got != "https://api.example.com/user?userId=USER123" {
|
||||
t.Errorf("旧格式应使用 RequestParamField,got: %s", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestResolveParamValue 测试参数值解析
|
||||
func TestResolveParamValue(t *testing.T) {
|
||||
tests := []struct {
|
||||
vt string
|
||||
value string
|
||||
dingId string
|
||||
want string
|
||||
}{
|
||||
{"ding_id", "", "USER123", "USER123"},
|
||||
{"string", "myvalue", "USER123", "myvalue"},
|
||||
{"string", "prefix_{{ding_id}}_suffix", "USER123", "prefix_USER123_suffix"},
|
||||
{"string", "prefix_$<{[ding_id]}>_suffix", "USER123", "prefix_USER123_suffix"},
|
||||
{"int", "42", "USER123", "42"},
|
||||
{"bool", "true", "USER123", "true"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got := resolveParamValue(tt.vt, tt.value, tt.dingId)
|
||||
if got != tt.want {
|
||||
t.Errorf("resolveParamValue(%q, %q, %q) = %q, want %q", tt.vt, tt.value, tt.dingId, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestExtractJSONPathAdvanced 测试 JSON 路径提取
|
||||
func TestExtractJSONPathAdvanced(t *testing.T) {
|
||||
data := map[string]interface{}{
|
||||
"code": float64(0),
|
||||
"data": []interface{}{
|
||||
map[string]interface{}{
|
||||
"userName": "test@example.com",
|
||||
"userId": "USER123",
|
||||
},
|
||||
},
|
||||
"nested": map[string]interface{}{
|
||||
"email": "nested@example.com",
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
path string
|
||||
want string
|
||||
}{
|
||||
{"code", "0"},
|
||||
{"data[0].userName", "test@example.com"},
|
||||
{"data[0].userId", "USER123"},
|
||||
{"nested.email", "nested@example.com"},
|
||||
{"notexist", ""},
|
||||
{"data[1].userName", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got := extractJSONPathAdvanced(data, tt.path)
|
||||
if got != tt.want {
|
||||
t.Errorf("extractJSONPathAdvanced(data, %q) = %q, want %q", tt.path, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestParseEmailApiConfigFromJSON_NewFormat 测试新格式解析
|
||||
func TestParseEmailApiConfigFromJSON_NewFormat(t *testing.T) {
|
||||
jsonStr := `{
|
||||
"enabled": true,
|
||||
"url": "https://api.example.com",
|
||||
"method": "GET",
|
||||
"params": [
|
||||
{"key": "userId", "value_type": "ding_id", "value": ""},
|
||||
{"key": "appKey", "value_type": "string", "value": "mykey"}
|
||||
],
|
||||
"response_email_field": "data[0].email"
|
||||
}`
|
||||
|
||||
cfg, err := parseEmailApiConfigFromJSON([]byte(jsonStr))
|
||||
if err != nil {
|
||||
t.Fatalf("解析新格式配置失败: %v", err)
|
||||
}
|
||||
if !isNewEmailApiConfig(cfg) {
|
||||
t.Error("应检测为新格式")
|
||||
}
|
||||
if len(cfg.Params) != 2 {
|
||||
t.Errorf("Params 数量错误,got: %d", len(cfg.Params))
|
||||
}
|
||||
}
|
||||
|
||||
// TestParseEmailApiConfigFromJSON_OldFormat 测试旧格式兼容解析
|
||||
func TestParseEmailApiConfigFromJSON_OldFormat(t *testing.T) {
|
||||
jsonStr := `{
|
||||
"enabled": true,
|
||||
"url": "https://api.example.com",
|
||||
"method": "GET",
|
||||
"request_param_field": "userId",
|
||||
"response_email_field": "data[0].email",
|
||||
"body_data": {
|
||||
"form_data": [{"userId": ""}],
|
||||
"urlencoded": []
|
||||
}
|
||||
}`
|
||||
|
||||
cfg, err := parseEmailApiConfigFromJSON([]byte(jsonStr))
|
||||
if err != nil {
|
||||
t.Fatalf("解析旧格式配置失败: %v", err)
|
||||
}
|
||||
if isNewEmailApiConfig(cfg) {
|
||||
t.Error("旧格式配置不应检测为新格式")
|
||||
}
|
||||
if cfg.RequestParamField != "userId" {
|
||||
t.Errorf("RequestParamField 应为 userId,got: %s", cfg.RequestParamField)
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateEmailApiConfigFields_NewFormat 测试新格式配置验证
|
||||
func TestValidateEmailApiConfigFields_NewFormat(t *testing.T) {
|
||||
cfg := request.EmailApiConfig{
|
||||
Enabled: true,
|
||||
URL: "https://api.example.com",
|
||||
Method: "GET",
|
||||
Params: []request.RequestParam{},
|
||||
ResponseEmailField: "data[0].email",
|
||||
}
|
||||
|
||||
if err := validateEmailApiConfigFields(cfg); err != nil {
|
||||
t.Errorf("有效的新格式配置不应报错: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateEmailApiConfigFields_InvalidParamType 测试 Params 不支持 int 类型
|
||||
func TestValidateEmailApiConfigFields_InvalidParamType(t *testing.T) {
|
||||
cfg := request.EmailApiConfig{
|
||||
Enabled: true,
|
||||
URL: "https://api.example.com",
|
||||
Method: "GET",
|
||||
Params: []request.RequestParam{
|
||||
{Key: "count", ValueType: "int", Value: "10"}, // Params 不支持 int
|
||||
},
|
||||
ResponseEmailField: "data[0].email",
|
||||
}
|
||||
|
||||
if err := validateEmailApiConfigFields(cfg); err == nil {
|
||||
t.Error("Params 不应支持 int 类型,应报错")
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateEmailApiConfigFields_InvalidBodyType 测试 Body 不支持未知类型
|
||||
func TestValidateEmailApiConfigFields_InvalidBodyType(t *testing.T) {
|
||||
cfg := request.EmailApiConfig{
|
||||
Enabled: true,
|
||||
URL: "https://api.example.com",
|
||||
Method: "POST",
|
||||
Params: []request.RequestParam{},
|
||||
BodyData: request.BodyData{
|
||||
FormData: []request.BodyField{
|
||||
{Key: "field1", ValueType: "invalid_type", Value: "val"},
|
||||
},
|
||||
},
|
||||
ResponseEmailField: "data[0].email",
|
||||
}
|
||||
|
||||
if err := validateEmailApiConfigFields(cfg); err == nil {
|
||||
t.Error("Body 不应支持未知类型,应报错")
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildBodyFields 测试 Body 字段类型转换
|
||||
func TestBuildBodyFields(t *testing.T) {
|
||||
fields := []request.BodyField{
|
||||
{Key: "userId", ValueType: "ding_id", Value: ""},
|
||||
{Key: "appKey", ValueType: "string", Value: "mykey"},
|
||||
{Key: "count", ValueType: "int", Value: "10"},
|
||||
{Key: "enabled", ValueType: "bool", Value: "true"},
|
||||
}
|
||||
|
||||
form := buildBodyFields(fields, "USER123")
|
||||
|
||||
if form.Get("userId") != "USER123" {
|
||||
t.Errorf("ding_id 类型应替换为实际钉钉 ID,got: %s", form.Get("userId"))
|
||||
}
|
||||
if form.Get("appKey") != "mykey" {
|
||||
t.Errorf("string 类型应直接使用,got: %s", form.Get("appKey"))
|
||||
}
|
||||
if form.Get("count") != "10" {
|
||||
t.Errorf("int 类型值不正确,got: %s", form.Get("count"))
|
||||
}
|
||||
}
|
||||
|
||||
// TestDingIDMarkerReplacement 测试 Raw 模式钉钉 ID 标记替换
|
||||
func TestDingIDMarkerReplacement(t *testing.T) {
|
||||
raw := `{"userId": "$<{[ding_id]}>", "other": "value"}`
|
||||
dingId := "USER789"
|
||||
|
||||
// 使用 resolveParamValue 测试标记替换
|
||||
replaced := resolveParamValue("string", raw, dingId)
|
||||
expected := `{"userId": "USER789", "other": "value"}`
|
||||
|
||||
if replaced != expected {
|
||||
t.Errorf("钉钉 ID 标记替换错误\ngot: %s\nwant: %s", replaced, expected)
|
||||
}
|
||||
}
|
||||
|
||||
// TestDingIDMarkerOldFormat 测试旧格式 {{ding_id}} 占位符替换
|
||||
func TestDingIDMarkerOldFormat(t *testing.T) {
|
||||
raw := `{"userId": "{{ding_id}}", "other": "value"}`
|
||||
replaced := resolveParamValue("string", raw, "USER789")
|
||||
expected := `{"userId": "USER789", "other": "value"}`
|
||||
if replaced != expected {
|
||||
t.Errorf("旧格式占位符替换错误\ngot: %s\nwant: %s", replaced, expected)
|
||||
}
|
||||
}
|
||||
@@ -136,6 +136,7 @@ func (e *SystemIntegratedService) OAuth2CodeLogin(
|
||||
return nil, fmt.Errorf("无法从 OAuth2 用户信息中获取邮箱或用户唯一标识")
|
||||
}
|
||||
|
||||
fmt.Println("OAuth2CodeLogin", email, username)
|
||||
sysUser, err := e.findUserByEmailOrPhone(email, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -148,8 +149,48 @@ func (e *SystemIntegratedService) OAuth2CodeLogin(
|
||||
return &response.GaiaLoginResult{User: *sysUser, Token: token, RedirectURI: req.RedirectURI, State: req.State}, nil
|
||||
}
|
||||
|
||||
// DingTalkTestCallback 仅用 code 换 token,用于「测试连接」回调,不登录、不写 session
|
||||
func (e *SystemIntegratedService) DingTalkTestCallback(code string) error {
|
||||
code = strings.TrimSpace(code)
|
||||
if code == "" {
|
||||
return fmt.Errorf("授权码为空")
|
||||
}
|
||||
integrate := e.getIntegratedConfigRaw(gaia.SystemIntegrationDingTalk)
|
||||
if integrate.AppKey == "" || integrate.AppSecret == "" {
|
||||
return fmt.Errorf("钉钉配置不完整")
|
||||
}
|
||||
bodyJSON, _ := json.Marshal(map[string]string{
|
||||
"clientId": integrate.AppKey,
|
||||
"clientSecret": integrate.AppSecret,
|
||||
"code": code,
|
||||
"grantType": "authorization_code",
|
||||
})
|
||||
httpReq, err := http.NewRequest("POST", "https://api.dingtalk.com/v1.0/oauth2/userAccessToken", bytes.NewReader(bodyJSON))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(httpReq)
|
||||
if err != nil {
|
||||
return fmt.Errorf("钉钉 token 请求失败: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
global.GVA_LOG.Error("钉钉 token 非 200", zap.Int("status", resp.StatusCode), zap.String("body", string(respBody)))
|
||||
return fmt.Errorf("钉钉返回错误: %d", resp.StatusCode)
|
||||
}
|
||||
var tokenResp map[string]interface{}
|
||||
if err = json.Unmarshal(respBody, &tokenResp); err != nil || tokenResp["accessToken"] == "" {
|
||||
return fmt.Errorf("解析钉钉 token 失败")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DingTalkCodeLogin 钉钉 code 换用户并登录(扫码/OAuth2 回调带 code)
|
||||
func (e *SystemIntegratedService) DingTalkCodeLogin(req request.GaiaDingTalkLoginReq) (*response.GaiaLoginResult, error) {
|
||||
func (e *SystemIntegratedService) DingTalkCodeLogin(
|
||||
req request.GaiaDingTalkLoginReq) (*response.GaiaLoginResult, error) {
|
||||
integrate := e.getIntegratedConfigRaw(gaia.SystemIntegrationDingTalk)
|
||||
if !integrate.Status {
|
||||
return nil, fmt.Errorf("钉钉登录未启用")
|
||||
@@ -165,7 +206,8 @@ func (e *SystemIntegratedService) DingTalkCodeLogin(req request.GaiaDingTalkLogi
|
||||
"code": req.AuthCode,
|
||||
"grantType": "authorization_code",
|
||||
})
|
||||
httpReq, err := http.NewRequest("POST", "https://api.dingtalk.com/v1.0/oauth2/userAccessToken", bytes.NewReader(bodyJSON))
|
||||
httpReq, err := http.NewRequest("POST",
|
||||
"https://api.dingtalk.com/v1.0/oauth2/userAccessToken", bytes.NewReader(bodyJSON))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -179,7 +221,8 @@ func (e *SystemIntegratedService) DingTalkCodeLogin(req request.GaiaDingTalkLogi
|
||||
defer resp.Body.Close()
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
global.GVA_LOG.Error("钉钉 token 非 200", zap.Int("status", resp.StatusCode), zap.String("body", string(respBody)))
|
||||
global.GVA_LOG.Error("钉钉 token 非 200", zap.Int(
|
||||
"status", resp.StatusCode), zap.String("body", string(respBody)))
|
||||
return nil, fmt.Errorf("钉钉返回错误: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
@@ -205,8 +248,56 @@ func (e *SystemIntegratedService) DingTalkCodeLogin(req request.GaiaDingTalkLogi
|
||||
if err = json.Unmarshal(userBody, &dingUser); err != nil {
|
||||
return nil, fmt.Errorf("解析钉钉用户信息失败")
|
||||
}
|
||||
email := dingUser["email"].(string)
|
||||
username := dingUser["nick"].(string)
|
||||
|
||||
// 提取钉钉 ID(user_id 字段)
|
||||
dingId := ""
|
||||
if v, ok := dingUser["unionId"]; ok && v != nil {
|
||||
dingId, _ = v.(string)
|
||||
}
|
||||
if dingId == "" {
|
||||
if v, ok := dingUser["userId"]; ok && v != nil {
|
||||
dingId, _ = v.(string)
|
||||
}
|
||||
}
|
||||
|
||||
// 解析用户名配置
|
||||
var emailList []string
|
||||
var configMap request.DingTalkConfigRequest
|
||||
var emailConfig request.EmailApiConfig
|
||||
if integrate.Config != "" {
|
||||
if jsonErr := json.Unmarshal([]byte(integrate.Config), &configMap); jsonErr == nil {
|
||||
var rawMsg json.RawMessage
|
||||
if rawBytes, marshalErr := json.Marshal(configMap.EmailApi); marshalErr == nil {
|
||||
rawMsg = rawBytes
|
||||
if cfg, parseErr := parseEmailApiConfigFromJSON(rawMsg); parseErr == nil {
|
||||
emailConfig = cfg
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 优先通过用户名 API 获取用户名(新格式)
|
||||
if emailConfig.Enabled && dingId != "" {
|
||||
emailList, err = e.callEmailApi(dingId, emailConfig)
|
||||
if err == nil && len(emailList) > 0 {
|
||||
fmt.Println("钉钉 code 换用户并登录(扫码/OAuth2 回调带 code)", emailList)
|
||||
sysUser, findErr := e.findUserByEmail(emailList)
|
||||
if findErr != nil {
|
||||
return nil, findErr
|
||||
}
|
||||
token, _, tokenErr := utils.LoginToken(sysUser)
|
||||
if tokenErr != nil {
|
||||
return nil, fmt.Errorf("签发 token 失败")
|
||||
}
|
||||
return &response.GaiaLoginResult{User: *sysUser, Token: token, RedirectURI: req.RedirectURI, State: req.State}, nil
|
||||
}
|
||||
global.GVA_LOG.Warn("DingTalkCodeLogin: 第三方邮箱 API 获取失败,尝试钉钉直接返回邮箱",
|
||||
zap.String("ding_id", dingId), zap.Error(err))
|
||||
}
|
||||
|
||||
// 回退:直接从钉钉用户信息获取邮箱
|
||||
email, _ := dingUser["email"].(string)
|
||||
username, _ := dingUser["nick"].(string)
|
||||
if username == "" {
|
||||
username = email
|
||||
}
|
||||
@@ -214,7 +305,8 @@ func (e *SystemIntegratedService) DingTalkCodeLogin(req request.GaiaDingTalkLogi
|
||||
return nil, fmt.Errorf("钉钉未返回邮箱")
|
||||
}
|
||||
|
||||
sysUser, err := e.findUserByEmail(email)
|
||||
fmt.Println("钉钉 code 换用户并登录第三方邮箱 API 获取失败", email)
|
||||
sysUser, err := e.findUserByEmail([]string{email})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -231,7 +323,8 @@ func getStringFromMap(m map[string]interface{}, keys ...string) string {
|
||||
continue
|
||||
}
|
||||
if v, ok := m[k]; ok && v != nil {
|
||||
if s, ok := v.(string); ok {
|
||||
var s string
|
||||
if s, ok = v.(string); ok {
|
||||
return s
|
||||
}
|
||||
}
|
||||
@@ -297,21 +390,21 @@ func getStringByPathOrKeys(m map[string]interface{}, path string, fallbackKeys .
|
||||
return getStringFromMap(m, fallbackKeys...)
|
||||
}
|
||||
|
||||
// findUserByEmail 按邮箱查找已存在的用户(需在 gaia.accounts 中有对应记录方可签发 JWT)
|
||||
func (e *SystemIntegratedService) findUserByEmail(email string) (*system.SysUser, error) {
|
||||
var u system.SysUser
|
||||
var mailList []string
|
||||
mailList = append(mailList, email)
|
||||
parts := strings.Split(email, "@")
|
||||
defaultMail := os.Getenv(gaia.EmailDomainEnv)
|
||||
if len(defaultMail) > 0 && len(parts) == 2 {
|
||||
mailList = append(mailList, parts[0]+"@"+defaultMail)
|
||||
}
|
||||
// findUserByEmail 按username查找已存在的用户(需在 gaia.accounts 中有对应记录方可签发 JWT)
|
||||
func (e *SystemIntegratedService) findUserByEmail(mailList []string) (*system.SysUser, error) {
|
||||
// 查询关联邮箱
|
||||
var u system.SysUser
|
||||
if len(mailList) == 1 {
|
||||
parts := strings.Split(mailList[0], "@")
|
||||
defaultMail := os.Getenv(gaia.EmailDomainEnv)
|
||||
if len(defaultMail) > 1 && len(parts) > 1 && len(parts[0]) > 0 {
|
||||
mailList = append(mailList, parts[0]+"@"+defaultMail)
|
||||
}
|
||||
}
|
||||
if err := global.GVA_DB.Where("email IN (?)", mailList).Preload(
|
||||
"Authorities").Preload("Authority").First(&u).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, fmt.Errorf(fmt.Sprintf("邮箱%s尚未开通账号,请联系管理员", email))
|
||||
return nil, fmt.Errorf("%s尚未开通账号,请联系管理员", mailList[0])
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
@@ -323,20 +416,20 @@ func (e *SystemIntegratedService) findUserByEmail(email string) (*system.SysUser
|
||||
}
|
||||
|
||||
// findUserByEmailOrPhone 按邮箱或用户唯一标识(如手机号)查找用户,优先邮箱
|
||||
func (e *SystemIntegratedService) findUserByEmailOrPhone(email, userID string) (*system.SysUser, error) {
|
||||
if email != "" {
|
||||
u, err := e.findUserByEmail(email)
|
||||
if err == nil {
|
||||
func (e *SystemIntegratedService) findUserByEmailOrPhone(
|
||||
mail, userID string) (u *system.SysUser, err error) {
|
||||
if mail != "" {
|
||||
if u, err = e.findUserByEmail([]string{mail}); err == nil {
|
||||
return u, nil
|
||||
}
|
||||
// 仅当“未开通”时再尝试按 userID(phone) 查,其他错误直接返回
|
||||
if err != nil && !strings.Contains(err.Error(), "尚未开通") {
|
||||
if !strings.Contains(err.Error(), "尚未开通") {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if userID != "" {
|
||||
var u system.SysUser
|
||||
if err := global.GVA_DB.Where("phone = ?", userID).Preload("Authorities").Preload("Authority").First(&u).Error; err != nil {
|
||||
if err = global.GVA_DB.Where("phone = ?", userID).Preload(
|
||||
"Authorities").Preload("Authority").First(&u).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, fmt.Errorf("该用户唯一标识尚未开通后台账号,请联系管理员")
|
||||
}
|
||||
@@ -345,7 +438,7 @@ func (e *SystemIntegratedService) findUserByEmailOrPhone(email, userID string) (
|
||||
if u.Enable != 1 {
|
||||
return nil, fmt.Errorf("账号已被禁用")
|
||||
}
|
||||
return &u, nil
|
||||
return u, nil
|
||||
}
|
||||
return nil, fmt.Errorf("无法从 OAuth2 用户信息中获取邮箱或用户唯一标识")
|
||||
}
|
||||
|
||||
@@ -14,10 +14,10 @@ import (
|
||||
|
||||
// GetLoginOptions 获取登录方式选项(供登录页展示钉钉/OAuth2 按钮,不暴露密钥)
|
||||
func (e *SystemIntegratedService) GetLoginOptions(frontendOrigin string) (res response.LoginOptionsResponse) {
|
||||
// 非本地的需要加上admin
|
||||
// 非本地的需要加上 admin(若 Referer 已带 /admin 则不再追加,避免 /admin/admin)
|
||||
integrateDing := e.getIntegratedConfigRaw(gaia.SystemIntegrationDingTalk)
|
||||
frontendOrigin = strings.TrimSuffix(frontendOrigin, "/")
|
||||
if !strings.Contains(frontendOrigin, "localhost") {
|
||||
if !strings.Contains(frontendOrigin, "localhost") && !strings.HasSuffix(frontendOrigin, "/admin") {
|
||||
frontendOrigin = frontendOrigin + "/admin"
|
||||
}
|
||||
if integrateDing.Status && integrateDing.AppKey != "" {
|
||||
@@ -72,6 +72,22 @@ func (e *SystemIntegratedService) GetLoginOptions(frontendOrigin string) (res re
|
||||
return res
|
||||
}
|
||||
|
||||
// GetDingTalkTestAuthURL 返回用于「测试连接」的钉钉授权 URL(state=dingtalk_test,回调后仅验证 code 换 token,不登录)
|
||||
func (e *SystemIntegratedService) GetDingTalkTestAuthURL(frontendOrigin string) (string, error) {
|
||||
integrate := e.getIntegratedConfigRaw(gaia.SystemIntegrationDingTalk)
|
||||
if integrate.AppKey == "" || integrate.AppSecret == "" {
|
||||
return "", fmt.Errorf("请先配置 AppKey 与 AppSecret")
|
||||
}
|
||||
frontendOrigin = strings.TrimSuffix(frontendOrigin, "/")
|
||||
if !strings.Contains(frontendOrigin, "localhost") && !strings.HasSuffix(frontendOrigin, "/admin") {
|
||||
frontendOrigin = frontendOrigin + "/admin"
|
||||
}
|
||||
callbackURI := frontendOrigin + "/#/loginCallback?provider=dingtalk"
|
||||
authURL := fmt.Sprintf("https://login.dingtalk.com/oauth2/auth?client_id=%s&response_type=code&scope=openid&redirect_uri=%s&state=dingtalk_test",
|
||||
integrate.AppKey, url.QueryEscape(callbackURI))
|
||||
return authURL, nil
|
||||
}
|
||||
|
||||
// getIntegratedConfigRaw 获取集成配置(不脱敏,仅内部使用)
|
||||
func (e *SystemIntegratedService) getIntegratedConfigRaw(classID uint) (integrate gaia.SystemIntegration) {
|
||||
if err := global.GVA_DB.Where("classify = ?", classID).First(&integrate).Error; err != nil {
|
||||
|
||||
@@ -14,21 +14,231 @@ import (
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/global"
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/model/gaia"
|
||||
gaiaRequest "github.com/flipped-aurora/gin-vue-admin/server/model/gaia/request"
|
||||
gaiaResponse "github.com/flipped-aurora/gin-vue-admin/server/model/gaia/response"
|
||||
"go.gnd.pw/crypto/eax"
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/global"
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/model/gaia"
|
||||
gaiaRequest "github.com/flipped-aurora/gin-vue-admin/server/model/gaia/request"
|
||||
gaiaResponse "github.com/flipped-aurora/gin-vue-admin/server/model/gaia/response"
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/model/system"
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/utils"
|
||||
"go.gnd.pw/crypto/eax"
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// fetchAdminToken 查询一个管理员用户,生成 Dify Console API 兼容的 JWT。
|
||||
// 结果缓存到 Redis,TTL 50 分钟(JWT 有效期内复用,避免频繁生成)。
|
||||
func (s *ModelProviderService) fetchAdminToken() (token string, err error) {
|
||||
ctx := context.Background()
|
||||
|
||||
// 优先从 Redis 读取缓存
|
||||
if cached, e := global.GVA_REDIS.Get(ctx, gaia.RedisKeyGaiaAdminConsoleToken).Result(); e == nil && cached != "" {
|
||||
return cached, nil
|
||||
}
|
||||
|
||||
// 查询一个活跃管理员
|
||||
var adminUser system.SysUser
|
||||
if err = global.GVA_DB.Where("authority_id = ? AND enable = ?",
|
||||
system.AdminAuthorityId, system.UserActive).First(&adminUser).Error; err != nil {
|
||||
return "", fmt.Errorf("找不到可用的管理员账号:%w", err)
|
||||
}
|
||||
|
||||
token, _, _, err = utils.LoginTokenWithCSRF(&adminUser)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("生成管理员 token 失败:%w", err)
|
||||
}
|
||||
|
||||
// 缓存 50 分钟(JWT 缓冲时间内有效)
|
||||
global.GVA_REDIS.Set(ctx, gaia.RedisKeyGaiaAdminConsoleToken, token, 50*time.Minute)
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// fetchModelPricingFromDify 通过 Dify Console API 拉取 LLM 模型定价,结果按 model 名缓存到 Redis(TTL 1 小时)。
|
||||
// Dify Console API:GET /console/api/workspaces/current/models/model-types/llm
|
||||
// 响应结构:{"data": [{"models": [{"model": "gpt-4o", "fetch_from": "...", "pricing": {"input":"0.005","output":"0.015","unit":"0.001","currency":"USD"}}]}]}
|
||||
func (s *ModelProviderService) fetchModelPricingFromDify(modelName string) (*gaia.ModelPricing, error) {
|
||||
const redisTTL = time.Hour
|
||||
cacheKey := gaia.RedisKeyGaiaModelPricingPrefix + modelName
|
||||
ctx := context.Background()
|
||||
|
||||
// 先查 Redis
|
||||
if cached, err := global.GVA_REDIS.Get(ctx, cacheKey).Result(); err == nil && cached != "" {
|
||||
var p gaia.ModelPricing
|
||||
if json.Unmarshal([]byte(cached), &p) == nil {
|
||||
return &p, nil
|
||||
}
|
||||
}
|
||||
|
||||
// 获取管理员 token
|
||||
token, err := s.fetchAdminToken()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 调用 Dify Console API
|
||||
apiURL := strings.TrimSuffix(global.GVA_CONFIG.Gaia.Url, "/") +
|
||||
"/console/api/workspaces/current/models/model-types/llm"
|
||||
req, err := http.NewRequest(http.MethodGet, apiURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("构建定价请求失败:%w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: 15 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("请求 Dify 定价接口失败:%w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取定价响应失败:%w", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("dify 定价接口返回 %d:%s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
// 解析响应,批量缓存所有模型的定价
|
||||
var apiResp gaia.DifyModelsResponse
|
||||
if err = json.Unmarshal(respBody, &apiResp); err != nil {
|
||||
return nil, fmt.Errorf("解析定价响应失败:%w", err)
|
||||
}
|
||||
|
||||
var targetPricing *gaia.ModelPricing
|
||||
for _, providerData := range apiResp.Data {
|
||||
for _, m := range providerData.Models {
|
||||
if m.Pricing == nil {
|
||||
continue
|
||||
}
|
||||
p := gaia.ModelPricing{Currency: m.Pricing.Currency}
|
||||
_, _ = fmt.Sscanf(m.Pricing.Input, "%f", &p.Input)
|
||||
_, _ = fmt.Sscanf(m.Pricing.Output, "%f", &p.Output)
|
||||
_, _ = fmt.Sscanf(m.Pricing.Unit, "%f", &p.Unit)
|
||||
if p.Unit == 0 {
|
||||
p.Unit = 0.001 // 默认按千 token 计费
|
||||
}
|
||||
|
||||
// 缓存每个模型的定价
|
||||
if b, e := json.Marshal(p); e == nil {
|
||||
global.GVA_REDIS.Set(ctx, gaia.RedisKeyGaiaModelPricingPrefix+m.Model, string(b), redisTTL)
|
||||
}
|
||||
if m.Model == modelName {
|
||||
cp := p
|
||||
targetPricing = &cp
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if targetPricing != nil {
|
||||
return targetPricing, nil
|
||||
}
|
||||
// 未找到该模型的定价,写入空标记避免反复请求(TTL 10 分钟)
|
||||
global.GVA_REDIS.Set(ctx, cacheKey, "{}", 10*time.Minute)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// rmbToUSD 将人民币金额按固定汇率换算为 USD。
|
||||
func rmbToUSD(rmb float64) float64 {
|
||||
return rmb / gaia.RmbToUSDRate
|
||||
}
|
||||
|
||||
// resolvePricing 返回模型定价:优先用从 Dify 拉取的 pricing,
|
||||
// 其次查内置兜底定价表(BuiltinModelPricing),最后返回 nil。
|
||||
func resolvePricing(pricing *gaia.ModelPricing, modelName string) *gaia.ModelPricing {
|
||||
if pricing != nil && pricing.Unit > 0 {
|
||||
return pricing
|
||||
}
|
||||
// 内置定价表精确匹配
|
||||
if p, ok := gaia.BuiltinModelPricing[modelName]; ok {
|
||||
return &p
|
||||
}
|
||||
// 前缀模糊匹配(如 "qwen3.5-plus-xxx" 匹配 "qwen3.5-plus")
|
||||
lower := strings.ToLower(modelName)
|
||||
for k, p := range gaia.BuiltinModelPricing {
|
||||
if strings.HasPrefix(lower, strings.ToLower(k)) {
|
||||
cp := p
|
||||
return &cp
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// calcQuotaDelta 根据定价和 token 用量计算本次消耗的配额金额(统一以 USD 计)。
|
||||
// Dify pricing 字段语义:input/output 为每「unit」个 token 的价格,unit 通常为 0.001(千分之一),
|
||||
// 即 input=0.0014, unit=0.001 表示每千 token ¥0.0014 × (tokens/1000)。
|
||||
// 公式:cost = tokens × input × unit(因为 unit=1/1000,等价于 tokens/1000 × input)。
|
||||
// 若货币为 RMB/CNY,则除以汇率 7.26 换算为 USD,与 account_money_extend.used_quota 存储单位保持一致。
|
||||
// 若 Dify 未返回定价则查内置兜底表;均未命中时按极小默认值记账,避免多扣。
|
||||
func calcQuotaDelta(pricing *gaia.ModelPricing, modelName string, promptTokens, completionTokens int) float64 {
|
||||
p := resolvePricing(pricing, modelName)
|
||||
if p == nil {
|
||||
// 兜底:仅做记账占位,不应大量触发
|
||||
global.GVA_LOG.Warn("calcQuotaDelta 未找到模型定价,使用兜底值",
|
||||
zap.String("model", modelName),
|
||||
zap.Int("prompt_tokens", promptTokens),
|
||||
zap.Int("completion_tokens", completionTokens),
|
||||
)
|
||||
return float64(promptTokens+completionTokens) * gaia.DefaultQuotaFallbackUSDPerToken
|
||||
}
|
||||
inputCost := float64(promptTokens) * p.Input * p.Unit
|
||||
outputPrice := p.Output
|
||||
if outputPrice == 0 {
|
||||
outputPrice = p.Input
|
||||
}
|
||||
outputCost := float64(completionTokens) * outputPrice * p.Unit
|
||||
total := inputCost + outputCost
|
||||
|
||||
// RMB/CNY 定价统一换算为 USD 后再扣费,与 used_quota 存储单位保持一致
|
||||
if strings.EqualFold(p.Currency, "RMB") || strings.EqualFold(p.Currency, "CNY") {
|
||||
total = rmbToUSD(total)
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
// CheckAccountQuota 检查用户是否还有可用余额(total_quota - used_quota > 0)。
|
||||
// total_quota = 0 视为"未设置限额",不拦截;total_quota > 0 时才做余额校验。
|
||||
func (s *ModelProviderService) CheckAccountQuota(userID string) error {
|
||||
var row gaiaResponse.CheckAccountQuotaRow
|
||||
err := global.GVA_DB.Table("account_money_extend").
|
||||
Select("total_quota, used_quota").
|
||||
Where("account_id = ?::uuid", userID).
|
||||
First(&row).Error
|
||||
if err != nil {
|
||||
// 记录未找到:可能尚未初始化,放行
|
||||
return nil
|
||||
}
|
||||
// total_quota = 0 表示不限额,放行
|
||||
if row.TotalQuota <= 0 {
|
||||
return nil
|
||||
}
|
||||
if row.UsedQuota >= row.TotalQuota {
|
||||
return fmt.Errorf("余额不足,已用 %.6f / 总额 %.6f USD,请联系管理员充值", row.UsedQuota, row.TotalQuota)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// deductAccountQuota 将消耗配额计入 account_money_extend.used_quota(原子累加)。
|
||||
func deductAccountQuota(userID string, delta float64) {
|
||||
if delta <= 0 {
|
||||
return
|
||||
}
|
||||
if err := global.GVA_DB.Exec(
|
||||
`UPDATE account_money_extend SET used_quota = used_quota + ?, updated_at = NOW() WHERE account_id = ?::uuid`,
|
||||
delta, userID,
|
||||
).Error; err != nil {
|
||||
global.GVA_LOG.Warn("deductAccountQuota 失败",
|
||||
zap.String("user_id", userID), zap.Float64("delta", delta), zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
// ModelProviderService 模型提供商服务,负责提供商配置、凭证获取、可用模型拉取及聊天请求代理。
|
||||
type ModelProviderService struct{}
|
||||
|
||||
@@ -186,7 +396,8 @@ func (s *ModelProviderService) getAvailableModelsFromProviderModelCredentials(pr
|
||||
Distinct("model_name").
|
||||
Pluck("model_name", &modelNames).Error
|
||||
if err != nil {
|
||||
global.GVA_LOG.Warn("从 provider_model_credentials 拉取模型列表失败", zap.String("provider", providerName), zap.Error(err))
|
||||
global.GVA_LOG.Warn("从 provider_model_credentials 拉取模型列表失败", zap.String(
|
||||
"provider", providerName), zap.Error(err))
|
||||
return nil, nil
|
||||
}
|
||||
list := make([]gaiaResponse.ModelInfo, 0, len(modelNames))
|
||||
@@ -207,6 +418,14 @@ func (s *ModelProviderService) GetAvailableModelsFromDify(providerName string) (
|
||||
if providerName == gaia.ProviderAzure {
|
||||
return s.getAvailableModelsFromProviderModelCredentials(providerName)
|
||||
}
|
||||
// AWS Bedrock 没有统一的 /v1/models 接口,模型由前端手输;直接返回空列表
|
||||
if providerName == gaia.ProviderAWS || providerName == gaia.ProviderAnthropic {
|
||||
return nil, nil
|
||||
}
|
||||
// DeepSeek 模型由前端手输(避免拉取全量列表),直接返回空列表
|
||||
if providerName == gaia.ProviderDeepSeek {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
creds, err := s.GetDifyProviderCredentials(providerName)
|
||||
if err != nil || creds.APIKey == "" {
|
||||
@@ -230,8 +449,6 @@ func (s *ModelProviderService) GetAvailableModelsFromDify(providerName string) (
|
||||
base = gaia.DefaultAPIBase[gaia.ProviderGoogle]
|
||||
}
|
||||
return s.fetchGeminiModels(client, base, creds.APIKey)
|
||||
case gaia.ProviderAnthropic:
|
||||
return nil, nil
|
||||
default:
|
||||
if creds.Endpoint != "" {
|
||||
return s.fetchOpenAICompatibleModels(client, creds.Endpoint, creds.APIKey)
|
||||
@@ -244,7 +461,8 @@ func (s *ModelProviderService) GetAvailableModelsFromDify(providerName string) (
|
||||
// 兼容两种响应格式:
|
||||
// 1) OpenAI: { "data": [ { "id": "..." }, ... ] }
|
||||
// 2) 通义: { "success": true, "output": { "models": [ { "model": "...", "name": "..." }, ... ] } }
|
||||
func (s *ModelProviderService) fetchOpenAICompatibleModels(client *http.Client, baseURL, apiKey string) ([]gaiaResponse.ModelInfo, error) {
|
||||
func (s *ModelProviderService) fetchOpenAICompatibleModels(client *http.Client, baseURL, apiKey string) (
|
||||
[]gaiaResponse.ModelInfo, error) {
|
||||
url := strings.TrimSuffix(baseURL, "/") + "/v1/models"
|
||||
req, err := http.NewRequest("GET", url, nil)
|
||||
if err != nil {
|
||||
@@ -259,7 +477,8 @@ func (s *ModelProviderService) fetchOpenAICompatibleModels(client *http.Client,
|
||||
defer resp.Body.Close()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
global.GVA_LOG.Warn("拉取模型列表接口非 200", zap.String("url", url), zap.Int("status", resp.StatusCode), zap.String("body", string(body)))
|
||||
global.GVA_LOG.Warn("拉取模型列表接口非 200", zap.String("url", url), zap.Int(
|
||||
"status", resp.StatusCode), zap.String("body", string(body)))
|
||||
return nil, fmt.Errorf("接口返回 %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
@@ -298,7 +517,8 @@ func (s *ModelProviderService) fetchOpenAICompatibleModels(client *http.Client,
|
||||
|
||||
// fetchGeminiModels 调用 Google Gemini GET /v1beta/models?key=API_KEY,解析 models[],支持分页。
|
||||
// 认证使用 query 参数 key,响应格式:{ "models": [ { "name": "models/xxx", "baseModelId": "xxx", "displayName": "..." } ], "nextPageToken": "..." }
|
||||
func (s *ModelProviderService) fetchGeminiModels(client *http.Client, baseURL, apiKey string) ([]gaiaResponse.ModelInfo, error) {
|
||||
func (s *ModelProviderService) fetchGeminiModels(client *http.Client, baseURL, apiKey string) (
|
||||
[]gaiaResponse.ModelInfo, error) {
|
||||
baseURL = strings.TrimSuffix(baseURL, "/")
|
||||
all := make([]gaiaResponse.ModelInfo, 0)
|
||||
pageToken := ""
|
||||
@@ -320,7 +540,9 @@ func (s *ModelProviderService) fetchGeminiModels(client *http.Client, baseURL, a
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
global.GVA_LOG.Warn("拉取 Gemini 模型列表非 200", zap.String("url", baseURL+"/v1beta/models"), zap.Int("status", resp.StatusCode), zap.String("body", string(body)))
|
||||
global.GVA_LOG.Warn("拉取 Gemini 模型列表非 200", zap.String(
|
||||
"url", baseURL+"/v1beta/models"), zap.Int("status", resp.StatusCode), zap.String(
|
||||
"body", string(body)))
|
||||
return nil, fmt.Errorf("接口返回 %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
@@ -356,7 +578,8 @@ func (s *ModelProviderService) fetchGeminiModels(client *http.Client, baseURL, a
|
||||
|
||||
// fetchAzureOpenAIModels 调用 Azure OpenAI GET {endpoint}/openai/models?api-version={version},解析 data[]。
|
||||
// 认证使用 api-key 请求头,响应格式:{ "data": [ { "id": "...", "object": "model" } ] }
|
||||
func (s *ModelProviderService) fetchAzureOpenAIModels(client *http.Client, baseURL, apiKey, apiVersion string) ([]gaiaResponse.ModelInfo, error) {
|
||||
func (s *ModelProviderService) fetchAzureOpenAIModels(client *http.Client, baseURL, apiKey, apiVersion string) (
|
||||
[]gaiaResponse.ModelInfo, error) {
|
||||
baseURL = strings.TrimSuffix(baseURL, "/")
|
||||
if apiVersion == "" {
|
||||
apiVersion = "2024-08-01-preview" // 默认 API 版本
|
||||
@@ -378,7 +601,8 @@ func (s *ModelProviderService) fetchAzureOpenAIModels(client *http.Client, baseU
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
global.GVA_LOG.Warn("拉取 Azure OpenAI 模型列表非 200", zap.String("url", url), zap.Int("status", resp.StatusCode), zap.String("body", string(body)))
|
||||
global.GVA_LOG.Warn("拉取 Azure OpenAI 模型列表非 200", zap.String("url", url), zap.Int(
|
||||
"status", resp.StatusCode), zap.String("body", string(body)))
|
||||
return nil, fmt.Errorf("接口返回 %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
@@ -416,29 +640,36 @@ func (s *ModelProviderService) GetDifyProviderCredentials(providerName string) (
|
||||
var cached string
|
||||
var firstTenant gaia.Tenants
|
||||
tenantID := firstTenant.GetSuperAdminTenantId()
|
||||
cacheKey := fmt.Sprintf("model_provider_credentials:%s", providerName)
|
||||
cacheKey := gaia.RedisKeyModelProviderCredentialsPrefix + providerName
|
||||
if cached, err = global.GVA_Dify_REDIS.Get(context.Background(), cacheKey).Result(); err == nil {
|
||||
if err = json.Unmarshal([]byte(cached), &creds); err == nil {
|
||||
global.GVA_LOG.Info("GetDifyProviderCredentials 命中缓存",
|
||||
zap.String("provider", providerName),
|
||||
zap.String("aws_region", creds.AWSRegion),
|
||||
zap.String("bedrock_proxy_url", creds.BedrockProxyURL),
|
||||
)
|
||||
return creds, nil
|
||||
}
|
||||
}
|
||||
|
||||
// 尝试方式1: 从 providers + provider_credentials 表查询
|
||||
var row gaia.ProviderCredential
|
||||
err = global.GVA_DB.Table("providers").
|
||||
// 将短名转为 Dify 内部 provider_name 的 LIKE 模式(避免 aws 匹配不到 bedrock_claude)
|
||||
likePattern := s.difyProviderLikePattern(providerName)
|
||||
err = global.GVA_DB.Debug().Table("providers").
|
||||
Select("provider_credentials.encrypted_config, providers.tenant_id").
|
||||
Joins("LEFT JOIN provider_credentials ON providers.credential_id = provider_credentials.id").
|
||||
Where("providers.tenant_id = ? AND providers.provider_name LIKE ? AND providers.provider_type = ? AND providers.is_valid = ?",
|
||||
tenantID, fmt.Sprintf("%%%s%%", providerName), gaia.DifyProviderTypeCustom, true).
|
||||
tenantID, likePattern, gaia.DifyProviderTypeCustom, true).
|
||||
Order("provider_credentials.updated_at DESC").
|
||||
First(&row).Error
|
||||
|
||||
// 如果方式1 未找到记录,尝试方式2: 从 provider_model_credentials 表查询
|
||||
if err != nil || row.EncryptedConfig == "" {
|
||||
var pmcRow gaia.ProviderCredential
|
||||
if pmcErr := global.GVA_DB.Table("provider_model_credentials").
|
||||
if pmcErr := global.GVA_DB.Debug().Table("provider_model_credentials").
|
||||
Select("encrypted_config, tenant_id, provider_name, updated_at").
|
||||
Where("tenant_id = ? AND provider_name LIKE ?", tenantID, fmt.Sprintf("%%%s%%", providerName)).
|
||||
Where("tenant_id = ? AND provider_name LIKE ?", tenantID, likePattern).
|
||||
Order("updated_at DESC"). // 按 updated_at 倒序,取最新的凭证
|
||||
First(&pmcRow).Error; pmcErr == nil && pmcRow.EncryptedConfig != "" {
|
||||
row = pmcRow
|
||||
@@ -453,6 +684,7 @@ func (s *ModelProviderService) GetDifyProviderCredentials(providerName string) (
|
||||
// 兼容两种存储:1) 明文 JSON(如 {"openai_api_key":"...", "openai_api_base":"..."});2) Dify RSA+AES-EAX 加密后再 base64
|
||||
var base, apiVersion string
|
||||
var configMap map[string]interface{}
|
||||
fmt.Println("row.EncryptedConfig", row.EncryptedConfig)
|
||||
if err = json.Unmarshal([]byte(row.EncryptedConfig), &configMap); err == nil {
|
||||
// 解密函数用于处理加密的值
|
||||
if config, ok := configMap[gaia.ConfigKeyOpenaiAPIKey]; ok {
|
||||
@@ -468,6 +700,35 @@ func (s *ModelProviderService) GetDifyProviderCredentials(providerName string) (
|
||||
creds.APIKey, err = s.decryptConfig(config.(string), row.TenantID)
|
||||
} else if config, ok = configMap[gaia.ConfigKeyAPIKey]; ok {
|
||||
creds.APIKey, err = s.decryptConfig(config.(string), row.TenantID)
|
||||
} else if _, hasAWSKey := configMap[gaia.ConfigKeyAWSAccessKeyID]; hasAWSKey {
|
||||
// AWS Bedrock 凭证:解析 aws_access_key_id / aws_secret_access_key / aws_region
|
||||
if v, ok2 := configMap[gaia.ConfigKeyAWSAccessKeyID].(string); ok2 && v != "" {
|
||||
if creds.AWSAccessKeyID, err = s.decryptConfig(v, row.TenantID); err != nil {
|
||||
return nil, fmt.Errorf("解密 aws_access_key_id 失败: %w", err)
|
||||
}
|
||||
}
|
||||
if v, ok2 := configMap[gaia.ConfigKeyAWSSecretAccessKey].(string); ok2 && v != "" {
|
||||
if creds.AWSSecretAccessKey, err = s.decryptConfig(v, row.TenantID); err != nil {
|
||||
return nil, fmt.Errorf("解密 aws_secret_access_key 失败: %w", err)
|
||||
}
|
||||
}
|
||||
if v, ok2 := configMap[gaia.ConfigKeyAWSSessionToken].(string); ok2 && v != "" {
|
||||
if creds.AWSSessionToken, err = s.decryptConfig(v, row.TenantID); err != nil {
|
||||
return nil, fmt.Errorf("解密 aws_session_token 失败: %w", err)
|
||||
}
|
||||
}
|
||||
if v, ok2 := configMap[gaia.ConfigKeyAWSRegion].(string); ok2 && v != "" {
|
||||
creds.AWSRegion = strings.TrimSpace(v)
|
||||
}
|
||||
// 可选:HTTP 代理地址(用于从受限地区中转 Bedrock 请求)。
|
||||
// 支持 "host:port" 或 "http(s)://host:port";若值被加密则先解密,失败时按原文处理。
|
||||
if v, ok2 := configMap[gaia.ConfigKeyBedrockProxyURL].(string); ok2 && strings.TrimSpace(v) != "" {
|
||||
proxyVal := strings.TrimSpace(v)
|
||||
if decrypted, decErr := s.decryptConfig(proxyVal, row.TenantID); decErr == nil && decrypted != "" {
|
||||
proxyVal = strings.TrimSpace(decrypted)
|
||||
}
|
||||
creds.BedrockProxyURL = proxyVal
|
||||
}
|
||||
} else {
|
||||
// 尝试从备选字段中查找
|
||||
for _, key := range gaia.CredentialKeyFallback {
|
||||
@@ -490,13 +751,14 @@ func (s *ModelProviderService) GetDifyProviderCredentials(providerName string) (
|
||||
return nil, fmt.Errorf("解密凭证失败: %w", err)
|
||||
}
|
||||
}
|
||||
if creds.APIKey == "" {
|
||||
return nil, fmt.Errorf("未能从配置中提取API Key")
|
||||
if creds.APIKey == "" && creds.AWSAccessKeyID == "" {
|
||||
return nil, fmt.Errorf("未能从配置中提取API Key(也未找到 AWS 凭证)")
|
||||
}
|
||||
|
||||
// 缓存凭证(1小时)
|
||||
var cacheJSON []byte
|
||||
if cacheJSON, err = json.Marshal(creds); err == nil {
|
||||
fmt.Println("row.EncryptedConfig", string(cacheJSON))
|
||||
global.GVA_Dify_REDIS.Set(context.Background(), cacheKey, cacheJSON, time.Hour)
|
||||
}
|
||||
|
||||
@@ -708,7 +970,7 @@ func (s *ModelProviderService) ProxyChat(userID string, req gaiaRequest.ChatRequ
|
||||
|
||||
defer func() {
|
||||
// 记录日志
|
||||
log := gaia.ModelProxyLog{
|
||||
global.GVA_DB.Create(&gaia.ModelProxyLog{
|
||||
UserId: userID,
|
||||
ProviderName: providerName,
|
||||
ModelName: req.Model,
|
||||
@@ -717,8 +979,7 @@ func (s *ModelProviderService) ProxyChat(userID string, req gaiaRequest.ChatRequ
|
||||
Status: status,
|
||||
ErrorMessage: errorMsg,
|
||||
CreatedAt: startTime,
|
||||
}
|
||||
global.GVA_DB.Create(&log)
|
||||
})
|
||||
}()
|
||||
|
||||
// 处理流式响应
|
||||
@@ -726,7 +987,7 @@ func (s *ModelProviderService) ProxyChat(userID string, req gaiaRequest.ChatRequ
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if _, err := writer.Write([]byte(line + "\n")); err != nil {
|
||||
if _, err = writer.Write([]byte(line + "\n")); err != nil {
|
||||
status = "error"
|
||||
errorMsg = err.Error()
|
||||
return err
|
||||
@@ -736,14 +997,14 @@ func (s *ModelProviderService) ProxyChat(userID string, req gaiaRequest.ChatRequ
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
if err = scanner.Err(); err != nil {
|
||||
status = "error"
|
||||
errorMsg = err.Error()
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// 非流式响应
|
||||
if _, err := io.Copy(writer, resp.Body); err != nil {
|
||||
if _, err = io.Copy(writer, resp.Body); err != nil {
|
||||
status = "error"
|
||||
errorMsg = err.Error()
|
||||
return err
|
||||
@@ -769,8 +1030,18 @@ func (s *ModelProviderService) getProviderCandidatesByModel(modelName string) []
|
||||
if strings.HasPrefix(modelLower, "gemini") || strings.Contains(modelLower, "google") {
|
||||
return []string{gaia.ProviderGoogle}
|
||||
}
|
||||
if strings.HasPrefix(modelLower, "anthropic.") {
|
||||
// anthropic.* 前缀是 AWS Bedrock 专用模型 ID 格式(如 anthropic.claude-sonnet-4-6),
|
||||
// 仅走 Bedrock 渠道,不回落到 Anthropic 直连(直连不支持此格式且可能因地区受限)。
|
||||
return []string{gaia.ProviderAWS}
|
||||
}
|
||||
if strings.Contains(modelLower, "claude") || strings.Contains(modelLower, "anthropic") {
|
||||
return []string{gaia.ProviderAnthropic}
|
||||
// 顺序即优先级:AWS Bedrock 优先(配置成本更高、可覆盖受限地区),未开启再回落到 anthropic 直连
|
||||
return []string{gaia.ProviderAWS, gaia.ProviderAnthropic}
|
||||
}
|
||||
// Kimi / Moonshot 系列经由 tongyi(百炼)渠道转发
|
||||
if strings.HasPrefix(modelLower, "kimi") || strings.Contains(modelLower, "moonshot") {
|
||||
return []string{gaia.ProviderTongyi}
|
||||
}
|
||||
// GLM/智谱 可能配置在 tongyi(统一入口)或 zhipuai 下,先试 tongyi
|
||||
if strings.HasPrefix(modelLower, "glm") || strings.Contains(modelLower, "zhipu") || strings.Contains(modelLower, "chatglm") {
|
||||
@@ -780,6 +1051,10 @@ func (s *ModelProviderService) getProviderCandidatesByModel(modelName string) []
|
||||
if strings.HasPrefix(modelLower, "minimax") || strings.Contains(modelLower, "abab") {
|
||||
return []string{gaia.ProviderTongyi, gaia.ProviderMinimax}
|
||||
}
|
||||
// DeepSeek 系列模型走 deepseek 渠道
|
||||
if strings.HasPrefix(modelLower, "deepseek") {
|
||||
return []string{gaia.ProviderDeepSeek}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -811,9 +1086,18 @@ func (s *ModelProviderService) getProviderByModel(modelName string) (string, err
|
||||
if strings.HasPrefix(modelLower, "gemini") || strings.Contains(modelLower, "google") {
|
||||
return gaia.ProviderGoogle, nil
|
||||
}
|
||||
if strings.HasPrefix(modelLower, "anthropic.") {
|
||||
// anthropic.* 前缀是 AWS Bedrock 专用格式,直接归 aws 渠道
|
||||
return gaia.ProviderAWS, nil
|
||||
}
|
||||
if strings.Contains(modelLower, "claude") || strings.Contains(modelLower, "anthropic") {
|
||||
// 仅按名字推断时默认 anthropic;实际渠道(含 AWS Bedrock)由 resolveProviderByModel 决定
|
||||
return gaia.ProviderAnthropic, nil
|
||||
}
|
||||
// Kimi / Moonshot 默认走 tongyi(百炼)渠道
|
||||
if strings.HasPrefix(modelLower, "kimi") || strings.Contains(modelLower, "moonshot") {
|
||||
return gaia.ProviderTongyi, nil
|
||||
}
|
||||
if strings.Contains(modelLower, "azure") {
|
||||
return gaia.ProviderAzure, nil
|
||||
}
|
||||
@@ -825,6 +1109,10 @@ func (s *ModelProviderService) getProviderByModel(modelName string) (string, err
|
||||
if strings.HasPrefix(modelLower, "minimax") || strings.Contains(modelLower, "abab") {
|
||||
return gaia.ProviderTongyi, nil
|
||||
}
|
||||
// DeepSeek 系列走 deepseek 渠道
|
||||
if strings.HasPrefix(modelLower, "deepseek") {
|
||||
return gaia.ProviderDeepSeek, nil
|
||||
}
|
||||
return "", fmt.Errorf("无法识别模型 %s 的提供商", modelName)
|
||||
}
|
||||
|
||||
@@ -900,7 +1188,8 @@ func (s *ModelProviderService) ProxyRequest(
|
||||
|
||||
// 解析 provider:头 > query 已在 handler 传入;此处从 body 取 model 仅当 body 为 JSON 且含 model 时用于推断
|
||||
xGaiaProvider := reqHeader.Get("X-Gaia-Provider")
|
||||
global.GVA_LOG.Info("ProxyRequest 解析 provider", zap.String("path", path), zap.String("X-Gaia-Provider", xGaiaProvider), zap.Int("body_len", len(body)))
|
||||
global.GVA_LOG.Info("ProxyRequest 解析 provider", zap.String("path", path), zap.String(
|
||||
"X-Gaia-Provider", xGaiaProvider), zap.Int("body_len", len(body)))
|
||||
if p := xGaiaProvider; p != "" {
|
||||
providerName = strings.TrimSpace(strings.ToLower(p))
|
||||
}
|
||||
@@ -912,7 +1201,8 @@ func (s *ModelProviderService) ProxyRequest(
|
||||
// 按“已选模型”解析实际渠道(如 gpt-5-chat 若只在 Azure 下勾选则走 azure)
|
||||
providerName, err = s.resolveProviderByModel(m)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("ProxyRequest resolveProviderByModel 失败", zap.String("model", m), zap.Error(err))
|
||||
global.GVA_LOG.Error("ProxyRequest resolveProviderByModel 失败", zap.String(
|
||||
"model", m), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
global.GVA_LOG.Info("ProxyRequest 解析得到 provider", zap.String("provider", providerName))
|
||||
@@ -935,11 +1225,29 @@ func (s *ModelProviderService) ProxyRequest(
|
||||
return err
|
||||
}
|
||||
|
||||
// AWS Bedrock 直连:不走通用 HTTP 转发,改用 SigV4 签名的 Bedrock 原生 API
|
||||
if providerName == gaia.ProviderAWS {
|
||||
return s.proxyBedrockRequest(userID, path, method, reqHeader, body, writer, creds)
|
||||
}
|
||||
|
||||
if base = s.getUpstreamBase(providerName, creds); base == "" {
|
||||
return fmt.Errorf("提供商 %s 无可用上游地址", providerName)
|
||||
}
|
||||
|
||||
// 若 body 是 JSON 且含 stream: true,注入 stream_options.include_usage = true
|
||||
// 这样上游会在 SSE 末尾的 data 行返回 usage,供后续计费解析使用。
|
||||
if len(body) > 0 {
|
||||
var bodyObj map[string]interface{}
|
||||
if json.Unmarshal(body, &bodyObj) == nil {
|
||||
if streamVal, ok := bodyObj["stream"].(bool); ok && streamVal {
|
||||
if _, hasOpt := bodyObj["stream_options"]; !hasOpt {
|
||||
bodyObj["stream_options"] = map[string]interface{}{"include_usage": true}
|
||||
if injected, e := json.Marshal(bodyObj); e == nil {
|
||||
body = injected
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
bodyReader = bytes.NewReader(body)
|
||||
}
|
||||
|
||||
@@ -975,7 +1283,6 @@ func (s *ModelProviderService) ProxyRequest(
|
||||
requestURL = base + "/" + path
|
||||
}
|
||||
|
||||
fmt.Println("path", requestURL, string(body))
|
||||
httpReq, err := http.NewRequest(method, requestURL, bodyReader)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -1017,18 +1324,42 @@ func (s *ModelProviderService) ProxyRequest(
|
||||
}
|
||||
}
|
||||
var logStatus, logError string
|
||||
var promptTokens, completionTokens int
|
||||
defer func() {
|
||||
if logStatus == "" {
|
||||
logStatus = "success"
|
||||
}
|
||||
global.GVA_DB.Create(&gaia.ModelProxyLog{
|
||||
UserId: userID,
|
||||
ProviderName: providerName,
|
||||
ModelName: modelOrPath,
|
||||
Status: logStatus,
|
||||
ErrorMessage: logError,
|
||||
CreatedAt: startTime,
|
||||
UserId: userID,
|
||||
ProviderName: providerName,
|
||||
ModelName: modelOrPath,
|
||||
RequestTokens: promptTokens,
|
||||
ResponseTokens: completionTokens,
|
||||
Status: logStatus,
|
||||
ErrorMessage: logError,
|
||||
CreatedAt: startTime,
|
||||
})
|
||||
|
||||
// 计费:仅成功时扣费
|
||||
if logStatus == "success" {
|
||||
if promptTokens > 0 || completionTokens > 0 {
|
||||
// LLM 类型:按 token 计费
|
||||
pricing, _ := s.fetchModelPricingFromDify(modelOrPath)
|
||||
delta := calcQuotaDelta(pricing, modelOrPath, promptTokens, completionTokens)
|
||||
deductAccountQuota(userID, delta)
|
||||
} else if isImageOrPerRequestPath(path) {
|
||||
// 图片生成等无 usage 的接口:按请求次数计费,默认单价见 gaia.DefaultImageGenerationPriceUSD
|
||||
pricing, _ := s.fetchModelPricingFromDify(modelOrPath)
|
||||
var delta float64
|
||||
if pricing != nil && pricing.Input > 0 {
|
||||
// 若定价表有配置,input 字段用作每次请求单价
|
||||
delta = pricing.Input
|
||||
} else {
|
||||
delta = gaia.DefaultImageGenerationPriceUSD
|
||||
}
|
||||
deductAccountQuota(userID, delta)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// 写回状态码与响应头(流式由上游 Content-Type 决定)
|
||||
@@ -1044,17 +1375,36 @@ func (s *ModelProviderService) ProxyRequest(
|
||||
_, _ = io.Copy(writer, resp.Body)
|
||||
return nil
|
||||
}
|
||||
// 流式响应时按行刷新,避免缓冲
|
||||
|
||||
// extractUsage 从 OpenAI 格式的 JSON 对象中提取 usage 字段
|
||||
extractUsage := func(data []byte) {
|
||||
var obj gaia.ModelUsageResponse
|
||||
if json.Unmarshal(data, &obj) == nil && obj.Usage != nil {
|
||||
if obj.Usage.PromptTokens > 0 {
|
||||
promptTokens = obj.Usage.PromptTokens
|
||||
}
|
||||
if obj.Usage.CompletionTokens > 0 {
|
||||
completionTokens = obj.Usage.CompletionTokens
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 流式响应:按行扫描,顺带从最后一条含 usage 的 data 行返回
|
||||
if strings.Contains(resp.Header.Get("Content-Type"), "text/event-stream") {
|
||||
if flusher, ok := writer.(http.Flusher); ok {
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
for scanner.Scan() {
|
||||
fmt.Println("sss", scanner.Text())
|
||||
if _, err = writer.Write([]byte(scanner.Text() + "\n")); err != nil {
|
||||
line := scanner.Text()
|
||||
if _, err = writer.Write([]byte(line + "\n")); err != nil {
|
||||
logStatus, logError = "error", err.Error()
|
||||
return err
|
||||
}
|
||||
flusher.Flush()
|
||||
// 解析 SSE data 行中的 usage(stream_options.include_usage=true 时上游会附带)
|
||||
if strings.HasPrefix(line, "data:") && strings.Contains(line, `"usage"`) {
|
||||
payload := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
|
||||
extractUsage([]byte(payload))
|
||||
}
|
||||
}
|
||||
if err = scanner.Err(); err != nil {
|
||||
logStatus, logError = "error", err.Error()
|
||||
@@ -1063,13 +1413,41 @@ func (s *ModelProviderService) ProxyRequest(
|
||||
return nil
|
||||
}
|
||||
}
|
||||
_, err = io.Copy(writer, resp.Body)
|
||||
|
||||
// 非流式响应:TeeReader 同时转发给客户端并读取 body 用于解析 usage
|
||||
var buf bytes.Buffer
|
||||
tee := io.TeeReader(resp.Body, &buf)
|
||||
_, err = io.Copy(writer, tee)
|
||||
if err != nil {
|
||||
logStatus, logError = "error", err.Error()
|
||||
} else {
|
||||
extractUsage(buf.Bytes())
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// GetProxyLogs 分页查询代理日志(model_proxy_log 表)。
|
||||
func (s *ModelProviderService) GetProxyLogs(info gaiaRequest.GetProxyLogsReq) (list []map[string]interface{}, total int64, err error) {
|
||||
page, pageSize := info.Page, info.PageSize
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
if pageSize < 1 || pageSize > 100 {
|
||||
pageSize = 20
|
||||
}
|
||||
db := global.GVA_DB.Table("model_proxy_log")
|
||||
if err = db.Count(&total).Error; err != nil {
|
||||
err = fmt.Errorf("查询日志总数失败:%w", err)
|
||||
return
|
||||
}
|
||||
offset := (page - 1) * pageSize
|
||||
if err = db.Order("created_at DESC").Limit(pageSize).Offset(offset).Find(&list).Error; err != nil {
|
||||
err = fmt.Errorf("查询日志列表失败:%w", err)
|
||||
return
|
||||
}
|
||||
return list, total, nil
|
||||
}
|
||||
|
||||
// isProviderEnabled 检查该提供商是否已启用(未校验具体模型列表,用于通用代理)。
|
||||
func (s *ModelProviderService) isProviderEnabled(providerName string) bool {
|
||||
var config gaia.ModelProviderConfig
|
||||
@@ -1078,3 +1456,40 @@ func (s *ModelProviderService) isProviderEnabled(providerName string) bool {
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// isImageOrPerRequestPath 判断请求路径是否为按次计费的接口(图片生成、语音合成等无 usage 字段的接口)。
|
||||
func isImageOrPerRequestPath(path string) bool {
|
||||
perRequestPaths := []string{
|
||||
"images/generations",
|
||||
"images/edits",
|
||||
"images/variations",
|
||||
"audio/speech",
|
||||
"audio/transcriptions",
|
||||
"audio/translations",
|
||||
}
|
||||
lpath := strings.ToLower(path)
|
||||
for _, p := range perRequestPaths {
|
||||
if strings.Contains(lpath, p) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// difyProviderLikePattern 将 Gaia 短名(openai/aws/...)转换为 Dify providers 表
|
||||
// provider_name 字段的 LIKE 搜索模式。
|
||||
// Dify 内部以 "langgenius/<plugin>/<plugin>" 格式存储提供商名,例如:
|
||||
// - aws → langgenius/bedrock_claude/bedrock_claude 或 langgenius/bedrock/bedrock
|
||||
// - openai → langgenius/openai/openai
|
||||
// - anthropic → langgenius/anthropic/anthropic
|
||||
//
|
||||
// 对于 AWS,Dify 有两种插件包:bedrock_claude 和 bedrock;使用 bedrock 作为公共关键字可同时命中。
|
||||
func (s *ModelProviderService) difyProviderLikePattern(providerName string) string {
|
||||
switch providerName {
|
||||
case gaia.ProviderAWS:
|
||||
// langgenius/bedrock_claude/bedrock_claude 和 langgenius/bedrock/bedrock 都含 "bedrock"
|
||||
return "%bedrock%"
|
||||
default:
|
||||
return fmt.Sprintf("%%%s%%", providerName)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
package gaia
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -8,13 +12,16 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/faabiosr/cachego/file"
|
||||
"github.com/fastwego/dingding"
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/global"
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/model/gaia"
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/model/gaia/request"
|
||||
gaiaResp "github.com/flipped-aurora/gin-vue-admin/server/model/gaia/response"
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/utils"
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
@@ -124,22 +131,52 @@ func (e *SystemIntegratedService) SetIntegratedConfig(
|
||||
// @param: req gaia.SystemIntegration
|
||||
// @return: *dingding.Client, error
|
||||
func (e *SystemIntegratedService) DingTalkConfigAvailable(req gaia.SystemIntegration) (*dingding.Client, error) {
|
||||
var err error
|
||||
// 1. 先直接调用钉钉 gettoken 接口,校验 AppKey/AppSecret 是否正确
|
||||
if strings.TrimSpace(req.AppKey) == "" || strings.TrimSpace(req.AppSecret) == "" {
|
||||
return nil, errors.New("AppKey 或 AppSecret 不能为空")
|
||||
}
|
||||
|
||||
params := url.Values{}
|
||||
params.Add("appkey", req.AppKey)
|
||||
params.Add("appsecret", req.AppSecret)
|
||||
|
||||
resp, err := http.Get("https://oapi.dingtalk.com/gettoken?" + params.Encode())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("请求钉钉 gettoken 失败: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("钉钉 gettoken HTTP 状态异常: %d, body=%s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var tokenResp struct {
|
||||
ErrCode int `json:"errcode"`
|
||||
ErrMsg string `json:"errmsg"`
|
||||
}
|
||||
if err = json.Unmarshal(body, &tokenResp); err != nil {
|
||||
return nil, fmt.Errorf("解析钉钉 gettoken 响应失败: %w", err)
|
||||
}
|
||||
if tokenResp.ErrCode != 0 {
|
||||
return nil, fmt.Errorf("钉钉 gettoken 返回错误: errcode=%d, errmsg=%s", tokenResp.ErrCode, tokenResp.ErrMsg)
|
||||
}
|
||||
|
||||
// 2. 校验通过后再构造 client(保持原有行为,供后续使用)
|
||||
var reqs *http.Request
|
||||
dingding.ServerUrl = "https://api.dingtalk.com"
|
||||
// 特殊需要,检查可用性就不设置缓存了
|
||||
return dingding.NewClient(&dingding.DefaultAccessTokenManager{
|
||||
client := dingding.NewClient(&dingding.DefaultAccessTokenManager{
|
||||
Id: uuid.New().String(),
|
||||
Cache: file.New(os.TempDir()),
|
||||
Name: "x-acs-dingtalk-access-token",
|
||||
GetRefreshRequestFunc: func() *http.Request {
|
||||
params := url.Values{}
|
||||
params.Add("appkey", req.AppKey)
|
||||
params.Add("appsecret", req.AppSecret)
|
||||
reqs, err = http.NewRequest(http.MethodGet, "https://oapi.dingtalk.com/gettoken?"+params.Encode(), nil)
|
||||
// 这里沿用原来的 token 刷新逻辑
|
||||
reqs, _ = http.NewRequest(http.MethodGet, "https://oapi.dingtalk.com/gettoken?"+params.Encode(), nil)
|
||||
return reqs
|
||||
},
|
||||
}), err
|
||||
})
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
// TestConnection 测试连接
|
||||
@@ -162,6 +199,11 @@ func (e *SystemIntegratedService) TestConnection(integrate gaia.SystemIntegratio
|
||||
global.GVA_LOG.Warn("第三方邮箱API配置验证失败", zap.Error(err))
|
||||
// 不阻止保存,只记录警告
|
||||
}
|
||||
// 验证转发集成配置
|
||||
if err := e.ValidateForwardConfig(integrate); err != nil {
|
||||
global.GVA_LOG.Warn("转发集成配置验证失败", zap.Error(err))
|
||||
// 不阻止保存,只记录警告
|
||||
}
|
||||
return nil
|
||||
case gaia.SystemIntegrationOAuth2:
|
||||
// 测试OAuth2连接
|
||||
@@ -171,79 +213,307 @@ func (e *SystemIntegratedService) TestConnection(integrate gaia.SystemIntegratio
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateEmailApiConfig 验证第三方邮箱API配置
|
||||
// @Tags System Integrated
|
||||
// @Summary 验证第三方邮箱API配置
|
||||
// @param: integrate gaia.SystemIntegration
|
||||
// @return: error
|
||||
func (e *SystemIntegratedService) ValidateEmailApiConfig(integrate gaia.SystemIntegration) error {
|
||||
// 解析Config字段
|
||||
if integrate.Config == "" {
|
||||
return nil // 配置为空不算错误
|
||||
}
|
||||
|
||||
// ParseDingTalkConfig 解析钉钉集成配置,自动处理新旧格式兼容
|
||||
func (e *SystemIntegratedService) ParseDingTalkConfig(configJSON string) (request.DingTalkConfigRequest, error) {
|
||||
var configMap request.DingTalkConfigRequest
|
||||
if err := json.Unmarshal([]byte(integrate.Config), &configMap); err != nil {
|
||||
return fmt.Errorf("解析配置失败: %s", err.Error())
|
||||
if configJSON == "" {
|
||||
return configMap, nil
|
||||
}
|
||||
|
||||
// 检查是否启用邮箱API
|
||||
if !configMap.EmailApi.Enabled {
|
||||
return nil // 未启用不需要验证
|
||||
// 先解析顶层结构
|
||||
var raw struct {
|
||||
EmailApi json.RawMessage `json:"email_api"`
|
||||
ForwardConfig request.ForwardConfig `json:"forward_config"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(configJSON), &raw); err != nil {
|
||||
return configMap, fmt.Errorf("解析钉钉配置失败: %s", err.Error())
|
||||
}
|
||||
|
||||
// 验证必填字段
|
||||
if configMap.EmailApi.URL == "" {
|
||||
configMap.ForwardConfig = raw.ForwardConfig
|
||||
|
||||
if raw.EmailApi != nil {
|
||||
cfg, err := parseEmailApiConfigFromJSON(raw.EmailApi)
|
||||
if err != nil {
|
||||
return configMap, err
|
||||
}
|
||||
configMap.EmailApi = cfg
|
||||
}
|
||||
|
||||
return configMap, nil
|
||||
}
|
||||
|
||||
// oldBodyDataCompat 旧格式 BodyData(兼容解析用)
|
||||
type oldBodyDataCompat struct {
|
||||
FormData []map[string]string `json:"form_data"`
|
||||
Urlencoded []map[string]string `json:"urlencoded"`
|
||||
Raw string `json:"raw"`
|
||||
}
|
||||
|
||||
// isNewEmailApiConfig 检测配置是否使用新格式(通过 params 字段判断)
|
||||
func isNewEmailApiConfig(config request.EmailApiConfig) bool {
|
||||
return config.Params != nil
|
||||
}
|
||||
|
||||
// convertOldBodyDataToNew 将旧格式 BodyData([]map[string]string)转换为新格式([]BodyField)
|
||||
func convertOldBodyDataToNew(old oldBodyDataCompat) request.BodyData {
|
||||
newData := request.BodyData{Raw: old.Raw}
|
||||
for _, kv := range old.FormData {
|
||||
for k, v := range kv {
|
||||
newData.FormData = append(newData.FormData, request.BodyField{
|
||||
Key: k,
|
||||
ValueType: request.ValueTypeString,
|
||||
Value: v,
|
||||
})
|
||||
}
|
||||
}
|
||||
for _, kv := range old.Urlencoded {
|
||||
for k, v := range kv {
|
||||
newData.Urlencoded = append(newData.Urlencoded, request.BodyField{
|
||||
Key: k,
|
||||
ValueType: request.ValueTypeString,
|
||||
Value: v,
|
||||
})
|
||||
}
|
||||
}
|
||||
return newData
|
||||
}
|
||||
|
||||
// parseEmailApiConfigFromJSON 解析 EmailApiConfig,自动兼容新旧格式
|
||||
// 新格式:包含 params 字段
|
||||
// 旧格式:包含 request_param_field 字段,body_data 中使用 []map[string]string
|
||||
func parseEmailApiConfigFromJSON(raw json.RawMessage) (request.EmailApiConfig, error) {
|
||||
// 先尝试解析为新格式
|
||||
var newConfig request.EmailApiConfig
|
||||
if err := json.Unmarshal(raw, &newConfig); err != nil {
|
||||
return request.EmailApiConfig{}, fmt.Errorf("解析邮箱配置失败: %s", err.Error())
|
||||
}
|
||||
|
||||
// 如果有 params 字段,说明是新格式
|
||||
if isNewEmailApiConfig(newConfig) {
|
||||
return newConfig, nil
|
||||
}
|
||||
|
||||
// 旧格式:尝试解析 body_data 中的 []map[string]string
|
||||
var oldCompat struct {
|
||||
request.EmailApiConfig
|
||||
BodyData oldBodyDataCompat `json:"body_data"`
|
||||
}
|
||||
if err := json.Unmarshal(raw, &oldCompat); err == nil {
|
||||
newConfig.BodyData = convertOldBodyDataToNew(oldCompat.BodyData)
|
||||
global.GVA_LOG.Info("邮箱配置:检测到旧格式,已在内存中转换为新格式",
|
||||
zap.String("request_param_field", newConfig.RequestParamField))
|
||||
}
|
||||
|
||||
return newConfig, nil
|
||||
}
|
||||
|
||||
// validateEmailApiConfigFields 验证 EmailApiConfig 字段
|
||||
func validateEmailApiConfigFields(cfg request.EmailApiConfig) error {
|
||||
if !cfg.Enabled {
|
||||
return nil
|
||||
}
|
||||
if cfg.URL == "" {
|
||||
return errors.New("邮箱API URL不能为空")
|
||||
}
|
||||
|
||||
if configMap.EmailApi.Method == "" {
|
||||
configMap.EmailApi.Method = "GET"
|
||||
if cfg.Method == "" {
|
||||
cfg.Method = "GET"
|
||||
}
|
||||
|
||||
if configMap.EmailApi.RequestParamField == "" {
|
||||
return errors.New("邮箱请求字段不能为空")
|
||||
}
|
||||
|
||||
if configMap.EmailApi.ResponseEmailField == "" {
|
||||
return errors.New("邮箱信息提取字段不能为空")
|
||||
}
|
||||
|
||||
// 验证Body类型(仅POST/PUT/DELETE需要)
|
||||
if configMap.EmailApi.Method != "GET" {
|
||||
bodyType := strings.ToLower(configMap.EmailApi.BodyType)
|
||||
if bodyType == "" {
|
||||
configMap.EmailApi.BodyType = "raw" // 默认raw
|
||||
} else if bodyType != "form-data" && bodyType != "x-www-form-urlencoded" && bodyType != "raw" {
|
||||
return fmt.Errorf("不支持的Body类型: %s,支持的类型: form-data, x-www-form-urlencoded, raw", bodyType)
|
||||
// 新格式不强制要求 request_param_field(通过 params 配置 URL 查询参数)
|
||||
if isNewEmailApiConfig(cfg) {
|
||||
// Params 只支持 string 和 ding_id 两种类型
|
||||
for i, p := range cfg.Params {
|
||||
if err := validateParamValueType(p.ValueType); err != nil {
|
||||
return fmt.Errorf("第%d个 URL 参数类型无效:%s", i+1, err.Error())
|
||||
}
|
||||
}
|
||||
// Body fields 支持 string、int、bool、ding_id
|
||||
for i, f := range cfg.BodyData.FormData {
|
||||
if err := validateValueType(f.ValueType); err != nil {
|
||||
return fmt.Errorf("form-data 第%d个字段类型无效:%s", i+1, err.Error())
|
||||
}
|
||||
if f.ValueType == request.ValueTypeInt && f.Value != "" {
|
||||
if _, err := strconv.ParseInt(f.Value, 10, 64); err != nil {
|
||||
return fmt.Errorf("form-data 第%d个字段(%s)的值不是有效整数:%s", i+1, f.Key, f.Value)
|
||||
}
|
||||
}
|
||||
if f.ValueType == request.ValueTypeBool && f.Value != "" {
|
||||
if _, err := strconv.ParseBool(f.Value); err != nil {
|
||||
return fmt.Errorf("form-data 第%d个字段(%s)的值不是有效布尔值:%s", i+1, f.Key, f.Value)
|
||||
}
|
||||
}
|
||||
}
|
||||
for i, f := range cfg.BodyData.Urlencoded {
|
||||
if err := validateValueType(f.ValueType); err != nil {
|
||||
return fmt.Errorf("urlencoded 第%d个字段类型无效:%s", i+1, err.Error())
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 旧格式兼容:request_param_field 不能为空
|
||||
if cfg.RequestParamField == "" {
|
||||
return errors.New("邮箱请求字段不能为空")
|
||||
}
|
||||
}
|
||||
|
||||
// 验证Authorization配置
|
||||
authType := strings.ToLower(configMap.EmailApi.Authorization.Type)
|
||||
if cfg.ResponseEmailField == "" {
|
||||
return errors.New("邮箱信息提取字段不能为空")
|
||||
}
|
||||
|
||||
if cfg.Method != "GET" {
|
||||
bodyType := strings.ToLower(cfg.BodyType)
|
||||
if bodyType != "" && bodyType != "form-data" && bodyType != "x-www-form-urlencoded" && bodyType != "raw" {
|
||||
return fmt.Errorf("不支持的Body类型: %s,支持的类型: form-data, x-www-form-urlencoded, raw", cfg.BodyType)
|
||||
}
|
||||
}
|
||||
|
||||
authType := strings.ToLower(cfg.Authorization.Type)
|
||||
if authType != "" && authType != "none" {
|
||||
if authType == "bearer" {
|
||||
if configMap.EmailApi.Authorization.Token == "" {
|
||||
if cfg.Authorization.Token == "" {
|
||||
return errors.New("Bearer Token不能为空")
|
||||
}
|
||||
} else if authType == "basic" {
|
||||
if configMap.EmailApi.Authorization.Username == "" || configMap.EmailApi.Authorization.Password == "" {
|
||||
if cfg.Authorization.Username == "" || cfg.Authorization.Password == "" {
|
||||
return errors.New("Basic Auth需要填写Username和Password")
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("不支持的Authorization类型: %s,支持的类型: none, bearer, basic", authType)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateValueType 验证 Body 字段的 ValueType 是否合法(支持全部四种)
|
||||
func validateValueType(vt string) error {
|
||||
switch vt {
|
||||
case "", request.ValueTypeString, request.ValueTypeInt, request.ValueTypeBool, request.ValueTypeDingID:
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("不支持的值类型: %s,支持的类型: string, int, bool, ding_id", vt)
|
||||
}
|
||||
}
|
||||
|
||||
// validateParamValueType 验证 URL Params 的 ValueType 是否合法(只支持 string 和 ding_id)
|
||||
func validateParamValueType(vt string) error {
|
||||
switch vt {
|
||||
case "", request.ValueTypeString, request.ValueTypeDingID:
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("URL 参数不支持的值类型: %s,支持的类型: string, ding_id", vt)
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateEmailApiConfig 验证第三方邮箱API配置
|
||||
// @Tags System Integrated
|
||||
// @Summary 验证第三方邮箱API配置
|
||||
// @param: integrate gaia.SystemIntegration
|
||||
// @return: error
|
||||
func (e *SystemIntegratedService) ValidateEmailApiConfig(integrate gaia.SystemIntegration) error {
|
||||
if integrate.Config == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var configMap struct {
|
||||
EmailApi json.RawMessage `json:"email_api"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(integrate.Config), &configMap); err != nil {
|
||||
return fmt.Errorf("解析配置失败: %s", err.Error())
|
||||
}
|
||||
if configMap.EmailApi == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
cfg, err := parseEmailApiConfigFromJSON(configMap.EmailApi)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = validateEmailApiConfigFields(cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
global.GVA_LOG.Info("第三方邮箱API配置验证通过",
|
||||
zap.String("url", configMap.EmailApi.URL),
|
||||
zap.String("method", configMap.EmailApi.Method),
|
||||
zap.String("body_type", configMap.EmailApi.BodyType),
|
||||
zap.String("auth_type", configMap.EmailApi.Authorization.Type))
|
||||
zap.String("url", cfg.URL),
|
||||
zap.String("method", cfg.Method),
|
||||
zap.String("body_type", cfg.BodyType),
|
||||
zap.String("auth_type", cfg.Authorization.Type),
|
||||
zap.Bool("new_format", isNewEmailApiConfig(cfg)))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// TestEmailApiConfig 测试邮箱 API 配置,返回详细的响应结果用于调试
|
||||
func (e *SystemIntegratedService) TestEmailApiConfig(cfg request.EmailApiConfig, testDingID string) (*gaiaResp.TestEmailApiConfigResponse, error) {
|
||||
if err := validateEmailApiConfigFields(cfg); err != nil {
|
||||
return &gaiaResp.TestEmailApiConfigResponse{
|
||||
IsValid: false,
|
||||
ErrorMessage: "配置验证失败:" + err.Error(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
dingId := strings.TrimSpace(testDingID)
|
||||
if dingId == "" {
|
||||
return &gaiaResp.TestEmailApiConfigResponse{
|
||||
IsValid: false,
|
||||
ErrorMessage: "测试钉钉 ID 不能为空,请先在弹窗中填写一个真实的 ding_id",
|
||||
}, nil
|
||||
}
|
||||
|
||||
respBody, statusCode, reqErr := e.doEmailApiRequest(dingId, cfg)
|
||||
|
||||
result := &gaiaResp.TestEmailApiConfigResponse{
|
||||
StatusCode: statusCode,
|
||||
}
|
||||
|
||||
// 尝试解析响应 Body 为 JSON
|
||||
var bodyJSON interface{}
|
||||
if json.Unmarshal(respBody, &bodyJSON) == nil {
|
||||
result.Body = bodyJSON
|
||||
} else {
|
||||
result.Body = string(respBody)
|
||||
}
|
||||
|
||||
if reqErr != nil {
|
||||
result.IsValid = false
|
||||
result.ErrorMessage = reqErr.Error()
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// 尝试提取邮箱字段
|
||||
if cfg.ResponseEmailField != "" {
|
||||
if bodyMap, ok := bodyJSON.(map[string]interface{}); ok {
|
||||
email := extractJSONPathAdvanced(bodyMap, cfg.ResponseEmailField)
|
||||
if email != "" {
|
||||
result.EmailFieldPreview = cfg.ResponseEmailField + " = " + email
|
||||
result.IsValid = true
|
||||
} else {
|
||||
result.IsValid = false
|
||||
result.ErrorMessage = "未找到邮箱字段:" + cfg.ResponseEmailField
|
||||
}
|
||||
} else if bodySlice, ok := bodyJSON.([]interface{}); ok {
|
||||
email := extractJSONPathAdvanced(bodySlice, cfg.ResponseEmailField)
|
||||
if email != "" {
|
||||
result.EmailFieldPreview = cfg.ResponseEmailField + " = " + email
|
||||
result.IsValid = true
|
||||
} else {
|
||||
result.IsValid = false
|
||||
result.ErrorMessage = "未找到邮箱字段:" + cfg.ResponseEmailField
|
||||
}
|
||||
} else {
|
||||
result.IsValid = statusCode >= 200 && statusCode < 300
|
||||
}
|
||||
} else {
|
||||
result.IsValid = statusCode >= 200 && statusCode < 300
|
||||
}
|
||||
|
||||
global.GVA_LOG.Info("测试邮箱 API 配置",
|
||||
zap.Int("status_code", statusCode),
|
||||
zap.String("email_preview", result.EmailFieldPreview),
|
||||
zap.Bool("is_valid", result.IsValid))
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// TestOAuth2Connection 测试OAuth2连接
|
||||
// @Tags System Integrated
|
||||
// @Summary 测试OAuth2连接
|
||||
@@ -325,3 +595,412 @@ func (e *SystemIntegratedService) TestOAuth2Connection(integrate gaia.SystemInte
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateForwardConfig 验证转发集成配置
|
||||
// @Tags System Integrated
|
||||
// @Summary 验证转发集成配置
|
||||
// @param: integrate gaia.SystemIntegration
|
||||
// @return: error
|
||||
func (e *SystemIntegratedService) ValidateForwardConfig(integrate gaia.SystemIntegration) error {
|
||||
// 解析 Config 字段
|
||||
if integrate.Config == "" {
|
||||
return nil // 配置为空不算错误
|
||||
}
|
||||
|
||||
var configMap request.DingTalkConfigRequest
|
||||
if err := json.Unmarshal([]byte(integrate.Config), &configMap); err != nil {
|
||||
return fmt.Errorf("解析配置失败:%s", err.Error())
|
||||
}
|
||||
|
||||
// 若未配置转发 Token,则认为未使用转发能力,不强制校验
|
||||
if len(configMap.ForwardConfig.Tokens) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 使用转发能力的前置条件:至少 1 个 Token + 启用并配置「第三方邮箱配置」
|
||||
if !configMap.EmailApi.Enabled || strings.TrimSpace(configMap.EmailApi.URL) == "" {
|
||||
return errors.New("使用转发能力前请先启用并配置「第三方邮箱配置」")
|
||||
}
|
||||
|
||||
// 验证 Token 数量
|
||||
if len(configMap.ForwardConfig.Tokens) > 20 {
|
||||
return errors.New("转发 Token 最多 20 个")
|
||||
}
|
||||
|
||||
// 验证每个 Token 的必填字段
|
||||
for i, token := range configMap.ForwardConfig.Tokens {
|
||||
if token.ID == "" {
|
||||
return fmt.Errorf("第%d个 Token 的 ID 不能为空", i+1)
|
||||
}
|
||||
if token.TokenHash == "" {
|
||||
return fmt.Errorf("第%d个 Token 的 TokenHash 不能为空", i+1)
|
||||
}
|
||||
}
|
||||
|
||||
global.GVA_LOG.Info("转发集成配置验证通过",
|
||||
zap.Int("token_count", len(configMap.ForwardConfig.Tokens)))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateDingIdApiConfig 验证第三方钉钉 ID 匹配 API 配置
|
||||
// @Tags System Integrated
|
||||
// @Summary 验证第三方钉钉 ID 匹配 API 配置
|
||||
// @param: integrate gaia.SystemIntegration
|
||||
// @return: error
|
||||
// extractJSONPath 按点分路径从 JSON 对象中提取字符串值,支持 "data.username" 等多层路径
|
||||
func extractJSONPath(data map[string]interface{}, path string) string {
|
||||
parts := strings.SplitN(path, ".", 2)
|
||||
val, ok := data[parts[0]]
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
if len(parts) == 1 {
|
||||
if s, ok := val.(string); ok {
|
||||
return s
|
||||
}
|
||||
return fmt.Sprintf("%v", val)
|
||||
}
|
||||
if nested, ok := val.(map[string]interface{}); ok {
|
||||
return extractJSONPath(nested, parts[1])
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// resolveParamValue 根据 ValueType 解析参数值,ding_id 类型替换为实际的钉钉 ID
|
||||
func resolveParamValue(vt, value, dingId string) string {
|
||||
switch vt {
|
||||
case request.ValueTypeDingID:
|
||||
return dingId
|
||||
default:
|
||||
// 兼容旧格式的 {{ding_id}} 占位符和新格式的 $<{[ding_id]}> 标记
|
||||
v := strings.ReplaceAll(value, "{{ding_id}}", dingId)
|
||||
v = strings.ReplaceAll(v, request.DingIDMarker, dingId)
|
||||
return v
|
||||
}
|
||||
}
|
||||
|
||||
// buildBodyFields 将 []BodyField 按类型转换,构建 url.Values(用于 form-data 和 urlencoded)
|
||||
func buildBodyFields(fields []request.BodyField, dingId string) url.Values {
|
||||
form := url.Values{}
|
||||
for _, f := range fields {
|
||||
if f.Key == "" {
|
||||
continue
|
||||
}
|
||||
val := resolveParamValue(f.ValueType, f.Value, dingId)
|
||||
form.Set(f.Key, val)
|
||||
}
|
||||
return form
|
||||
}
|
||||
|
||||
// buildURL 根据新格式 Params 或旧格式 RequestParamField 构建带查询参数的 URL
|
||||
func buildURL(baseURL string, config request.EmailApiConfig, dingId string) string {
|
||||
if isNewEmailApiConfig(config) {
|
||||
// 新格式:遍历 Params 列表自动拼接
|
||||
params := url.Values{}
|
||||
for _, p := range config.Params {
|
||||
if p.Key == "" {
|
||||
continue
|
||||
}
|
||||
params.Set(p.Key, resolveParamValue(p.ValueType, p.Value, dingId))
|
||||
}
|
||||
if len(params) == 0 {
|
||||
return baseURL
|
||||
}
|
||||
sep := "?"
|
||||
if strings.Contains(baseURL, "?") {
|
||||
sep = "&"
|
||||
}
|
||||
return baseURL + sep + params.Encode()
|
||||
}
|
||||
|
||||
// 旧格式:RequestParamField 字段名 + dingId 作为值
|
||||
if config.RequestParamField != "" {
|
||||
sep := "?"
|
||||
if strings.Contains(baseURL, "?") {
|
||||
sep = "&"
|
||||
}
|
||||
return baseURL + sep + url.QueryEscape(config.RequestParamField) + "=" + url.QueryEscape(dingId)
|
||||
}
|
||||
return baseURL
|
||||
}
|
||||
|
||||
// callEmailApi 调用第三方邮箱 API,使用 ding_id(用户名) 获取邮箱
|
||||
func (e *SystemIntegratedService) callEmailApi(
|
||||
dingId string, config request.EmailApiConfig) (mailList []string, err error) {
|
||||
// init
|
||||
respBody, _, err := e.doEmailApiRequest(dingId, config)
|
||||
if err != nil {
|
||||
return mailList, err
|
||||
}
|
||||
|
||||
var respJSON map[string]interface{}
|
||||
if err = json.Unmarshal(respBody, &respJSON); err != nil {
|
||||
return mailList, fmt.Errorf("解析响应 JSON 失败:%s", err.Error())
|
||||
}
|
||||
|
||||
email := extractJSONPathAdvanced(respJSON, config.ResponseEmailField)
|
||||
if email == "" {
|
||||
return mailList, fmt.Errorf("响应中未找到邮箱(路径:%s)", config.ResponseEmailField)
|
||||
}
|
||||
//
|
||||
mailList = append(mailList, email)
|
||||
parts := strings.Split(email, "@")
|
||||
defaultMail := os.Getenv(gaia.EmailDomainEnv)
|
||||
if len(defaultMail) > 1 && len(parts) > 1 && len(parts[0]) > 0 {
|
||||
mailList = append(mailList, parts[0]+"@"+defaultMail)
|
||||
}
|
||||
|
||||
return mailList, nil
|
||||
}
|
||||
|
||||
// doEmailApiRequest 构建并执行邮箱 API 请求,返回响应体字节和状态码
|
||||
func (e *SystemIntegratedService) doEmailApiRequest(dingId string, config request.EmailApiConfig) ([]byte, int, error) {
|
||||
method := strings.ToUpper(config.Method)
|
||||
if method == "" {
|
||||
method = "GET"
|
||||
}
|
||||
|
||||
var bodyReader io.Reader
|
||||
var contentType string
|
||||
if method != "GET" && method != "HEAD" {
|
||||
switch strings.ToLower(config.BodyType) {
|
||||
case "raw":
|
||||
raw := strings.ReplaceAll(config.BodyData.Raw, "{{ding_id}}", dingId)
|
||||
raw = strings.ReplaceAll(raw, request.DingIDMarker, dingId)
|
||||
bodyReader = strings.NewReader(raw)
|
||||
contentType = "application/json"
|
||||
case "form-data":
|
||||
form := buildBodyFields(config.BodyData.FormData, dingId)
|
||||
// 也处理旧格式的 urlencoded 字段(兼容)
|
||||
if len(form) == 0 {
|
||||
form = buildBodyFields(config.BodyData.Urlencoded, dingId)
|
||||
}
|
||||
bodyReader = strings.NewReader(form.Encode())
|
||||
contentType = "multipart/form-data"
|
||||
case "x-www-form-urlencoded":
|
||||
form := buildBodyFields(config.BodyData.Urlencoded, dingId)
|
||||
if len(form) == 0 {
|
||||
form = buildBodyFields(config.BodyData.FormData, dingId)
|
||||
}
|
||||
bodyReader = strings.NewReader(form.Encode())
|
||||
contentType = "application/x-www-form-urlencoded"
|
||||
}
|
||||
}
|
||||
|
||||
apiURL := buildURL(config.URL, config, dingId)
|
||||
|
||||
req, err := http.NewRequest(method, apiURL, bodyReader)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("构建请求失败:%s", err.Error())
|
||||
}
|
||||
|
||||
// 设置 Content-Type(如果 Headers 未覆盖)
|
||||
if contentType != "" {
|
||||
req.Header.Set("Content-Type", contentType)
|
||||
}
|
||||
|
||||
// 设置 Headers(可覆盖 Content-Type)
|
||||
for k, v := range config.Headers {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
|
||||
// 设置 Authorization
|
||||
authType := strings.ToLower(config.Authorization.Type)
|
||||
switch authType {
|
||||
case "bearer":
|
||||
req.Header.Set("Authorization", "Bearer "+config.Authorization.Token)
|
||||
case "basic":
|
||||
req.SetBasicAuth(config.Authorization.Username, config.Authorization.Password)
|
||||
}
|
||||
|
||||
client := &http.Client{Timeout: 15 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("请求失败:%s", err.Error())
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, resp.StatusCode, fmt.Errorf("读取响应失败:%s", err.Error())
|
||||
}
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return respBody, resp.StatusCode, fmt.Errorf("第三方 API 返回错误状态码:%d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return respBody, resp.StatusCode, nil
|
||||
}
|
||||
|
||||
// extractJSONPathAdvanced 支持点分路径和数组索引(如 data[0].userName)
|
||||
func extractJSONPathAdvanced(data interface{}, path string) string {
|
||||
if path == "" {
|
||||
return ""
|
||||
}
|
||||
parts := splitJSONPath(path)
|
||||
current := data
|
||||
for _, part := range parts {
|
||||
switch v := current.(type) {
|
||||
case map[string]interface{}:
|
||||
current = v[part.key]
|
||||
case []interface{}:
|
||||
if part.isIndex && part.index >= 0 && part.index < len(v) {
|
||||
current = v[part.index]
|
||||
} else {
|
||||
return ""
|
||||
}
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
if current == nil {
|
||||
return ""
|
||||
}
|
||||
}
|
||||
switch v := current.(type) {
|
||||
case string:
|
||||
return v
|
||||
case float64:
|
||||
return strconv.FormatFloat(v, 'f', -1, 64)
|
||||
case bool:
|
||||
return strconv.FormatBool(v)
|
||||
default:
|
||||
return fmt.Sprintf("%v", v)
|
||||
}
|
||||
}
|
||||
|
||||
type jsonPathPart struct {
|
||||
key string
|
||||
isIndex bool
|
||||
index int
|
||||
}
|
||||
|
||||
// splitJSONPath 将路径字符串分割为结构化的部分列表,支持 data[0].userName 格式
|
||||
func splitJSONPath(path string) []jsonPathPart {
|
||||
var parts []jsonPathPart
|
||||
// 先按点分割
|
||||
segments := strings.Split(path, ".")
|
||||
for _, seg := range segments {
|
||||
seg = strings.TrimSpace(seg)
|
||||
if seg == "" {
|
||||
continue
|
||||
}
|
||||
// 检查是否包含数组索引 data[0]
|
||||
if idx := strings.Index(seg, "["); idx != -1 {
|
||||
key := seg[:idx]
|
||||
rest := seg[idx:]
|
||||
if key != "" {
|
||||
parts = append(parts, jsonPathPart{key: key})
|
||||
}
|
||||
// 解析所有 [N] 部分
|
||||
for len(rest) > 0 && rest[0] == '[' {
|
||||
end := strings.Index(rest, "]")
|
||||
if end == -1 {
|
||||
break
|
||||
}
|
||||
idxStr := rest[1:end]
|
||||
rest = rest[end+1:]
|
||||
if n, err := strconv.Atoi(idxStr); err == nil {
|
||||
parts = append(parts, jsonPathPart{isIndex: true, index: n})
|
||||
}
|
||||
}
|
||||
} else {
|
||||
parts = append(parts, jsonPathPart{key: seg})
|
||||
}
|
||||
}
|
||||
return parts
|
||||
}
|
||||
|
||||
// ParseForwardToken 从 Bearer Token 中验签并提取 ding_id
|
||||
// 遍历 tokens 列表,找到签名匹配的条目
|
||||
func (e *SystemIntegratedService) ParseForwardToken(
|
||||
rawToken string, tokens []request.ForwardToken) (dingId string, err error) {
|
||||
parts := strings.SplitN(rawToken, ".", 2)
|
||||
if len(parts) != 2 {
|
||||
return "", errors.New("token 格式非法")
|
||||
}
|
||||
payloadB64, sigB64 := parts[0], parts[1]
|
||||
sigBytes, err := base64.RawURLEncoding.DecodeString(sigB64)
|
||||
if err != nil {
|
||||
return "", errors.New("签名解码失败")
|
||||
}
|
||||
for _, t := range tokens {
|
||||
if t.TokenSecret == "" {
|
||||
continue
|
||||
}
|
||||
// 直接使用原始字节作为 HMAC 密钥,兼容任意字符串格式的密钥
|
||||
secret := []byte(t.TokenSecret)
|
||||
// 验证 HMAC 签名
|
||||
mac := hmac.New(sha256.New, secret)
|
||||
mac.Write([]byte(payloadB64))
|
||||
expected := mac.Sum(nil)
|
||||
if !hmac.Equal(expected, sigBytes) {
|
||||
continue // 不匹配,试下一个
|
||||
}
|
||||
// 签名验证通过,解析 payload
|
||||
payloadBytes, err := base64.RawURLEncoding.DecodeString(payloadB64)
|
||||
if err != nil {
|
||||
return "", errors.New("payload 解码失败")
|
||||
}
|
||||
var payload struct {
|
||||
DingID string `json:"ding_id"`
|
||||
}
|
||||
if err := json.Unmarshal(payloadBytes, &payload); err != nil {
|
||||
return "", errors.New("payload 解析失败")
|
||||
}
|
||||
if payload.DingID == "" {
|
||||
return "", errors.New("token 中缺少 ding_id")
|
||||
}
|
||||
return payload.DingID, nil
|
||||
}
|
||||
return "", errors.New("无效的转发 Token")
|
||||
}
|
||||
|
||||
// ResolveAccountByDingId 通过钉钉 ID 解析 gaia account_id
|
||||
// 解析顺序:Redis 缓存 → AccountDingTalkExtend 本地表 → 第三方 EmailApi(邮箱 API)
|
||||
func (e *SystemIntegratedService) ResolveAccountByDingId(
|
||||
dingId string, apiConfig request.EmailApiConfig) (string, error) {
|
||||
|
||||
// 1. 查 Redis 缓存
|
||||
ctx := context.Background()
|
||||
redisKey := gaia.RedisKeyGaiaForwardDingPrefix + dingId
|
||||
if cached, err := global.GVA_REDIS.Get(ctx, redisKey).Result(); err == nil && cached != "" {
|
||||
return cached, nil
|
||||
}
|
||||
|
||||
// 2. 查本地 AccountDingTalkExtend 表
|
||||
var extend gaia.AccountDingTalkExtend
|
||||
if err := global.GVA_DB.Where("ding_talk = ?", dingId).First(&extend).Error; err == nil {
|
||||
accountID := extend.ID.String()
|
||||
global.GVA_REDIS.Set(ctx, redisKey, accountID, 24*time.Hour)
|
||||
return accountID, nil
|
||||
}
|
||||
|
||||
// 3. 第三方邮箱 API(若配置)
|
||||
if !apiConfig.Enabled || strings.TrimSpace(apiConfig.URL) == "" {
|
||||
return "", fmt.Errorf("未找到 ding_id=%s 对应的用户,且未配置第三方邮箱 API", dingId)
|
||||
}
|
||||
|
||||
email, err := e.callEmailApi(dingId, apiConfig)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("调用第三方邮箱 API 失败:%s", err.Error())
|
||||
}
|
||||
|
||||
// 4. 按邮箱查 accounts 表(匹配 email 字段)
|
||||
var account gaia.Account
|
||||
if err = global.GVA_DB.Where("email = ?", email).First(&account).Error; err != nil {
|
||||
return "", fmt.Errorf("用户 %s 不存在(来自第三方邮箱 API)", email)
|
||||
}
|
||||
|
||||
accountID := account.ID.String()
|
||||
|
||||
// 5. 写回 AccountDingTalkExtend,方便下次本地命中
|
||||
global.GVA_DB.Create(&gaia.AccountDingTalkExtend{
|
||||
ID: account.ID,
|
||||
DingTalk: dingId,
|
||||
})
|
||||
|
||||
// 6. 写 Redis 缓存
|
||||
global.GVA_REDIS.Set(ctx, redisKey, accountID, 24*time.Hour)
|
||||
return accountID, nil
|
||||
}
|
||||
|
||||
@@ -15,13 +15,6 @@ import (
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/model/gaia"
|
||||
)
|
||||
|
||||
// UserWorkerAllocation 用户工作器分配信息
|
||||
type UserWorkerAllocation struct {
|
||||
UserID uint `json:"user_id"`
|
||||
Workers int `json:"workers"`
|
||||
MaxLimit int `json:"max_limit"`
|
||||
}
|
||||
|
||||
// 全局工作池实例
|
||||
var globalWorkerPool *WorkerPool
|
||||
|
||||
@@ -29,10 +22,10 @@ var globalWorkerPool *WorkerPool
|
||||
type WorkerPool struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
totalWorkers int // 总工作器数量
|
||||
userWorkers map[uint]*UserWorkerAllocation // 每个用户的工作器分配
|
||||
userTaskChan map[uint]chan *gaia.BatchWorkflowTask // 每个用户的任务队列
|
||||
runningWorkers map[uint]int // 每个用户当前运行的worker数量
|
||||
totalWorkers int // 总工作器数量
|
||||
userWorkers map[uint]*gaia.UserWorkerAllocation // 每个用户的工作器分配
|
||||
userTaskChan map[uint]chan *gaia.BatchWorkflowTask // 每个用户的任务队列
|
||||
runningWorkers map[uint]int // 每个用户当前运行的worker数量
|
||||
wg sync.WaitGroup
|
||||
batchService *BatchWorkflowService
|
||||
running bool
|
||||
@@ -47,7 +40,7 @@ func NewWorkerPool(totalWorkers int) *WorkerPool {
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
totalWorkers: totalWorkers,
|
||||
userWorkers: make(map[uint]*UserWorkerAllocation),
|
||||
userWorkers: make(map[uint]*gaia.UserWorkerAllocation),
|
||||
userTaskChan: make(map[uint]chan *gaia.BatchWorkflowTask),
|
||||
runningWorkers: make(map[uint]int),
|
||||
batchService: &BatchWorkflowService{},
|
||||
@@ -101,15 +94,13 @@ func (wp *WorkerPool) calculateUserWorkerAllocation() {
|
||||
return
|
||||
}
|
||||
|
||||
// 第二个查询:获取这些用户的所有批量工作流的累计错误次数(不限状态)
|
||||
type UserErrorInfo struct {
|
||||
UserID uint `json:"user_id"`
|
||||
ErrorCount int `json:"error_count"`
|
||||
}
|
||||
var userErrorInfos []UserErrorInfo
|
||||
// 第二个查询:获取这些用户的累计错误次数和最小 total_rows(pending 状态工作流)
|
||||
var userErrorInfos []gaia.UserErrorInfo
|
||||
if len(activeUserIDs) > 0 {
|
||||
err = global.GVA_DB.Raw(`
|
||||
SELECT bw.user_id, COALESCE(SUM(bw.error_count), 0) as error_count
|
||||
SELECT bw.user_id,
|
||||
COALESCE(SUM(bw.error_count), 0) AS error_count,
|
||||
COALESCE(SUM(bw.total_rows), 0) AS total_rows
|
||||
FROM batch_workflows_extend bw
|
||||
WHERE bw.user_id IN (?) AND bw.status='pending'
|
||||
GROUP BY bw.user_id
|
||||
@@ -121,10 +112,12 @@ func (wp *WorkerPool) calculateUserWorkerAllocation() {
|
||||
}
|
||||
}
|
||||
|
||||
// 提取活跃用户ID列表和错误次数映射
|
||||
// 提取活跃用户ID列表、错误次数映射和 total_rows 映射
|
||||
userErrorMap := make(map[uint]int)
|
||||
userTotalRowsMap := make(map[uint]int)
|
||||
for _, info := range userErrorInfos {
|
||||
userErrorMap[info.UserID] = info.ErrorCount
|
||||
userTotalRowsMap[info.UserID] = info.TotalRows
|
||||
}
|
||||
|
||||
userCount := len(activeUserIDs)
|
||||
@@ -133,7 +126,7 @@ func (wp *WorkerPool) calculateUserWorkerAllocation() {
|
||||
for _, ch := range wp.userTaskChan {
|
||||
close(ch)
|
||||
}
|
||||
wp.userWorkers = make(map[uint]*UserWorkerAllocation)
|
||||
wp.userWorkers = make(map[uint]*gaia.UserWorkerAllocation)
|
||||
wp.userTaskChan = make(map[uint]chan *gaia.BatchWorkflowTask)
|
||||
wp.runningWorkers = make(map[uint]int)
|
||||
return
|
||||
@@ -155,226 +148,142 @@ func (wp *WorkerPool) calculateUserWorkerAllocation() {
|
||||
}
|
||||
}
|
||||
|
||||
// 检查用户数量是否超过了最大支持数量(每用户最少1个工作器)
|
||||
maxSupportedUsers := wp.totalWorkers / 1
|
||||
|
||||
// 存储新的分配计算结果
|
||||
newAllocations := make(map[uint]*UserWorkerAllocation)
|
||||
newAllocations := make(map[uint]*gaia.UserWorkerAllocation)
|
||||
|
||||
if userCount <= maxSupportedUsers {
|
||||
// 用户数量在可支持范围内,采用两阶段分配策略
|
||||
baseAllocation := wp.totalWorkers / userCount
|
||||
remainder := wp.totalWorkers % userCount
|
||||
|
||||
// 第一阶段:计算每个用户的基础分配和错误惩罚后的实际分配
|
||||
type UserAllocationInfo struct {
|
||||
UserID uint
|
||||
BaseWorkers int
|
||||
ActualWorkers int
|
||||
ErrorCount int
|
||||
PenaltyReduced int
|
||||
// 计算每个用户的优先级权重(total_rows 越小权重越大)
|
||||
// 权重:tier1=4, tier2=3, tier3=2, tier4=1
|
||||
priorityWeightOf := func(totalRows int) int {
|
||||
switch {
|
||||
case totalRows <= gaia.PriorityTier1MaxRows:
|
||||
return 4
|
||||
case totalRows <= gaia.PriorityTier2MaxRows:
|
||||
return 3
|
||||
case totalRows <= gaia.PriorityTier3MaxRows:
|
||||
return 2
|
||||
default:
|
||||
return 1
|
||||
}
|
||||
}
|
||||
|
||||
var userAllocations []UserAllocationInfo
|
||||
totalPenaltyReduced := 0
|
||||
// 计算所有用户的权重总和,用于按比例分配 worker
|
||||
totalWeight := 0
|
||||
for _, userID := range activeUserIDs {
|
||||
totalWeight += priorityWeightOf(userTotalRowsMap[userID])
|
||||
}
|
||||
if totalWeight == 0 {
|
||||
totalWeight = userCount
|
||||
}
|
||||
|
||||
for i, userID := range activeUserIDs {
|
||||
baseWorkers := baseAllocation
|
||||
// 处理余数,前几个用户多分配一个
|
||||
if i < remainder {
|
||||
baseWorkers++
|
||||
}
|
||||
// UserAllocationInfo 单用户分配中间计算结构
|
||||
type UserAllocationInfo struct {
|
||||
UserID uint
|
||||
BaseWorkers int
|
||||
ActualWorkers int
|
||||
ErrorCount int
|
||||
PenaltyReduced int
|
||||
}
|
||||
|
||||
// 确保每个用户至少有1个并发位
|
||||
if baseWorkers < 1 {
|
||||
baseWorkers = 1
|
||||
}
|
||||
|
||||
// 应用错误惩罚:根据用户的累计错误次数减少并发位
|
||||
errorCount := userErrorMap[userID]
|
||||
actualWorkers := wp.calculateWorkerCountWithErrorPenalty(baseWorkers, errorCount)
|
||||
penaltyReduced := baseWorkers - actualWorkers
|
||||
totalPenaltyReduced += penaltyReduced
|
||||
userAllocations = append(userAllocations, UserAllocationInfo{
|
||||
UserID: userID,
|
||||
BaseWorkers: baseWorkers,
|
||||
ActualWorkers: actualWorkers,
|
||||
ErrorCount: errorCount,
|
||||
PenaltyReduced: penaltyReduced,
|
||||
})
|
||||
// 按权重比例为每个用户分配基础 worker 数,余数补给权重最高的用户
|
||||
userAllocations := make([]UserAllocationInfo, 0, userCount)
|
||||
allocatedBase := 0
|
||||
for _, userID := range activeUserIDs {
|
||||
w := priorityWeightOf(userTotalRowsMap[userID])
|
||||
base := wp.totalWorkers * w / totalWeight
|
||||
if base < 1 {
|
||||
base = 1
|
||||
}
|
||||
userAllocations = append(userAllocations, UserAllocationInfo{
|
||||
UserID: userID,
|
||||
BaseWorkers: base,
|
||||
ErrorCount: userErrorMap[userID],
|
||||
})
|
||||
allocatedBase += base
|
||||
}
|
||||
// 将剩余 worker 补给权重最大的用户(已按 activeUserIDs 顺序,这里找最大权重的)
|
||||
remainder := wp.totalWorkers - allocatedBase
|
||||
if remainder > 0 {
|
||||
// 找权重最大的用户索引
|
||||
maxW, maxIdx := 0, 0
|
||||
for i, alloc := range userAllocations {
|
||||
w := priorityWeightOf(userTotalRowsMap[alloc.UserID])
|
||||
if w > maxW {
|
||||
maxW, maxIdx = w, i
|
||||
}
|
||||
}
|
||||
userAllocations[maxIdx].BaseWorkers += remainder
|
||||
}
|
||||
|
||||
// 第二阶段:将空出来的并发位重新分配给错误较少的用户
|
||||
if totalPenaltyReduced > 0 {
|
||||
// 按错误数量排序,错误少的用户优先获得额外分配
|
||||
for i := 0; i < len(userAllocations)-1; i++ {
|
||||
for j := i + 1; j < len(userAllocations); j++ {
|
||||
if userAllocations[i].ErrorCount > userAllocations[j].ErrorCount {
|
||||
userAllocations[i], userAllocations[j] = userAllocations[j], userAllocations[i]
|
||||
// 应用错误惩罚并收集被释放的 worker
|
||||
totalPenaltyReduced := 0
|
||||
for i := range userAllocations {
|
||||
actual := wp.calculateWorkerCountWithErrorPenalty(
|
||||
userAllocations[i].BaseWorkers, userAllocations[i].ErrorCount)
|
||||
userAllocations[i].PenaltyReduced = userAllocations[i].BaseWorkers - actual
|
||||
userAllocations[i].ActualWorkers = actual
|
||||
totalPenaltyReduced += userAllocations[i].PenaltyReduced
|
||||
}
|
||||
|
||||
// 将惩罚释放的 worker 重新分配给无惩罚用户(按权重优先)
|
||||
if totalPenaltyReduced > 0 {
|
||||
// 按权重降序、错误数升序排序,优先分配给高优先级无惩罚用户
|
||||
for i := 0; i < len(userAllocations)-1; i++ {
|
||||
for j := i + 1; j < len(userAllocations); j++ {
|
||||
wi := priorityWeightOf(userTotalRowsMap[userAllocations[i].UserID])
|
||||
wj := priorityWeightOf(userTotalRowsMap[userAllocations[j].UserID])
|
||||
if wi < wj || (wi == wj && userAllocations[i].ErrorCount > userAllocations[j].ErrorCount) {
|
||||
userAllocations[i], userAllocations[j] = userAllocations[j], userAllocations[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
remainingToDistribute := totalPenaltyReduced
|
||||
eligibleUsers := 0
|
||||
for _, a := range userAllocations {
|
||||
if a.PenaltyReduced == 0 {
|
||||
eligibleUsers++
|
||||
}
|
||||
}
|
||||
if eligibleUsers > 0 {
|
||||
for i := 0; i < len(userAllocations) && remainingToDistribute > 0; i++ {
|
||||
if userAllocations[i].PenaltyReduced == 0 {
|
||||
extra := remainingToDistribute / eligibleUsers
|
||||
if extra < 1 {
|
||||
extra = 1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 只为没有被惩罚的用户(PenaltyReduced = 0)重新分配空闲的并发位
|
||||
// 被惩罚的用户不应该获得额外分配
|
||||
remainingToDistribute := totalPenaltyReduced
|
||||
eligibleUsers := 0
|
||||
|
||||
// 计算有资格获得额外分配的用户数量(没有被惩罚的用户)
|
||||
for _, allocation := range userAllocations {
|
||||
if allocation.PenaltyReduced == 0 {
|
||||
eligibleUsers++
|
||||
}
|
||||
}
|
||||
|
||||
if eligibleUsers > 0 {
|
||||
// 只为没有被惩罚的用户分配额外的并发位
|
||||
for i := 0; i < len(userAllocations) && remainingToDistribute > 0; i++ {
|
||||
if userAllocations[i].PenaltyReduced == 0 {
|
||||
// 为没有错误惩罚的用户分配额外的并发位
|
||||
extraWorkers := remainingToDistribute / eligibleUsers
|
||||
if extraWorkers < 1 {
|
||||
extraWorkers = 1
|
||||
}
|
||||
if extraWorkers > remainingToDistribute {
|
||||
extraWorkers = remainingToDistribute
|
||||
}
|
||||
|
||||
userAllocations[i].ActualWorkers += extraWorkers
|
||||
remainingToDistribute -= extraWorkers
|
||||
eligibleUsers--
|
||||
if extra > remainingToDistribute {
|
||||
extra = remainingToDistribute
|
||||
}
|
||||
userAllocations[i].ActualWorkers += extra
|
||||
remainingToDistribute -= extra
|
||||
eligibleUsers--
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 创建最终分配结果
|
||||
totalFinalWorkers := 0
|
||||
for _, allocation := range userAllocations {
|
||||
newAllocations[allocation.UserID] = &UserWorkerAllocation{
|
||||
// 写入最终分配,超出总量时截断(降级场景)
|
||||
allocatedWorkers := 0
|
||||
for _, allocation := range userAllocations {
|
||||
workers := allocation.ActualWorkers
|
||||
remaining := wp.totalWorkers - allocatedWorkers
|
||||
if workers > remaining {
|
||||
workers = remaining
|
||||
}
|
||||
if workers > 0 {
|
||||
newAllocations[allocation.UserID] = &gaia.UserWorkerAllocation{
|
||||
UserID: allocation.UserID,
|
||||
Workers: allocation.ActualWorkers,
|
||||
Workers: workers,
|
||||
MaxLimit: wp.totalWorkers,
|
||||
}
|
||||
totalFinalWorkers += allocation.ActualWorkers
|
||||
allocatedWorkers += workers
|
||||
}
|
||||
} else {
|
||||
// 用户数量超过最大支持数量,采用降级分配策略(两阶段分配)
|
||||
baseAllocation := wp.totalWorkers / userCount
|
||||
remainder := wp.totalWorkers % userCount
|
||||
|
||||
// 第一阶段:计算每个用户的基础分配和错误惩罚后的实际分配
|
||||
type UserAllocationInfo struct {
|
||||
UserID uint
|
||||
BaseWorkers int
|
||||
ActualWorkers int
|
||||
ErrorCount int
|
||||
PenaltyReduced int
|
||||
if allocatedWorkers >= wp.totalWorkers {
|
||||
break
|
||||
}
|
||||
|
||||
var userAllocations []UserAllocationInfo
|
||||
totalPenaltyReduced := 0
|
||||
|
||||
for i, userID := range activeUserIDs {
|
||||
baseWorkers := baseAllocation
|
||||
// 处理余数,前几个用户多分配一个
|
||||
if i < remainder {
|
||||
baseWorkers++
|
||||
}
|
||||
|
||||
// 确保至少分配1个工作器
|
||||
if baseWorkers < 1 {
|
||||
baseWorkers = 1
|
||||
}
|
||||
|
||||
// 应用错误惩罚:根据用户的累计错误次数减少并发位
|
||||
errorCount := userErrorMap[userID]
|
||||
actualWorkers := wp.calculateWorkerCountWithErrorPenalty(baseWorkers, errorCount)
|
||||
penaltyReduced := baseWorkers - actualWorkers
|
||||
totalPenaltyReduced += penaltyReduced
|
||||
|
||||
// 添加详细的错误惩罚计算调试日志
|
||||
userAllocations = append(userAllocations, UserAllocationInfo{
|
||||
UserID: userID,
|
||||
BaseWorkers: baseWorkers,
|
||||
ActualWorkers: actualWorkers,
|
||||
ErrorCount: errorCount,
|
||||
PenaltyReduced: penaltyReduced,
|
||||
})
|
||||
}
|
||||
|
||||
// 第二阶段:将空出来的并发位重新分配给错误较少的用户
|
||||
if totalPenaltyReduced > 0 {
|
||||
// 按错误数量排序,错误少的用户优先获得额外分配
|
||||
for i := 0; i < len(userAllocations)-1; i++ {
|
||||
for j := i + 1; j < len(userAllocations); j++ {
|
||||
if userAllocations[i].ErrorCount > userAllocations[j].ErrorCount {
|
||||
userAllocations[i], userAllocations[j] = userAllocations[j], userAllocations[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 只为没有被惩罚的用户(PenaltyReduced = 0)重新分配空闲的并发位
|
||||
// 被惩罚的用户不应该获得额外分配
|
||||
remainingToDistribute := totalPenaltyReduced
|
||||
eligibleUsers := 0
|
||||
|
||||
// 计算有资格获得额外分配的用户数量(没有被惩罚的用户)
|
||||
for _, allocation := range userAllocations {
|
||||
if allocation.PenaltyReduced == 0 {
|
||||
eligibleUsers++
|
||||
}
|
||||
}
|
||||
|
||||
if eligibleUsers > 0 {
|
||||
// 只为没有被惩罚的用户分配额外的并发位
|
||||
for i := 0; i < len(userAllocations) && remainingToDistribute > 0; i++ {
|
||||
if userAllocations[i].PenaltyReduced == 0 {
|
||||
// 为没有错误惩罚的用户分配额外的并发位
|
||||
extraWorkers := remainingToDistribute / eligibleUsers
|
||||
if extraWorkers < 1 {
|
||||
extraWorkers = 1
|
||||
}
|
||||
if extraWorkers > remainingToDistribute {
|
||||
extraWorkers = remainingToDistribute
|
||||
}
|
||||
|
||||
userAllocations[i].ActualWorkers += extraWorkers
|
||||
remainingToDistribute -= extraWorkers
|
||||
eligibleUsers--
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 创建最终分配结果,确保不超过总工作器数量
|
||||
allocatedWorkers := 0
|
||||
for _, allocation := range userAllocations {
|
||||
workers := allocation.ActualWorkers
|
||||
|
||||
// 确保不会超过剩余的工作器数量
|
||||
remainingWorkers := wp.totalWorkers - allocatedWorkers
|
||||
if workers > remainingWorkers {
|
||||
workers = remainingWorkers
|
||||
}
|
||||
|
||||
if workers > 0 {
|
||||
newAllocations[allocation.UserID] = &UserWorkerAllocation{
|
||||
UserID: allocation.UserID,
|
||||
Workers: workers,
|
||||
MaxLimit: wp.totalWorkers,
|
||||
}
|
||||
allocatedWorkers += workers
|
||||
}
|
||||
|
||||
// 如果工作器已经分配完毕,剩余用户分配0个工作器
|
||||
if allocatedWorkers >= wp.totalWorkers {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
global.GVA_LOG.Warn(fmt.Sprintf("降级分配完成 - 总工作器: %d, 用户数: %d, 已分配: %d, 重新分配: %d, 平均每用户: %.1f个",
|
||||
wp.totalWorkers, userCount, allocatedWorkers, totalPenaltyReduced, float64(allocatedWorkers)/float64(userCount)))
|
||||
}
|
||||
if userCount > wp.totalWorkers {
|
||||
global.GVA_LOG.Warn(fmt.Sprintf("降级分配完成 - 总工作器: %d, 用户数: %d, 已分配: %d, 平均每用户: %.1f个",
|
||||
wp.totalWorkers, userCount, allocatedWorkers, float64(allocatedWorkers)/float64(userCount)))
|
||||
}
|
||||
|
||||
// 应用新的分配,只更新有变化的用户
|
||||
|
||||
@@ -171,8 +171,10 @@ func (casbinService *CasbinService) AddPolicies(db *gorm.DB, rules [][]string) e
|
||||
|
||||
func (CasbinService *CasbinService) FreshCasbin() (err error) {
|
||||
e := CasbinService.Casbin()
|
||||
err = e.LoadPolicy()
|
||||
return err
|
||||
if e == nil {
|
||||
return errors.New("casbin enforcer is nil, please check database initialization")
|
||||
}
|
||||
return e.LoadPolicy()
|
||||
}
|
||||
|
||||
//@author: [piexlmax](https://github.com/piexlmax)
|
||||
@@ -187,6 +189,10 @@ var (
|
||||
|
||||
func (casbinService *CasbinService) Casbin() *casbin.SyncedCachedEnforcer {
|
||||
once.Do(func() {
|
||||
if global.GVA_DB == nil {
|
||||
zap.L().Warn("Casbin initialization skipped: global.GVA_DB is nil")
|
||||
return
|
||||
}
|
||||
a, err := gormadapter.NewAdapterByDB(global.GVA_DB)
|
||||
if err != nil {
|
||||
zap.L().Error("适配数据库失败请检查casbin表是否为InnoDB引擎!", zap.Error(err))
|
||||
|
||||
@@ -6,7 +6,9 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/global"
|
||||
modelSystem "github.com/flipped-aurora/gin-vue-admin/server/model/system"
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/model/system/request"
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
"sort"
|
||||
)
|
||||
@@ -86,6 +88,22 @@ func RegisterInit(order int, i SubInitializer) {
|
||||
|
||||
type InitDBService struct{}
|
||||
|
||||
// IfInit 判断是否数据库初始化了
|
||||
func (initDBService *InitDBService) IfInit() (init bool) {
|
||||
var menuCount, authorityCount int64
|
||||
if global.GVA_DB != nil {
|
||||
init = true
|
||||
var menu modelSystem.SysBaseMenuBtn
|
||||
var authority modelSystem.SysAuthority
|
||||
global.GVA_DB.Model(&menu).Count(&menuCount)
|
||||
global.GVA_DB.Model(&authority).Count(&authorityCount)
|
||||
if menuCount <= 1 && authorityCount <= 1 {
|
||||
init = false
|
||||
}
|
||||
}
|
||||
return init
|
||||
}
|
||||
|
||||
// InitDB 创建数据库并初始化 总入口
|
||||
func (initDBService *InitDBService) InitDB(conf request.InitDB) (err error) {
|
||||
ctx := context.TODO()
|
||||
@@ -122,7 +140,9 @@ func (initDBService *InitDBService) InitDB(conf request.InitDB) (err error) {
|
||||
|
||||
db := ctx.Value("db").(*gorm.DB)
|
||||
global.GVA_DB = db
|
||||
|
||||
db.Exec("DELETE FROM sys_base_menus")
|
||||
db.Exec("DELETE FROM sys_authorities")
|
||||
db.Exec("DELETE FROM sys_user_authority")
|
||||
if err = initHandler.InitTables(ctx, initializers); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -133,6 +153,10 @@ func (initDBService *InitDBService) InitDB(conf request.InitDB) (err error) {
|
||||
if err = initHandler.WriteConfig(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
// 初始化完成后刷新 Casbin 策略,避免使用旧的或空的策略
|
||||
if err = CasbinServiceApp.FreshCasbin(); err != nil {
|
||||
global.GVA_LOG.Warn("refresh casbin policy after InitDB failed", zap.Error(err))
|
||||
}
|
||||
initializers = initSlice{}
|
||||
cache = map[string]*orderedInitializer{}
|
||||
return nil
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/spf13/viper"
|
||||
"go.uber.org/zap"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/config"
|
||||
@@ -35,10 +36,10 @@ func (h PgsqlInitHandler) WriteConfig(ctx context.Context) error {
|
||||
|
||||
// 改成拿dify的配置,如果不是docker运行,则从dify api的.env文件中获取jwt的加密key
|
||||
if !global.GVA_CONFIG.System.DockerRun {
|
||||
var err error
|
||||
global.GVA_CONFIG.JWT.SigningKey, err = h.GetJwtSigningKeyFormDifyApiEnv()
|
||||
if err != nil {
|
||||
return err
|
||||
if secretKey, err := h.GetJwtSigningKeyFormDifyApiEnv(); err != nil {
|
||||
global.GVA_LOG.Warn("failed to load JWT signing key from Dify API .env, using existing configuration", zap.Error(err))
|
||||
} else if secretKey != "" {
|
||||
global.GVA_CONFIG.JWT.SigningKey = secretKey
|
||||
}
|
||||
}
|
||||
cs := utils.StructToMap(global.GVA_CONFIG)
|
||||
|
||||
@@ -201,6 +201,9 @@ func (i *initApi) InitializeData(ctx context.Context) (context.Context, error) {
|
||||
// Extend Start: system integration
|
||||
{ApiGroup: "应用集成配置", Method: "GET", Path: "/gaia/system/dingtalk", Description: "获取钉钉系统配置"},
|
||||
{ApiGroup: "应用集成配置", Method: "POST", Path: "/gaia/system/dingtalk", Description: "设置钉钉系统配置"},
|
||||
{ApiGroup: "应用集成配置", Method: "POST", Path: "/gaia/system/dingtalk/test-email-config", Description: "测试钉钉邮箱API配置"},
|
||||
{ApiGroup: "应用集成配置", Method: "GET", Path: "/gaia/system/dingtalk/test-auth-url", Description: "测试连接-获取钉钉授权URL"},
|
||||
{ApiGroup: "应用集成配置", Method: "POST", Path: "/gaia/system/dingtalk/test-callback", Description: "测试连接-钉钉回调验证"},
|
||||
// Extend Stop: system integration
|
||||
|
||||
// Extend Start: oauth2
|
||||
@@ -244,6 +247,12 @@ func (i *initApi) InitializeData(ctx context.Context) (context.Context, error) {
|
||||
{ApiGroup: "模型管理", Method: "PATCH", Path: "/gaia/proxy/*", Description: "中转API(第三方)-PATCH"},
|
||||
{ApiGroup: "模型管理", Method: "DELETE", Path: "/gaia/proxy/*", Description: "中转API(第三方)-DELETE"},
|
||||
// Extend Stop: model provider
|
||||
|
||||
// Extend Start: 转发集成 (forward tokens)
|
||||
{ApiGroup: "转发集成", Method: "GET", Path: "/gaia/system/forward-tokens", Description: "获取转发 Token 列表"},
|
||||
{ApiGroup: "转发集成", Method: "POST", Path: "/gaia/system/forward-tokens", Description: "新增转发 Token"},
|
||||
{ApiGroup: "转发集成", Method: "DELETE", Path: "/gaia/system/forward-tokens/:id", Description: "删除转发 Token"},
|
||||
// Extend Stop: 转发集成
|
||||
}
|
||||
if err := db.Create(&entities).Error; err != nil {
|
||||
return ctx, errors.Wrap(err, sysModel.SysApi{}.TableName()+"表数据初始化失败!")
|
||||
|
||||
@@ -287,6 +287,9 @@ func (i *initCasbin) InitializeData(ctx context.Context) (context.Context, error
|
||||
// Extend Start: system integration
|
||||
{Ptype: "p", V0: "888", V1: "/gaia/system/dingtalk", V2: "GET"},
|
||||
{Ptype: "p", V0: "888", V1: "/gaia/system/dingtalk", V2: "POST"},
|
||||
{Ptype: "p", V0: "888", V1: "/gaia/system/dingtalk/test-email-config", V2: "POST"},
|
||||
{Ptype: "p", V0: "888", V1: "/gaia/system/dingtalk/test-auth-url", V2: "GET"},
|
||||
{Ptype: "p", V0: "888", V1: "/gaia/system/dingtalk/test-callback", V2: "POST"},
|
||||
// Extend Stop: system integration
|
||||
|
||||
// Extend Start: oauth2
|
||||
@@ -403,6 +406,15 @@ func (i *initCasbin) InitializeData(ctx context.Context) (context.Context, error
|
||||
{Ptype: "p", V0: "8881", V1: "/gaia/proxy/*", V2: "PATCH"},
|
||||
{Ptype: "p", V0: "8881", V1: "/gaia/proxy/*", V2: "DELETE"},
|
||||
// Extend Stop: model provider
|
||||
|
||||
// Extend Start: 转发集成 (forward tokens)
|
||||
{Ptype: "p", V0: "888", V1: "/gaia/system/forward-tokens", V2: "GET"},
|
||||
{Ptype: "p", V0: "888", V1: "/gaia/system/forward-tokens", V2: "POST"},
|
||||
{Ptype: "p", V0: "888", V1: "/gaia/system/forward-tokens/:id", V2: "DELETE"},
|
||||
{Ptype: "p", V0: "8881", V1: "/gaia/system/forward-tokens", V2: "GET"},
|
||||
{Ptype: "p", V0: "8881", V1: "/gaia/system/forward-tokens", V2: "POST"},
|
||||
{Ptype: "p", V0: "8881", V1: "/gaia/system/forward-tokens/:id", V2: "DELETE"},
|
||||
// Extend Stop: 转发集成
|
||||
}
|
||||
if err := db.Create(&entities).Error; err != nil {
|
||||
return ctx, errors.Wrap(err, "Casbin 表 ("+i.InitializerName()+") 数据初始化失败!")
|
||||
|
||||
@@ -41,7 +41,7 @@ export const getSystemOAuth2 = () => {
|
||||
}
|
||||
|
||||
// @Tags systrm
|
||||
// @Summary 修改OAuth2集成配置
|
||||
// @Summary 修改 OAuth2 集成配置
|
||||
// @Security ApiKeyAuth
|
||||
// @Produce application/json
|
||||
// @Success 200 {string} string "{"success":true,"data":{},"msg":"返回成功"}"
|
||||
@@ -53,3 +53,67 @@ export const setSystemOAuth2 = (data) => {
|
||||
data,
|
||||
})
|
||||
}
|
||||
|
||||
// @Tags systrm
|
||||
// @Summary 获取转发 Token 列表
|
||||
// @Security ApiKeyAuth
|
||||
// @Router /gaia/system/forward-tokens [get]
|
||||
export const getForwardTokens = () => {
|
||||
return service({
|
||||
url: '/gaia/system/forward-tokens',
|
||||
method: 'get'
|
||||
})
|
||||
}
|
||||
|
||||
// @Tags systrm
|
||||
// @Summary 新增转发 Token
|
||||
// @Security ApiKeyAuth
|
||||
// @Router /gaia/system/forward-tokens [post]
|
||||
export const createForwardToken = (data) => {
|
||||
return service({
|
||||
url: '/gaia/system/forward-tokens',
|
||||
method: 'post',
|
||||
data,
|
||||
})
|
||||
}
|
||||
|
||||
// @Tags systrm
|
||||
// @Summary 删除转发 Token
|
||||
// @Security ApiKeyAuth
|
||||
// @Router /gaia/system/forward-tokens/:seq [delete]
|
||||
export const deleteForwardToken = (seq, password) => {
|
||||
return service({
|
||||
url: `/gaia/system/forward-tokens/${seq}`,
|
||||
method: 'delete',
|
||||
data: { password },
|
||||
})
|
||||
}
|
||||
|
||||
// @Tags systrm
|
||||
// @Summary 测试第三方邮箱 API 配置
|
||||
// @Security ApiKeyAuth
|
||||
// @Router /gaia/system/dingtalk/test-email-config [post]
|
||||
export const testEmailApiConfig = (data) => {
|
||||
return service({
|
||||
url: '/gaia/system/dingtalk/test-email-config',
|
||||
method: 'post',
|
||||
data,
|
||||
})
|
||||
}
|
||||
|
||||
// 测试连接:获取钉钉授权 URL(打开后扫码完成即视为连接成功)
|
||||
export const getDingTalkTestAuthUrl = () => {
|
||||
return service({
|
||||
url: '/gaia/system/dingtalk/test-auth-url',
|
||||
method: 'get',
|
||||
})
|
||||
}
|
||||
|
||||
// 测试连接回调:仅用 code 验证,不登录
|
||||
export const dingtalkTestCallback = (data) => {
|
||||
return service({
|
||||
url: '/gaia/system/dingtalk/test-callback',
|
||||
method: 'post',
|
||||
data: { code: data.code },
|
||||
})
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import { useUserStore } from '@/pinia/modules/user'
|
||||
import { useRouterStore } from '@/pinia/modules/router'
|
||||
import router from '@/router'
|
||||
import { gaiaOAuth2Login, dingtalkLogin } from '@/api/user_extend'
|
||||
import { dingtalkTestCallback } from '@/api/gaia/system'
|
||||
|
||||
defineOptions({
|
||||
name: 'LoginCallback',
|
||||
@@ -70,6 +71,19 @@ const callback = async () => {
|
||||
const redirectUri = sessionStorage.getItem('gaia_login_redirect_uri') || ''
|
||||
const state = sessionStorage.getItem('gaia_login_state') || getQueryParam('state') || ''
|
||||
|
||||
// 测试连接回调:仅验证 code 换 token,不登录,结果通过 postMessage 回传并关闭
|
||||
if (provider === 'dingtalk' && state === 'dingtalk_test') {
|
||||
try {
|
||||
const res = await dingtalkTestCallback({ code })
|
||||
const payload = { type: 'dingtalk_test_result', success: res?.code === 0, message: res?.msg }
|
||||
if (window.opener) window.opener.postMessage(payload, '*')
|
||||
} catch (e) {
|
||||
if (window.opener) window.opener.postMessage({ type: 'dingtalk_test_result', success: false, message: e?.message || '验证失败' }, '*')
|
||||
}
|
||||
window.close()
|
||||
return
|
||||
}
|
||||
|
||||
try {
|
||||
if (provider === 'dingtalk') {
|
||||
if (!hasCode) {
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -146,7 +146,13 @@ const providerList = ref([])
|
||||
const providerDisplayNames = {
|
||||
openai: 'OpenAI',
|
||||
tongyi: '千问(通义)',
|
||||
google: 'Google Gemini'
|
||||
google: 'Google Gemini',
|
||||
anthropic: 'Anthropic',
|
||||
aws: 'AWS Bedrock',
|
||||
azure: 'Azure OpenAI',
|
||||
zhipuai: '智谱 AI',
|
||||
minimax: 'MiniMax',
|
||||
deepseek: 'DeepSeek'
|
||||
}
|
||||
|
||||
const getProviderDisplayName = (providerName) => {
|
||||
@@ -395,7 +401,7 @@ onMounted(() => {
|
||||
:deep(.el-select-dropdown) {
|
||||
.el-select-dropdown__item {
|
||||
padding: 8px 16px;
|
||||
|
||||
|
||||
&.is-selected {
|
||||
font-weight: 600;
|
||||
color: #409eff;
|
||||
|
||||
@@ -6,7 +6,6 @@ from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class ExtendInfo(BaseSettings):
|
||||
|
||||
OAUTH2_CLIENT_ID: Optional[str] = Field(
|
||||
description="OA client id for OAuth",
|
||||
default=None,
|
||||
@@ -74,6 +73,16 @@ class ExtendInfo(BaseSettings):
|
||||
)
|
||||
# Extend: 记忆上下文功能
|
||||
|
||||
# Extend: 控制台首次安装完成后,向内网 Admin 服务触发 /init/initdb(与 DB_* 同源,无需额外密钥文件)
|
||||
ADMIN_INITDB_ENABLED: bool = Field(
|
||||
description="Dify 安装向导完成后是否请求 Admin 初始化业务库表",
|
||||
default=False,
|
||||
)
|
||||
ADMIN_INITDB_URL: str = Field(
|
||||
description="Admin InitDB 接口完整 URL(Docker 默认可为 http://admin-server:8888/admin/api/init/initdb)",
|
||||
default="http://admin-server:8888/admin/api/init/initdb",
|
||||
)
|
||||
|
||||
|
||||
class ExtendConfig(ExtendInfo):
|
||||
pass
|
||||
|
||||
@@ -170,7 +170,8 @@ class BaseApiKeyListResource(Resource):
|
||||
@marshal_with(api_key_fields)
|
||||
def put(self, resource_id):
|
||||
resource_id = str(resource_id)
|
||||
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
_get_resource(resource_id, current_tenant_id, self.resource_model)
|
||||
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
@@ -198,7 +199,7 @@ class BaseApiKeyListResource(Resource):
|
||||
)
|
||||
|
||||
if key is None:
|
||||
flask_restful.abort(404, message="API密钥未找到")
|
||||
flask_restx.abort(404, message="API密钥未找到")
|
||||
|
||||
data = request.get_json()
|
||||
|
||||
@@ -227,7 +228,7 @@ class BaseApiKeyListResource(Resource):
|
||||
merged_data = {**api_token.__dict__, **api_token_money_extend.__dict__}
|
||||
return merged_data, 200
|
||||
else:
|
||||
flask_restful.abort(500, message="更新API密钥时发生错误")
|
||||
flask_restx.abort(500, message="更新API密钥时发生错误")
|
||||
# --------------------- 二开部分End - 密钥额度限制 ---------------------
|
||||
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
class StatisticTimeRangeQuery(BaseModel):
|
||||
start: str | None = Field(default=None, description="Start date (YYYY-MM-DD HH:MM)")
|
||||
end: str | None = Field(default=None, description="End date (YYYY-MM-DD HH:MM)")
|
||||
account: str | None = Field(default=None, description="Account ID filter")
|
||||
|
||||
@field_validator("start", "end", mode="before")
|
||||
@classmethod
|
||||
@@ -114,14 +115,8 @@ class DailyConversationStatistic(Resource):
|
||||
args = StatisticTimeRangeQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
converted_created_at = convert_datetime_to_date("created_at")
|
||||
sql_query = f"""SELECT
|
||||
{converted_created_at} AS date,
|
||||
COUNT(DISTINCT conversation_id) AS conversation_count
|
||||
FROM
|
||||
messages
|
||||
WHERE
|
||||
app_id = :app_id
|
||||
AND invoke_from != :invoke_from"""
|
||||
sql_query = f"""SELECT {converted_created_at} AS date, COUNT(DISTINCT conversation_id) AS conversation_count
|
||||
FROM messages WHERE app_id = :app_id AND invoke_from != :invoke_from"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER}
|
||||
assert account.timezone is not None
|
||||
|
||||
@@ -131,33 +126,10 @@ WHERE
|
||||
abort(400, description=str(e))
|
||||
|
||||
if args.account is not None and args.account:
|
||||
sql_query += ""
|
||||
# stmt = (
|
||||
# select(
|
||||
# func.date(
|
||||
# func.date_trunc("day", text("created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz"))
|
||||
# ).label("date"),
|
||||
# func.count(distinct(Message.conversation_id)).label("conversation_count")
|
||||
# )
|
||||
# .select_from(Message)
|
||||
# .where(
|
||||
# Message.app_id == app_model.id,
|
||||
# or_(
|
||||
# Message.from_account_id == account.id,
|
||||
# Message.from_end_user_id.in_(
|
||||
# select(EndUser.id)
|
||||
# .where(EndUser.external_user_id == account.id)
|
||||
# .distinct()
|
||||
# )
|
||||
# )
|
||||
# )
|
||||
# .group_by(
|
||||
# func.date(
|
||||
# func.date_trunc("day", text("created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz"))
|
||||
# )
|
||||
# )
|
||||
# .params(tz=account.timezone) # 绑定参数
|
||||
# )
|
||||
sql_query = f"""SELECT {converted_created_at} AS date, COUNT(DISTINCT conversation_id) AS conversation_count
|
||||
FROM messages WHERE app_id = :app_id AND invoke_from != :invoke_from AND (from_account_id = :user_id OR
|
||||
from_end_user_id IN (SELECT DISTINCT(id) FROM end_users WHERE external_user_id = :user_id))"""
|
||||
arg_dict = {"tz": account.timezone, "app_id": app_model.id, "invoke_from": InvokeFrom.DEBUGGER, "user_id": account.id}
|
||||
|
||||
if start_datetime_utc:
|
||||
sql_query += " AND created_at >= :start"
|
||||
|
||||
@@ -9,6 +9,7 @@ from libs.helper import EmailStr, extract_remote_ip
|
||||
from libs.password import valid_password
|
||||
from models.model import DifySetup, db
|
||||
from services.account_service import RegisterService, TenantService
|
||||
from services.admin_initdb_service import trigger_admin_initdb_if_configured
|
||||
|
||||
from .error import AlreadySetupError, NotInitValidateError
|
||||
from .init_validate import get_init_validate_status
|
||||
@@ -82,6 +83,8 @@ def setup_system(payload: SetupRequestPayload) -> SetupResponse:
|
||||
language=payload.language,
|
||||
)
|
||||
|
||||
trigger_admin_initdb_if_configured(admin_password=payload.password)
|
||||
|
||||
return SetupResponse(result="success")
|
||||
|
||||
|
||||
|
||||
@@ -48,7 +48,10 @@ def load_user_from_request(request_from_flask_login):
|
||||
account.current_tenant = tenant
|
||||
return account
|
||||
|
||||
if request.blueprint in {"console", "inner_api"}:
|
||||
# extend: start fastopenapi路由认证,/console/api/ 路径无blueprint时按console处理
|
||||
is_console_path = request.blueprint is None and request.path.startswith("/console/api/")
|
||||
if request.blueprint in {"console", "inner_api"} or is_console_path:
|
||||
# extend: stop fastopenapi路由认证,/console/api/ 路径无blueprint时按console处理
|
||||
if not auth_token:
|
||||
raise Unauthorized("Invalid Authorization token.")
|
||||
decoded = PassportService().verify(auth_token)
|
||||
|
||||
@@ -35,6 +35,7 @@ from models.account import (
|
||||
)
|
||||
from models.account_money_extend import AccountMoneyExtend
|
||||
from models.model import DifySetup
|
||||
from services.account_service_extend import TenantExtendService
|
||||
from services.billing_service import BillingService
|
||||
from services.errors.account import (
|
||||
AccountAlreadyInTenantError,
|
||||
@@ -1313,6 +1314,19 @@ class RegisterService:
|
||||
|
||||
TenantService.create_owner_tenant_if_not_exist(account=account, is_setup=True)
|
||||
|
||||
# extend begin: admin 初始化不同步问题 - 在 dify 初始化完成后,自动调用 admin 初始化
|
||||
# 将 setup 用户加入到第一个工作区(admin 租户),确保后续用户信息同步时权限正确
|
||||
tenant_extend_service = TenantExtendService
|
||||
super_admin_id = tenant_extend_service.get_super_admin_id().id
|
||||
super_admin_tenant_id = tenant_extend_service.get_super_admin_tenant_id().id
|
||||
if super_admin_id and super_admin_tenant_id:
|
||||
is_create = TenantExtendService.create_default_tenant_member_if_not_exist(
|
||||
super_admin_tenant_id, account.id
|
||||
)
|
||||
if is_create:
|
||||
TenantService.switch_tenant(account, super_admin_tenant_id)
|
||||
# extend end: admin 初始化不同步问题
|
||||
|
||||
dify_setup = DifySetup(version=dify_config.project.version)
|
||||
db.session.add(dify_setup)
|
||||
db.session.commit()
|
||||
|
||||
@@ -0,0 +1,105 @@
|
||||
"""
|
||||
Trigger gin-vue-admin InitDB after Dify console setup.
|
||||
|
||||
Uses the same DB_* settings as the API container and the admin password from the setup form.
|
||||
Does not read or write local config files; outbound call only if ADMIN_INITDB_ENABLED is true.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Literal
|
||||
|
||||
import httpx
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ADMIN_ALREADY_INIT_MSG = "已存在数据库配置"
|
||||
_DEFAULT_TIMEOUT = httpx.Timeout(120.0, connect=10.0)
|
||||
|
||||
|
||||
def _admin_db_type(db_type: str) -> Literal["pgsql", "mysql"]:
|
||||
if db_type == "postgresql":
|
||||
return "pgsql"
|
||||
return "mysql"
|
||||
|
||||
|
||||
def _build_payload(admin_password: str) -> dict[str, Any]:
|
||||
db_type = _admin_db_type(dify_config.DB_TYPE)
|
||||
return {
|
||||
"adminPassword": admin_password,
|
||||
"dbType": db_type,
|
||||
"host": dify_config.DB_HOST,
|
||||
"port": str(dify_config.DB_PORT),
|
||||
"userName": dify_config.DB_USERNAME,
|
||||
"password": dify_config.DB_PASSWORD,
|
||||
"dbName": dify_config.DB_DATABASE,
|
||||
"dbPath": "",
|
||||
}
|
||||
|
||||
|
||||
def _is_acceptable_admin_response(body: dict[str, Any]) -> bool:
|
||||
code = body.get("code")
|
||||
msg = body.get("msg") or ""
|
||||
if code == 0:
|
||||
return True
|
||||
if msg == _ADMIN_ALREADY_INIT_MSG:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def trigger_admin_initdb_if_configured(*, admin_password: str) -> None:
|
||||
"""
|
||||
Best-effort POST to Admin InitDB. Logs warnings on failure; never raises.
|
||||
"""
|
||||
if not dify_config.ADMIN_INITDB_ENABLED:
|
||||
return
|
||||
if dify_config.DB_TYPE not in ("postgresql", "mysql", "oceanbase", "seekdb"):
|
||||
logger.warning(
|
||||
"skip admin initdb: unsupported DB_TYPE for admin bridge: %s",
|
||||
dify_config.DB_TYPE,
|
||||
)
|
||||
return
|
||||
|
||||
url = (dify_config.ADMIN_INITDB_URL or "").strip()
|
||||
if not url:
|
||||
logger.warning("ADMIN_INITDB_ENABLED is true but ADMIN_INITDB_URL is empty, skip admin initdb")
|
||||
return
|
||||
|
||||
payload = _build_payload(admin_password)
|
||||
try:
|
||||
with httpx.Client(timeout=_DEFAULT_TIMEOUT, follow_redirects=True) as client:
|
||||
response = client.post(
|
||||
url,
|
||||
json=payload,
|
||||
headers={"Content-Type": "application/json", "Accept": "application/json"},
|
||||
)
|
||||
except httpx.RequestError as exc:
|
||||
logger.warning("admin initdb request failed: %s", exc)
|
||||
return
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.warning(
|
||||
"admin initdb HTTP %s: %s",
|
||||
response.status_code,
|
||||
(response.text or "")[:500],
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
body = response.json()
|
||||
except ValueError:
|
||||
logger.warning("admin initdb returned non-JSON body: %s", (response.text or "")[:500])
|
||||
return
|
||||
|
||||
if not isinstance(body, dict):
|
||||
logger.warning("admin initdb returned unexpected JSON: %s", body)
|
||||
return
|
||||
|
||||
if _is_acceptable_admin_response(body):
|
||||
logger.info("admin initdb finished: %s", body.get("msg"))
|
||||
return
|
||||
|
||||
logger.warning("admin initdb rejected: %s", body)
|
||||
+4
-1
@@ -693,7 +693,7 @@ ALIBABACLOUD_MYSQL_MAX_CONNECTION=5
|
||||
ALIBABACLOUD_MYSQL_HNSW_M=6
|
||||
|
||||
# relyt configurations, only available when VECTOR_STORE is `relyt`
|
||||
RELYT_HOST=db
|
||||
RELYT_HOST=db_postgres
|
||||
RELYT_PORT=5432
|
||||
RELYT_USER=postgres
|
||||
RELYT_PASSWORD=difyai123456
|
||||
@@ -1283,8 +1283,11 @@ NGINX_SSL_PROTOCOLS=TLSv1.2 TLSv1.3
|
||||
NGINX_WORKER_PROCESSES=auto
|
||||
NGINX_CLIENT_MAX_BODY_SIZE=100M
|
||||
NGINX_KEEPALIVE_TIMEOUT=65
|
||||
NGINX_CLIENT_HEADER_TIMEOUT=1800s
|
||||
NGINX_CLIENT_BODY_TIMEOUT=1800s
|
||||
|
||||
# Proxy settings
|
||||
NGINX_PROXY_CONNECT_TIMEOUT=3600s
|
||||
NGINX_PROXY_READ_TIMEOUT=3600s
|
||||
NGINX_PROXY_SEND_TIMEOUT=3600s
|
||||
|
||||
|
||||
@@ -419,6 +419,9 @@ services:
|
||||
NGINX_WORKER_PROCESSES: ${NGINX_WORKER_PROCESSES:-auto}
|
||||
NGINX_CLIENT_MAX_BODY_SIZE: ${NGINX_CLIENT_MAX_BODY_SIZE:-100M}
|
||||
NGINX_KEEPALIVE_TIMEOUT: ${NGINX_KEEPALIVE_TIMEOUT:-65}
|
||||
NGINX_CLIENT_HEADER_TIMEOUT: ${NGINX_CLIENT_HEADER_TIMEOUT:-1800s}
|
||||
NGINX_CLIENT_BODY_TIMEOUT: ${NGINX_CLIENT_BODY_TIMEOUT:-1800s}
|
||||
NGINX_PROXY_CONNECT_TIMEOUT: ${NGINX_PROXY_CONNECT_TIMEOUT:-3600s}
|
||||
NGINX_PROXY_READ_TIMEOUT: ${NGINX_PROXY_READ_TIMEOUT:-3600s}
|
||||
NGINX_PROXY_SEND_TIMEOUT: ${NGINX_PROXY_SEND_TIMEOUT:-3600s}
|
||||
NGINX_ENABLE_CERTBOT_CHALLENGE: ${NGINX_ENABLE_CERTBOT_CHALLENGE:-false}
|
||||
|
||||
@@ -557,6 +557,9 @@ x-shared-env: &shared-api-worker-env
|
||||
NGINX_WORKER_PROCESSES: ${NGINX_WORKER_PROCESSES:-auto}
|
||||
NGINX_CLIENT_MAX_BODY_SIZE: ${NGINX_CLIENT_MAX_BODY_SIZE:-100M}
|
||||
NGINX_KEEPALIVE_TIMEOUT: ${NGINX_KEEPALIVE_TIMEOUT:-65}
|
||||
NGINX_CLIENT_HEADER_TIMEOUT: ${NGINX_CLIENT_HEADER_TIMEOUT:-1800s}
|
||||
NGINX_CLIENT_BODY_TIMEOUT: ${NGINX_CLIENT_BODY_TIMEOUT:-1800s}
|
||||
NGINX_PROXY_CONNECT_TIMEOUT: ${NGINX_PROXY_CONNECT_TIMEOUT:-3600s}
|
||||
NGINX_PROXY_READ_TIMEOUT: ${NGINX_PROXY_READ_TIMEOUT:-3600s}
|
||||
NGINX_PROXY_SEND_TIMEOUT: ${NGINX_PROXY_SEND_TIMEOUT:-3600s}
|
||||
NGINX_ENABLE_CERTBOT_CHALLENGE: ${NGINX_ENABLE_CERTBOT_CHALLENGE:-false}
|
||||
@@ -717,7 +720,7 @@ services:
|
||||
|
||||
# API service
|
||||
api:
|
||||
image: ccr.ccs.tencentyun.com/yfgaia/dify-plus-api:1.12.1
|
||||
image: ccr.ccs.tencentyun.com/yfgaia/dify-plus-api:1.12.1.fix.3.2
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@@ -734,6 +737,9 @@ services:
|
||||
INNER_API_KEY_FOR_PLUGIN: ${PLUGIN_DIFY_INNER_API_KEY:-QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1}
|
||||
FULL_CODE_EXECUTION_ENDPOINT: ${FULL_CODE_EXECUTION_ENDPOINT:-http://sandbox-full:8194}
|
||||
ALLOW_REGISTER: ${ALLOW_REGISTER:-True}
|
||||
# 安装向导完成后由内网调用 Admin InitDB(库连接与 DB_* 一致,无需用户改本地配置文件)
|
||||
ADMIN_INITDB_ENABLED: ${ADMIN_INITDB_ENABLED:-true}
|
||||
ADMIN_INITDB_URL: ${ADMIN_INITDB_URL:-http://admin-server:8888/init/initdb}
|
||||
depends_on:
|
||||
init_permissions:
|
||||
condition: service_completed_successfully
|
||||
@@ -761,7 +767,7 @@ services:
|
||||
# worker service
|
||||
# The Celery worker for processing all queues (dataset, workflow, mail, etc.)
|
||||
worker:
|
||||
image: ccr.ccs.tencentyun.com/yfgaia/dify-plus-api:1.12.1
|
||||
image: ccr.ccs.tencentyun.com/yfgaia/dify-plus-api:1.12.1.fix.3.2
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@@ -800,7 +806,7 @@ services:
|
||||
# worker-gaia service
|
||||
# The Celery worker-gaia for processing the queue.
|
||||
worker-gaia:
|
||||
image: ccr.ccs.tencentyun.com/yfgaia/dify-plus-api:1.12.1
|
||||
image: ccr.ccs.tencentyun.com/yfgaia/dify-plus-api:1.12.1.fix.3.2
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@@ -826,7 +832,7 @@ services:
|
||||
# worker-dataset service
|
||||
# The Celery worker-dataset for processing the queue.
|
||||
worker-dataset:
|
||||
image: ccr.ccs.tencentyun.com/yfgaia/dify-plus-api:1.12.1
|
||||
image: ccr.ccs.tencentyun.com/yfgaia/dify-plus-api:1.12.1.fix.3.2
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@@ -852,7 +858,7 @@ services:
|
||||
# worker_beat service
|
||||
# Celery beat for scheduling periodic tasks.
|
||||
worker_beat:
|
||||
image: ccr.ccs.tencentyun.com/yfgaia/dify-plus-api:1.12.1
|
||||
image: ccr.ccs.tencentyun.com/yfgaia/dify-plus-api:1.12.1.fix.3.2
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@@ -882,7 +888,7 @@ services:
|
||||
|
||||
# Frontend web application.
|
||||
web:
|
||||
image: ccr.ccs.tencentyun.com/yfgaia/dify-plus-web:1.12.1
|
||||
image: ccr.ccs.tencentyun.com/yfgaia/dify-plus-web:1.12.1.fix.3.2
|
||||
restart: always
|
||||
environment:
|
||||
CONSOLE_API_URL: ${CONSOLE_API_URL:-}
|
||||
@@ -943,8 +949,6 @@ services:
|
||||
interval: 1s
|
||||
timeout: 3s
|
||||
retries: 60
|
||||
ports:
|
||||
- 5432:5432
|
||||
|
||||
# The mysql database.
|
||||
db_mysql:
|
||||
@@ -993,8 +997,6 @@ services:
|
||||
"CMD-SHELL",
|
||||
"redis-cli -a ${REDIS_PASSWORD:-difyai123456} ping | grep -q PONG",
|
||||
]
|
||||
ports:
|
||||
- 6379:6379
|
||||
|
||||
# The DifySandbox
|
||||
sandbox:
|
||||
@@ -1171,6 +1173,9 @@ services:
|
||||
NGINX_WORKER_PROCESSES: ${NGINX_WORKER_PROCESSES:-auto}
|
||||
NGINX_CLIENT_MAX_BODY_SIZE: ${NGINX_CLIENT_MAX_BODY_SIZE:-100M}
|
||||
NGINX_KEEPALIVE_TIMEOUT: ${NGINX_KEEPALIVE_TIMEOUT:-65}
|
||||
NGINX_CLIENT_HEADER_TIMEOUT: ${NGINX_CLIENT_HEADER_TIMEOUT:-1800s}
|
||||
NGINX_CLIENT_BODY_TIMEOUT: ${NGINX_CLIENT_BODY_TIMEOUT:-1800s}
|
||||
NGINX_PROXY_CONNECT_TIMEOUT: ${NGINX_PROXY_CONNECT_TIMEOUT:-3600s}
|
||||
NGINX_PROXY_READ_TIMEOUT: ${NGINX_PROXY_READ_TIMEOUT:-3600s}
|
||||
NGINX_PROXY_SEND_TIMEOUT: ${NGINX_PROXY_SEND_TIMEOUT:-3600s}
|
||||
NGINX_ENABLE_CERTBOT_CHALLENGE: ${NGINX_ENABLE_CERTBOT_CHALLENGE:-false}
|
||||
@@ -1667,7 +1672,7 @@ services:
|
||||
|
||||
# Extend - admin-web
|
||||
admin-web:
|
||||
image: ccr.ccs.tencentyun.com/yfgaia/dify-plus-admin-web:1.12.1
|
||||
image: ccr.ccs.tencentyun.com/yfgaia/dify-plus-admin-web:1.12.1.fix.3.2
|
||||
restart: always
|
||||
ports:
|
||||
- '8081:8081'
|
||||
@@ -1679,13 +1684,22 @@ services:
|
||||
|
||||
# Extend - admin-server
|
||||
admin-server:
|
||||
image: ccr.ccs.tencentyun.com/yfgaia/dify-plus-admin-server:1.12.1
|
||||
image: ccr.ccs.tencentyun.com/yfgaia/dify-plus-admin-server:1.12.1.fix.3.3
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
<<: *shared-api-worker-env
|
||||
# JWT signing key must match API's SECRET_KEY for token compatibility
|
||||
JWT_SIGNING_KEY: ${SECRET_KEY:-sk-9f73s3ljTXVcMT3Blb3ljTqtsKiGHXVcMT3BlbkFJLK7U}
|
||||
SECRET_KEY: ${SECRET_KEY:-sk-9f73s3ljTXVcMT3Blb3ljTqtsKiGHXVcMT3BlbkFJLK7U}
|
||||
EMAIL_DOMAIN: ${EMAIL_DOMAIN:-}
|
||||
# 与 Dify 使用同一套数据库配置,避免密码不一致导致认证失败
|
||||
DB_TYPE: ${DB_TYPE:-postgresql}
|
||||
DB_HOST: ${DB_HOST:-db_postgres}
|
||||
DB_PORT: ${DB_PORT:-5432}
|
||||
DB_USERNAME: ${DB_USERNAME:-postgres}
|
||||
DB_PASSWORD: ${DB_PASSWORD:-difyai123456}
|
||||
DB_DATABASE: ${DB_DATABASE:-dify}
|
||||
ports:
|
||||
- '8888:8888'
|
||||
depends_on:
|
||||
|
||||
@@ -2,9 +2,6 @@ services:
|
||||
# The postgres database.
|
||||
db_postgres:
|
||||
image: postgres:15-alpine
|
||||
profiles:
|
||||
- ""
|
||||
- postgresql
|
||||
restart: always
|
||||
env_file:
|
||||
- ./middleware.env
|
||||
|
||||
@@ -554,6 +554,9 @@ x-shared-env: &shared-api-worker-env
|
||||
NGINX_WORKER_PROCESSES: ${NGINX_WORKER_PROCESSES:-auto}
|
||||
NGINX_CLIENT_MAX_BODY_SIZE: ${NGINX_CLIENT_MAX_BODY_SIZE:-100M}
|
||||
NGINX_KEEPALIVE_TIMEOUT: ${NGINX_KEEPALIVE_TIMEOUT:-65}
|
||||
NGINX_CLIENT_HEADER_TIMEOUT: ${NGINX_CLIENT_HEADER_TIMEOUT:-1800s}
|
||||
NGINX_CLIENT_BODY_TIMEOUT: ${NGINX_CLIENT_BODY_TIMEOUT:-1800s}
|
||||
NGINX_PROXY_CONNECT_TIMEOUT: ${NGINX_PROXY_CONNECT_TIMEOUT:-3600s}
|
||||
NGINX_PROXY_READ_TIMEOUT: ${NGINX_PROXY_READ_TIMEOUT:-3600s}
|
||||
NGINX_PROXY_SEND_TIMEOUT: ${NGINX_PROXY_SEND_TIMEOUT:-3600s}
|
||||
NGINX_ENABLE_CERTBOT_CHALLENGE: ${NGINX_ENABLE_CERTBOT_CHALLENGE:-false}
|
||||
@@ -1105,6 +1108,9 @@ services:
|
||||
NGINX_WORKER_PROCESSES: ${NGINX_WORKER_PROCESSES:-auto}
|
||||
NGINX_CLIENT_MAX_BODY_SIZE: ${NGINX_CLIENT_MAX_BODY_SIZE:-100M}
|
||||
NGINX_KEEPALIVE_TIMEOUT: ${NGINX_KEEPALIVE_TIMEOUT:-65}
|
||||
NGINX_CLIENT_HEADER_TIMEOUT: ${NGINX_CLIENT_HEADER_TIMEOUT:-1800s}
|
||||
NGINX_CLIENT_BODY_TIMEOUT: ${NGINX_CLIENT_BODY_TIMEOUT:-1800s}
|
||||
NGINX_PROXY_CONNECT_TIMEOUT: ${NGINX_PROXY_CONNECT_TIMEOUT:-3600s}
|
||||
NGINX_PROXY_READ_TIMEOUT: ${NGINX_PROXY_READ_TIMEOUT:-3600s}
|
||||
NGINX_PROXY_SEND_TIMEOUT: ${NGINX_PROXY_SEND_TIMEOUT:-3600s}
|
||||
NGINX_ENABLE_CERTBOT_CHALLENGE: ${NGINX_ENABLE_CERTBOT_CHALLENGE:-false}
|
||||
|
||||
@@ -3,6 +3,9 @@
|
||||
server {
|
||||
listen ${NGINX_PORT};
|
||||
server_name ${NGINX_SERVER_NAME};
|
||||
keepalive_timeout 1800s;
|
||||
client_header_timeout ${NGINX_CLIENT_HEADER_TIMEOUT};
|
||||
client_body_timeout ${NGINX_CLIENT_BODY_TIMEOUT};
|
||||
# 管理中心反向代理配置
|
||||
location = /admin {
|
||||
return 301 /admin/;
|
||||
|
||||
@@ -7,5 +7,6 @@ proxy_set_header X-Forwarded-Port $server_port;
|
||||
proxy_http_version 1.1;
|
||||
proxy_set_header Connection "";
|
||||
proxy_buffering off;
|
||||
proxy_connect_timeout ${NGINX_PROXY_CONNECT_TIMEOUT};
|
||||
proxy_read_timeout ${NGINX_PROXY_READ_TIMEOUT};
|
||||
proxy_send_timeout ${NGINX_PROXY_SEND_TIMEOUT};
|
||||
|
||||
@@ -153,4 +153,43 @@ INSERT INTO casbin_rule (ptype, v0, v1, v2) VALUES
|
||||
INSERT INTO sys_authority_menus (sys_authority_authority_id, sys_base_menu_id) VALUES (888, 42);
|
||||
|
||||
|
||||
-- --------------- 9. API sys_apis (转发集成:转发 Token 管理 3 条) 2026-03-09 18:08:33 ---------------
|
||||
-- 请按当前库最大 id 调整起始 id,避免冲突。例如 MAX(id)=269 则从 270 起
|
||||
INSERT INTO sys_apis (id, created_at, updated_at, deleted_at, path, description, api_group, method) VALUES
|
||||
(270, NOW(), NOW(), NULL, '/gaia/system/forward-tokens', '获取转发 Token 列表', '转发集成', 'GET'),
|
||||
(271, NOW(), NOW(), NULL, '/gaia/system/forward-tokens', '新增转发 Token', '转发集成', 'POST'),
|
||||
(272, NOW(), NOW(), NULL, '/gaia/system/forward-tokens/:id', '删除转发 Token', '转发集成', 'DELETE');
|
||||
|
||||
|
||||
-- --------------- 10. Casbin 规则 casbin_rule (转发集成 888/8881) ---------------
|
||||
INSERT INTO casbin_rule (ptype, v0, v1, v2) VALUES
|
||||
('p', '888', '/gaia/system/forward-tokens', 'GET'),
|
||||
('p', '888', '/gaia/system/forward-tokens', 'POST'),
|
||||
('p', '888', '/gaia/system/forward-tokens/:id', 'DELETE'),
|
||||
('p', '8881', '/gaia/system/forward-tokens', 'GET'),
|
||||
('p', '8881', '/gaia/system/forward-tokens', 'POST'),
|
||||
('p', '8881', '/gaia/system/forward-tokens/:id', 'DELETE');
|
||||
|
||||
|
||||
-- --------------- 11. API sys_apis (钉钉邮箱配置测试 1 条) ---------------
|
||||
-- 请按当前库最大 id 调整起始 id,避免冲突。例如 MAX(id)=272 则从 273 起
|
||||
INSERT INTO sys_apis (id, created_at, updated_at, deleted_at, path, description, api_group, method) VALUES
|
||||
(273, NOW(), NOW(), NULL, '/gaia/system/dingtalk/test-email-config', '测试钉钉邮箱API配置', '应用集成配置', 'POST');
|
||||
|
||||
|
||||
-- --------------- 12. Casbin 规则 casbin_rule (钉钉邮箱配置测试 888) ---------------
|
||||
INSERT INTO casbin_rule (ptype, v0, v1, v2) VALUES
|
||||
('p', '888', '/gaia/system/dingtalk/test-email-config', 'POST');
|
||||
|
||||
|
||||
-- --------------- 13. API sys_apis (钉钉测试连接:授权 URL + 回调验证 2 条) ---------------
|
||||
-- 请按当前库最大 id 调整起始 id,避免冲突。例如 MAX(id)=273 则从 274 起
|
||||
INSERT INTO sys_apis (id, created_at, updated_at, deleted_at, path, description, api_group, method) VALUES
|
||||
(274, NOW(), NOW(), NULL, '/gaia/system/dingtalk/test-auth-url', '测试连接-获取钉钉授权URL', '应用集成配置', 'GET'),
|
||||
(275, NOW(), NOW(), NULL, '/gaia/system/dingtalk/test-callback', '测试连接-钉钉回调验证', '应用集成配置', 'POST');
|
||||
|
||||
|
||||
-- --------------- 14. Casbin 规则 casbin_rule (钉钉测试连接 888) ---------------
|
||||
INSERT INTO casbin_rule (ptype, v0, v1, v2) VALUES
|
||||
('p', '888', '/gaia/system/dingtalk/test-auth-url', 'GET'),
|
||||
('p', '888', '/gaia/system/dingtalk/test-callback', 'POST');
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
'use client'
|
||||
import type { FC } from 'react'
|
||||
import React from 'react'
|
||||
import * as React from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import ParamItem from '.'
|
||||
|
||||
@@ -27,7 +27,7 @@ const DayLimitItemExtend: FC<Props> = ({
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
const handleParamChange = (key: string, value: number) => {
|
||||
let notOutRangeValue = parseFloat(value.toFixed(2))
|
||||
let notOutRangeValue = Number.parseFloat(value.toFixed(2))
|
||||
notOutRangeValue = Math.max(VALUE_LIMIT.min, notOutRangeValue)
|
||||
notOutRangeValue = Math.min(VALUE_LIMIT.max, notOutRangeValue)
|
||||
onChange(key, notOutRangeValue)
|
||||
@@ -36,8 +36,8 @@ const DayLimitItemExtend: FC<Props> = ({
|
||||
<ParamItem
|
||||
className={className}
|
||||
id={key}
|
||||
name={t('extend.apiKeyModal.dayLimitItemName')}
|
||||
tip={t('extend.apiKeyModal.noLimitTips') as string}
|
||||
name={t('apiKeyModal.dayLimitItemName', { ns: 'extend' })}
|
||||
tip={t('apiKeyModal.noLimitTips', { ns: 'extend' }) as string}
|
||||
{...VALUE_LIMIT}
|
||||
value={value}
|
||||
enable={enable}
|
||||
|
||||
@@ -27,7 +27,7 @@ const MonthLimitItemExtend: FC<Props> = ({
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
const handleParamChange = (key: string, value: number) => {
|
||||
let notOutRangeValue = parseFloat(value.toFixed(2))
|
||||
let notOutRangeValue = Number.parseFloat(value.toFixed(2))
|
||||
notOutRangeValue = Math.max(VALUE_LIMIT.min, notOutRangeValue)
|
||||
notOutRangeValue = Math.min(VALUE_LIMIT.max, notOutRangeValue)
|
||||
onChange(key, notOutRangeValue)
|
||||
@@ -36,8 +36,8 @@ const MonthLimitItemExtend: FC<Props> = ({
|
||||
<ParamItem
|
||||
className={className}
|
||||
id={key}
|
||||
name={t('extend.apiKeyModal.monthLimitItemName')}
|
||||
tip={t('extend.apiKeyModal.noLimitTips') as string}
|
||||
name={t('apiKeyModal.monthLimitItemName', { ns: 'extend' })}
|
||||
tip={t('apiKeyModal.noLimitTips', { ns: 'extend' }) as string}
|
||||
{...VALUE_LIMIT}
|
||||
value={value}
|
||||
enable={enable}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
'use client'
|
||||
import type { CreateApiKeyResponse } from '@/models/app'
|
||||
import type { ApiKeyItemResponse, ApikeyItemResponseWithQuotaLimitExtend, CreateApiKeyResponse } from '@/models/app'
|
||||
import { PlusIcon, XMarkIcon } from '@heroicons/react/20/solid'
|
||||
import { RiDeleteBinLine } from '@remixicon/react'
|
||||
import {
|
||||
@@ -17,14 +17,18 @@ import useTimestamp from '@/hooks/use-timestamp'
|
||||
import {
|
||||
createApikey as createAppApikey,
|
||||
delApikey as delAppApikey,
|
||||
editApikey, // 二开部分 - 密钥额度
|
||||
} from '@/service/apps'
|
||||
import {
|
||||
createApikey as createDatasetApikey,
|
||||
delApikey as delDatasetApikey,
|
||||
} from '@/service/datasets'
|
||||
// 二开部分End - 密钥额度
|
||||
import { useDatasetApiKeys, useInvalidateDatasetApiKeys } from '@/service/knowledge/use-dataset'
|
||||
import { useAppApiKeys, useInvalidateAppApiKeys } from '@/service/use-apps'
|
||||
import SecretKeyGenerateModal from './secret-key-generate'
|
||||
// 二开部分Start - 密钥额度
|
||||
import SecretKeyQuotaSetExtendModal from './secret-key-quota-set-modal-extend'
|
||||
import s from './style.module.css'
|
||||
|
||||
type ISecretKeyModalProps = {
|
||||
@@ -50,6 +54,65 @@ const SecretKeyModal = ({
|
||||
const { data: datasetApiKeys, isLoading: isDatasetApiKeysLoading } = useDatasetApiKeys({ enabled: !appId && isShow })
|
||||
const apiKeysList = appId ? appApiKeys : datasetApiKeys
|
||||
const isApiKeysLoading = appId ? isAppApiKeysLoading : isDatasetApiKeysLoading
|
||||
// ---------------------- 二开部分Begin - 密钥额度 ----------------------
|
||||
const [isVisibleExtend, setVisibleExtend] = useState(false)
|
||||
const [keyItem, setKeyItem] = useState<ApikeyItemResponseWithQuotaLimitExtend>({
|
||||
created_at: '',
|
||||
id: '',
|
||||
last_used_at: '',
|
||||
token: '',
|
||||
description: '',
|
||||
day_limit_quota: -1,
|
||||
month_limit_quota: -1,
|
||||
})
|
||||
// 打开新增密钥额度编辑框
|
||||
const openSecretKeyQuotaSetModalExtend = async () => {
|
||||
setVisibleExtend(true)
|
||||
setKeyItem({
|
||||
created_at: '',
|
||||
id: '',
|
||||
last_used_at: '',
|
||||
token: '',
|
||||
description: '',
|
||||
day_limit_quota: -1,
|
||||
month_limit_quota: -1,
|
||||
})
|
||||
}
|
||||
// 打开编辑密钥额度编辑框
|
||||
const openSecretKeyQuotaEditModalExtend = async (api: ApiKeyItemResponse) => {
|
||||
setVisibleExtend(true)
|
||||
setKeyItem({
|
||||
created_at: api.created_at,
|
||||
id: api.id,
|
||||
last_used_at: api.last_used_at,
|
||||
token: api.token,
|
||||
description: api.description,
|
||||
day_limit_quota: api.day_limit_quota,
|
||||
month_limit_quota: api.month_limit_quota,
|
||||
})
|
||||
}
|
||||
// 设置密钥额度数据
|
||||
const handleSetKeyDataSetQuotas = (newKeyItems: ApikeyItemResponseWithQuotaLimitExtend) => {
|
||||
setKeyItem(newKeyItems)
|
||||
}
|
||||
// ---------------------- 二开部分End - 密钥额度 ----------------------
|
||||
|
||||
// 二开部分 Begin - 密钥额度限制编辑
|
||||
const onEdit = async () => {
|
||||
const params = {
|
||||
url: `/apps/${appId}/api-keys`,
|
||||
body: {
|
||||
id: keyItem.id,
|
||||
description: keyItem.description,
|
||||
day_limit_quota: keyItem.day_limit_quota,
|
||||
month_limit_quota: keyItem.month_limit_quota,
|
||||
},
|
||||
}
|
||||
const res = await editApikey(params)
|
||||
setVisibleExtend(false)
|
||||
setNewKey(res)
|
||||
}
|
||||
// 二开部分 Begin - 密钥额度限制编辑
|
||||
|
||||
const [delKeyID, setDelKeyId] = useState('')
|
||||
|
||||
@@ -101,6 +164,12 @@ const SecretKeyModal = ({
|
||||
<div className="w-64 shrink-0 px-3">{t('apiKeyModal.secretKey', { ns: 'appApi' })}</div>
|
||||
<div className="w-[200px] shrink-0 px-3">{t('apiKeyModal.created', { ns: 'appApi' })}</div>
|
||||
<div className="w-[200px] shrink-0 px-3">{t('apiKeyModal.lastUsed', { ns: 'appApi' })}</div>
|
||||
{/* ---------------------- 二开部分Begin - 密钥额度限制 ---------------------- */}
|
||||
<div className="w-[100px] shrink-0 px-3">{t('apiKeyModal.descriptionPlaceholder', { ns: 'extend' })}</div>
|
||||
<div className="w-[200px] shrink-0 px-3">{t('apiKeyModal.dayLimit', { ns: 'extend' })}</div>
|
||||
<div className="w-[200px] shrink-0 px-3">{t('apiKeyModal.monthLimit', { ns: 'extend' })}</div>
|
||||
<div className="w-[200px] shrink-0 px-3">{t('apiKeyModal.accumulatedLimit', { ns: 'extend' })}</div>
|
||||
{/* ---------------------- 二开部分End - 密钥额度限制 ---------------------- */}
|
||||
<div className="grow px-3"></div>
|
||||
</div>
|
||||
<div className="grow overflow-auto">
|
||||
@@ -109,6 +178,27 @@ const SecretKeyModal = ({
|
||||
<div className="w-64 shrink-0 truncate px-3 font-mono">{generateToken(api.token)}</div>
|
||||
<div className="w-[200px] shrink-0 truncate px-3">{formatTime(Number(api.created_at), t('dateTimeFormat', { ns: 'appLog' }) as string)}</div>
|
||||
<div className="w-[200px] shrink-0 truncate px-3">{api.last_used_at ? formatTime(Number(api.last_used_at), t('dateTimeFormat', { ns: 'appLog' }) as string) : t('never', { ns: 'appApi' })}</div>
|
||||
{/* ---------------------- 二开部分Begin - 密钥额度限制 ---------------------- */}
|
||||
<div className="w-[100px] shrink-0 truncate px-3">{api.description}</div>
|
||||
<div className="w-[200px] shrink-0 truncate px-3">
|
||||
$
|
||||
{api.day_used_quota}
|
||||
{' '}
|
||||
/
|
||||
{api.day_limit_quota === -1 ? t('apiKeyModal.noLimit', { ns: 'extend' }) : `$ ${api.day_limit_quota}`}
|
||||
</div>
|
||||
<div className="w-[200px] shrink-0 truncate px-3">
|
||||
$
|
||||
{api.month_used_quota}
|
||||
{' '}
|
||||
/
|
||||
{api.month_limit_quota === -1 ? t('apiKeyModal.noLimit', { ns: 'extend' }) : `$ ${api.month_limit_quota}`}
|
||||
</div>
|
||||
<div className="w-[200px] shrink-0 truncate px-3">
|
||||
$
|
||||
{api.accumulated_quota}
|
||||
</div>
|
||||
{/* ---------------------- 二开部分End - 密钥额度限制 ---------------------- */}
|
||||
<div className="flex grow space-x-2 px-3">
|
||||
<CopyFeedback content={api.token} />
|
||||
{isCurrentWorkspaceManager && (
|
||||
@@ -121,6 +211,17 @@ const SecretKeyModal = ({
|
||||
<RiDeleteBinLine className="h-4 w-4" />
|
||||
</ActionButton>
|
||||
)}
|
||||
{/* // 二开部分 End - 密钥额度限制编辑 */}
|
||||
{isCurrentWorkspaceManager && (
|
||||
<div
|
||||
className={`flex h-6 w-6 shrink-0 cursor-pointer items-center justify-center rounded-lg ${s.editIcon}`}
|
||||
onClick={() => {
|
||||
openSecretKeyQuotaEditModalExtend(api).then()
|
||||
}}
|
||||
>
|
||||
</div>
|
||||
)}
|
||||
{/* // 二开部分 Begin - 密钥额度限制编辑 */}
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
@@ -147,6 +248,10 @@ const SecretKeyModal = ({
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
|
||||
{/* ----------------------二开部分Begin - 密钥额度限制---------------------- */}
|
||||
<SecretKeyQuotaSetExtendModal className="shrink-0" isShow={isVisibleExtend} onClose={() => setVisibleExtend(false)} newKey={keyItem} onChange={handleSetKeyDataSetQuotas} onCreate={keyItem.id === '' ? onCreate : onEdit} />
|
||||
{/* ----------------------二开部分End - 密钥额度限制---------------------- */}
|
||||
</Modal>
|
||||
)
|
||||
}
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
'use client'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import type { ApikeyItemResponseWithQuotaLimitExtend } from '@/models/app'
|
||||
import { XMarkIcon } from '@heroicons/react/20/solid'
|
||||
import s from './style.module.css'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Button from '@/app/components/base/button'
|
||||
import Modal from '@/app/components/base/modal'
|
||||
import type { ApikeyItemResponseWithQuotaLimitExtend } from '@/models/app'
|
||||
import DayLimitItemExtend from '@/app/components/base/param-item/day-limit-item-extend'
|
||||
import MonthLimitItemExtend from '@/app/components/base/param-item/month-limit-item-extend'
|
||||
import s from './style.module.css'
|
||||
|
||||
type ISecretKeyGenerateModalProps = {
|
||||
isShow: boolean
|
||||
@@ -53,35 +53,41 @@ const SecretKeyQuotaSetExtendModal = ({
|
||||
}
|
||||
|
||||
return (
|
||||
<Modal isShow={isShow} onClose={onClose} title={(newKey?.id ? '编辑' : '创建')+ `${t('appApi.apiKeyModal.apiSecretKey')}`}
|
||||
className={`px-8 ${className}`}>
|
||||
<XMarkIcon className={`w-6 h-6 absolute cursor-pointer text-gray-500 ${s.close}`} onClick={onClose}/>
|
||||
<p className='mt-1 text-[13px] text-gray-500 font-normal leading-5'>{t('extend.apiKeyModal.apiSecretKeyTips')}</p>
|
||||
<div className='my-4'>
|
||||
<Modal
|
||||
isShow={isShow}
|
||||
onClose={onClose}
|
||||
title={`${newKey?.id ? '编辑' : '创建'}${t('apiKeyModal.apiSecretKey', { ns: 'appApi' })}`}
|
||||
className={`px-8 ${className}`}
|
||||
>
|
||||
<XMarkIcon className={`absolute h-6 w-6 cursor-pointer text-gray-500 ${s.close}`} onClick={onClose} />
|
||||
<p className="mt-1 text-[13px] font-normal leading-5 text-gray-500">
|
||||
{t('apiKeyModal.apiSecretKeyTips', { ns: 'extend' })}
|
||||
</p>
|
||||
<div className="my-4">
|
||||
<input
|
||||
value={newKey?.description ?? ''}
|
||||
onChange={e => handleParamChangeDesc(e.target.value)}
|
||||
placeholder={t('extend.apiKeyModal.descriptionPlaceholder') || '密钥用途'}
|
||||
className='grow h-10 px-3 text-sm font-normal bg-gray-100 rounded-lg border border-transparent outline-none appearance-none caret-primary-600 placeholder:text-gray-400 hover:bg-gray-50 hover:border hover:border-gray-300 focus:bg-gray-50 focus:border focus:border-gray-300 focus:shadow-xs'
|
||||
placeholder={t('apiKeyModal.descriptionPlaceholder', { ns: 'extend' }) || '密钥用途'}
|
||||
className="h-10 grow appearance-none rounded-lg border border-transparent bg-gray-100 px-3 text-sm font-normal caret-primary-600 outline-none placeholder:text-gray-400 hover:border hover:border-gray-300 hover:bg-gray-50 focus:border focus:border-gray-300 focus:bg-gray-50 focus:shadow-xs"
|
||||
/>
|
||||
</div>
|
||||
<div className='my-4'>
|
||||
<div className="my-4">
|
||||
<DayLimitItemExtend
|
||||
value={newKey?.day_limit_quota ?? -1}
|
||||
onChange={handleParamChange}
|
||||
enable={true}
|
||||
/>
|
||||
</div>
|
||||
<div className='my-4'>
|
||||
<div className="my-4">
|
||||
<MonthLimitItemExtend
|
||||
value={newKey?.month_limit_quota ?? -1}
|
||||
onChange={handleParamChange}
|
||||
enable={true}
|
||||
/>
|
||||
</div>
|
||||
<div className='flex justify-end my-4'>
|
||||
<Button variant='primary' className={`flex-shrink-0 ${s.w64}`} onClick={onCreate}>
|
||||
{newKey?.id ? t('common.operation.save') : t('common.operation.create')}
|
||||
<div className="my-4 flex justify-end">
|
||||
<Button variant="primary" className={`shrink-0 ${s.w64}`} onClick={onCreate}>
|
||||
{newKey?.id ? t('operation.save', { ns: 'common' }) : t('operation.create', { ns: 'common' })}
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
|
||||
@@ -55,3 +55,17 @@
|
||||
.copyIcon.copied {
|
||||
background-image: url(./assets/copied.svg);
|
||||
}
|
||||
|
||||
/* 二开部分Begin - 密钥额度限制 */
|
||||
.editIcon {
|
||||
background-image: url(./assets/edit.svg);
|
||||
background-position: center;
|
||||
background-repeat: no-repeat;
|
||||
}
|
||||
|
||||
.editIcon:hover {
|
||||
background-image: url(./assets/edit-hover.svg);
|
||||
background-position: center;
|
||||
background-repeat: no-repeat;
|
||||
}
|
||||
/* 二开部分End - 密钥额度限制 */
|
||||
|
||||
@@ -180,12 +180,19 @@ const TextGeneration: FC<IMainProps> = ({
|
||||
const batchJobsLimit = 5 // 每页5个任务
|
||||
const [totalBatchJobs, setTotalBatchJobs] = useState(0)
|
||||
const [isLoadingBatchJobs, setIsLoadingBatchJobs] = useState(false)
|
||||
const lastRefreshTimeRef = useRef(0) // 记录上次刷新时间,避免频繁刷新
|
||||
|
||||
// 从后端获取批量工作流列表
|
||||
const loadBatchWorkflows = useCallback(async () => {
|
||||
const loadBatchWorkflows = useCallback(async (force = false) => {
|
||||
if (!appId || currentTab !== 'batch')
|
||||
return
|
||||
|
||||
// 防止过于频繁的刷新(至少间隔 1 秒)
|
||||
const now = Date.now()
|
||||
if (!force && now - lastRefreshTimeRef.current < 1000)
|
||||
return
|
||||
|
||||
lastRefreshTimeRef.current = now
|
||||
setIsLoadingBatchJobs(true)
|
||||
try {
|
||||
const result = await fetchBatchWorkflowListApi(installedAppInfo?.id, currentPage, batchJobsLimit)
|
||||
@@ -218,25 +225,9 @@ const TextGeneration: FC<IMainProps> = ({
|
||||
loadBatchWorkflows()
|
||||
}, [loadBatchWorkflows])
|
||||
|
||||
// 自动刷新批量工作流列表(每3秒)
|
||||
useEffect(() => {
|
||||
if (currentTab !== 'batch' || batchJobs.length === 0)
|
||||
return
|
||||
// 注意:不再需要自动刷新逻辑,因为每个批量任务现在自己管理进度刷新
|
||||
// 每个 BatchProgress 组件会独立轮询自己的进度(每 3 秒)
|
||||
|
||||
// 检查是否有进行中的任务
|
||||
const hasActiveJobs = batchJobs.some(job =>
|
||||
job.status === 'pending' || job.status === 'processing',
|
||||
)
|
||||
|
||||
if (!hasActiveJobs)
|
||||
return
|
||||
|
||||
const refreshInterval = setInterval(() => {
|
||||
loadBatchWorkflows()
|
||||
}, 3000) // 每3秒刷新一次
|
||||
|
||||
return () => clearInterval(refreshInterval)
|
||||
}, [currentTab, batchJobs, loadBatchWorkflows])
|
||||
|
||||
// 计算分页数据 - 现在数据已经是从后端分页获取的,不需要再切片
|
||||
const paginatedBatchJobs = batchJobs
|
||||
@@ -472,6 +463,28 @@ const TextGeneration: FC<IMainProps> = ({
|
||||
loadBatchWorkflows()
|
||||
console.log('批量任务重试成功,已刷新列表')
|
||||
}
|
||||
|
||||
// 处理单个任务进度更新(只更新列表中的对应项,不刷新整个列表)
|
||||
const handleJobUpdate = useCallback((jobId: string, updatedData: { status: string, processedRows: number, error?: string }) => {
|
||||
setBatchJobs(prevJobs => {
|
||||
// 检查是否真的有变化
|
||||
const job = prevJobs.find(j => j.id === jobId)
|
||||
if (!job)
|
||||
return prevJobs
|
||||
|
||||
// 如果没有变化,不更新
|
||||
if (job.status === updatedData.status && job.processedRows === updatedData.processedRows && job.error === updatedData.error)
|
||||
return prevJobs
|
||||
|
||||
// 有变化才更新
|
||||
return prevJobs.map(job =>
|
||||
job.id === jobId
|
||||
? { ...job, ...updatedData }
|
||||
: job
|
||||
)
|
||||
})
|
||||
}, [])
|
||||
|
||||
// Extend: Stop Batch import
|
||||
|
||||
const handleCompleted = (completionRes: string, taskId?: number, isSuccess?: boolean) => {
|
||||
@@ -662,6 +675,7 @@ const TextGeneration: FC<IMainProps> = ({
|
||||
jobData={job}
|
||||
onDownload={() => handleBatchDownload(job.id)}
|
||||
onRetrySuccess={handleRetrySuccess}
|
||||
onJobUpdate={(updatedData) => handleJobUpdate(job.id, updatedData)}
|
||||
/>
|
||||
))
|
||||
) : (
|
||||
|
||||
@@ -11,7 +11,7 @@ import {
|
||||
RiRefreshLine,
|
||||
RiStopLine,
|
||||
} from '@remixicon/react'
|
||||
import { resumeBatchApi, retryFailedTasksApi, stopBatchApi } from '@/service/web-extend' // extend: 批量运行工单
|
||||
import { fetchProgressApi, resumeBatchApi, retryFailedTasksApi, stopBatchApi } from '@/service/web-extend' // extend: 批量运行工单
|
||||
import type { BatchStatus } from '@/utils/batch-progress-manager' // extend: 批量运行工单
|
||||
import ActionButton from '@/app/components/base/action-button'
|
||||
|
||||
@@ -32,6 +32,7 @@ export type BatchProgressProps = {
|
||||
}
|
||||
onDownload: () => void
|
||||
onRetrySuccess?: () => void
|
||||
onJobUpdate?: (jobData: { status: string, processedRows: number, error?: string }) => void // 新增:任务更新回调
|
||||
}
|
||||
|
||||
const BatchProgress: FC<BatchProgressProps> = ({
|
||||
@@ -41,10 +42,62 @@ const BatchProgress: FC<BatchProgressProps> = ({
|
||||
jobData,
|
||||
onDownload,
|
||||
onRetrySuccess,
|
||||
onJobUpdate,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
const [isLoading, setIsLoading] = useState(false)
|
||||
// 本地进度状态,用于独立刷新
|
||||
const [localProgress, setLocalProgress] = useState({
|
||||
status: jobData.status,
|
||||
processedRows: jobData.processedRows,
|
||||
totalRows: jobData.totalRows,
|
||||
error: jobData.error,
|
||||
})
|
||||
|
||||
// 自动刷新单个任务的进度(每 3 秒)
|
||||
useEffect(() => {
|
||||
// 只在任务进行中时刷新
|
||||
if (localProgress.status !== 'pending' && localProgress.status !== 'processing')
|
||||
return
|
||||
|
||||
const refreshInterval = setInterval(async () => {
|
||||
try {
|
||||
const progress = await fetchProgressApi(batchId)
|
||||
if (progress) {
|
||||
const newStatus = progress.status as string
|
||||
const newProcessedRows = progress.processed_rows as number
|
||||
const newError = progress.error as string | undefined
|
||||
|
||||
// 只有当数据有变化时才更新
|
||||
if (
|
||||
newStatus !== localProgress.status
|
||||
|| newProcessedRows !== localProgress.processedRows
|
||||
|| newError !== localProgress.error
|
||||
) {
|
||||
const updatedProgress = {
|
||||
status: newStatus,
|
||||
processedRows: newProcessedRows,
|
||||
totalRows: progress.total_rows as number,
|
||||
error: newError,
|
||||
}
|
||||
setLocalProgress(updatedProgress)
|
||||
// 通知父组件数据已更新(用于列表级别的状态同步)
|
||||
onJobUpdate?.({
|
||||
status: newStatus,
|
||||
processedRows: newProcessedRows,
|
||||
error: newError,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
catch (error) {
|
||||
console.error('Failed to fetch progress:', error)
|
||||
}
|
||||
}, 3000)
|
||||
|
||||
return () => clearInterval(refreshInterval)
|
||||
}, [batchId, localProgress.status, localProgress.processedRows, localProgress.error, onJobUpdate])
|
||||
|
||||
// 停止批量处理
|
||||
const handleStop = async () => {
|
||||
@@ -160,8 +213,8 @@ const BatchProgress: FC<BatchProgressProps> = ({
|
||||
})
|
||||
|
||||
// 计算进度
|
||||
const progress = jobData.totalRows > 0 ? (jobData.processedRows / jobData.totalRows) * 100 : 0
|
||||
const status = jobData.status as BatchStatus
|
||||
const progress = localProgress.totalRows > 0 ? (localProgress.processedRows / localProgress.totalRows) * 100 : 0
|
||||
const status = localProgress.status as BatchStatus
|
||||
const failed_count = 0 // 从列表API没有这个字段,如果需要可以后续添加
|
||||
|
||||
const getBorderColor = (status: BatchStatus) => {
|
||||
@@ -232,18 +285,18 @@ const BatchProgress: FC<BatchProgressProps> = ({
|
||||
</div>
|
||||
|
||||
{/* 详细进度信息 */}
|
||||
{jobData.totalRows > 0 && (
|
||||
{localProgress.totalRows > 0 && (
|
||||
<div className="mt-2 text-xs text-gray-500">
|
||||
{t('batchWorkflow.processed', {
|
||||
processed: jobData.processedRows || 0,
|
||||
total: jobData.totalRows || 0,
|
||||
processed: localProgress.processedRows || 0,
|
||||
total: localProgress.totalRows || 0,
|
||||
ns: 'extend',
|
||||
})}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* 错误信息显示 */}
|
||||
{jobData.error && status === 'failed' && (
|
||||
{localProgress.error && status === 'failed' && (
|
||||
<div className="mt-3 rounded-lg border border-red-200 bg-red-50 p-3">
|
||||
<div className="flex items-start space-x-2">
|
||||
<RiErrorWarningLine className="h-4 w-4 text-red-500 mt-0.5 flex-shrink-0" />
|
||||
@@ -252,7 +305,7 @@ const BatchProgress: FC<BatchProgressProps> = ({
|
||||
{t('batchWorkflow.errorOccurred', { ns: 'extend'} )}
|
||||
</div>
|
||||
<div className="text-xs text-red-700 break-words">
|
||||
{jobData.error}
|
||||
{localProgress.error}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
+2
-2
@@ -97,7 +97,7 @@
|
||||
"dompurify": "3.3.0",
|
||||
"echarts": "5.6.0",
|
||||
"echarts-for-react": "3.0.5",
|
||||
"elkjs": "0.9.3",
|
||||
"elkjs": "0.11.0",
|
||||
"embla-carousel-autoplay": "8.6.0",
|
||||
"embla-carousel-react": "8.6.0",
|
||||
"emoji-mart": "5.6.0",
|
||||
@@ -167,7 +167,7 @@
|
||||
"devDependencies": {
|
||||
"@antfu/eslint-config": "7.2.0",
|
||||
"@chromatic-com/storybook": "5.0.0",
|
||||
"@eslint-react/eslint-plugin": "2.8.1",
|
||||
"@eslint-react/eslint-plugin": "2.12.4",
|
||||
"@mdx-js/loader": "3.1.1",
|
||||
"@mdx-js/react": "3.1.1",
|
||||
"@next/bundle-analyzer": "16.1.5",
|
||||
|
||||
Reference in New Issue
Block a user