mirror of
https://github.com/YFGaia/dify-plus.git
synced 2026-06-04 10:14:00 +08:00
fix: 余额不足直接拦截
This commit is contained in:
@@ -122,6 +122,13 @@ func proxyWithAccountId(c *gin.Context, accountId string) {
|
||||
zap.Int("body_len", len(body)),
|
||||
zap.String("body_model", bodyModel),
|
||||
)
|
||||
|
||||
// 余额前置检查:余额耗尽时直接拦截,不继续请求上游
|
||||
if quotaErr := modelProviderService.CheckAccountQuota(accountId); quotaErr != nil {
|
||||
c.JSON(http.StatusPaymentRequired, gin.H{"error": gin.H{"message": quotaErr.Error()}})
|
||||
return
|
||||
}
|
||||
|
||||
if err = modelProviderService.ProxyRequest(
|
||||
accountId, path, c.Request.Method, reqHeader, body, c.Writer); err != nil {
|
||||
global.GVA_LOG.Error("代理请求失败", zap.String("account_id", accountId), zap.String("path", path), zap.Error(err))
|
||||
|
||||
@@ -203,7 +203,8 @@ func (e *SystemIntegratedService) DingTalkCodeLogin(req request.GaiaDingTalkLogi
|
||||
"code": req.AuthCode,
|
||||
"grantType": "authorization_code",
|
||||
})
|
||||
httpReq, err := http.NewRequest("POST", "https://api.dingtalk.com/v1.0/oauth2/userAccessToken", bytes.NewReader(bodyJSON))
|
||||
httpReq, err := http.NewRequest("POST",
|
||||
"https://api.dingtalk.com/v1.0/oauth2/userAccessToken", bytes.NewReader(bodyJSON))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -217,7 +218,8 @@ func (e *SystemIntegratedService) DingTalkCodeLogin(req request.GaiaDingTalkLogi
|
||||
defer resp.Body.Close()
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
global.GVA_LOG.Error("钉钉 token 非 200", zap.Int("status", resp.StatusCode), zap.String("body", string(respBody)))
|
||||
global.GVA_LOG.Error("钉钉 token 非 200", zap.Int(
|
||||
"status", resp.StatusCode), zap.String("body", string(respBody)))
|
||||
return nil, fmt.Errorf("钉钉返回错误: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
@@ -240,7 +242,6 @@ func (e *SystemIntegratedService) DingTalkCodeLogin(req request.GaiaDingTalkLogi
|
||||
}
|
||||
|
||||
var dingUser map[string]interface{}
|
||||
fmt.Println("sssssssss", string(userBody))
|
||||
if err = json.Unmarshal(userBody, &dingUser); err != nil {
|
||||
return nil, fmt.Errorf("解析钉钉用户信息失败")
|
||||
}
|
||||
|
||||
@@ -145,12 +145,23 @@ func (s *ModelProviderService) fetchModelPricingFromDify(modelName string) (*gai
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// calcQuotaDelta 根据定价和 token 用量计算本次消耗的配额金额。
|
||||
// 若未找到定价则回退到默认单价 0.001(每 token)。
|
||||
// rmbToUSD 将人民币金额按固定汇率换算为 USD(与 dashboard.go 保持一致,使用 7.26)。
|
||||
const rmbToUSDRate = 7.26
|
||||
|
||||
func rmbToUSD(rmb float64) float64 {
|
||||
return rmb / rmbToUSDRate
|
||||
}
|
||||
|
||||
// calcQuotaDelta 根据定价和 token 用量计算本次消耗的配额金额(统一以 USD 计)。
|
||||
// Dify pricing 字段语义:input/output 为每「unit」个 token 的价格,unit 通常为 0.001(千分之一),
|
||||
// 即 input=0.0014, unit=0.001 表示每千 token 收 $0.0014 × (tokens/1000)。
|
||||
// 公式:cost = tokens × input × unit(因为 unit=1/1000,等价于 tokens/1000 × input)。
|
||||
// 若货币为 RMB,则除以汇率 7.26 换算为 USD,与 account_money_extend.used_quota 存储单位保持一致。
|
||||
// 若未找到定价则回退到合理默认值:$0.001/1000 tokens(即每 token $0.000001)。
|
||||
func calcQuotaDelta(pricing *gaia.ModelPricing, promptTokens, completionTokens int) float64 {
|
||||
if pricing == nil || pricing.Unit == 0 {
|
||||
// 回退:按 0.001/token 统一计费
|
||||
return float64(promptTokens+completionTokens) * 0.001
|
||||
// 回退:按 $0.001 / 千token 计费(约 GPT-3.5 量级),避免按每 token 计费导致超额扣费
|
||||
return float64(promptTokens+completionTokens) * 0.001 * 0.001
|
||||
}
|
||||
inputCost := float64(promptTokens) * pricing.Input * pricing.Unit
|
||||
outputPrice := pricing.Output
|
||||
@@ -158,7 +169,38 @@ func calcQuotaDelta(pricing *gaia.ModelPricing, promptTokens, completionTokens i
|
||||
outputPrice = pricing.Input
|
||||
}
|
||||
outputCost := float64(completionTokens) * outputPrice * pricing.Unit
|
||||
return inputCost + outputCost
|
||||
total := inputCost + outputCost
|
||||
|
||||
// RMB 定价统一换算为 USD 后再扣费,与 used_quota 存储单位保持一致
|
||||
if strings.EqualFold(pricing.Currency, "RMB") || strings.EqualFold(pricing.Currency, "CNY") {
|
||||
total = rmbToUSD(total)
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
// CheckAccountQuota 检查用户是否还有可用余额(total_quota - used_quota > 0)。
|
||||
// total_quota = 0 视为"未设置限额",不拦截;total_quota > 0 时才做余额校验。
|
||||
func (s *ModelProviderService) CheckAccountQuota(userID string) error {
|
||||
var row struct {
|
||||
TotalQuota float64 `gorm:"column:total_quota"`
|
||||
UsedQuota float64 `gorm:"column:used_quota"`
|
||||
}
|
||||
err := global.GVA_DB.Table("account_money_extend").
|
||||
Select("total_quota, used_quota").
|
||||
Where("account_id = ?::uuid", userID).
|
||||
First(&row).Error
|
||||
if err != nil {
|
||||
// 记录未找到:可能尚未初始化,放行
|
||||
return nil
|
||||
}
|
||||
// total_quota = 0 表示不限额,放行
|
||||
if row.TotalQuota <= 0 {
|
||||
return nil
|
||||
}
|
||||
if row.UsedQuota >= row.TotalQuota {
|
||||
return fmt.Errorf("余额不足,已用 %.6f / 总额 %.6f USD,请联系管理员充值", row.UsedQuota, row.TotalQuota)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// deductAccountQuota 将消耗配额计入 account_money_extend.used_quota(原子累加)。
|
||||
|
||||
Reference in New Issue
Block a user