Files
dify-plus/admin/server/api/v1/gaia/model_provider.go
T
2026-03-11 12:05:53 +08:00

270 lines
8.4 KiB
Go

package gaia
import (
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"github.com/flipped-aurora/gin-vue-admin/server/global"
"github.com/flipped-aurora/gin-vue-admin/server/model/common/response"
"github.com/flipped-aurora/gin-vue-admin/server/service"
"github.com/flipped-aurora/gin-vue-admin/server/utils"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
type ModelProviderApi struct{}
var modelProviderService = service.ServiceGroupApp.GaiaServiceGroup.ModelProviderService
// GetProviderList 获取提供商配置列表
// @Tags ModelProvider
// @Summary 获取提供商配置列表
// @Security ApiKeyAuth
// @accept application/json
// @Produce application/json
// @Success 200 {object} response.Response{data=[]gaiaResponse.ProviderListItem,msg=string} "获取成功"
// @Router /gaia/model-provider/list [get]
func (m *ModelProviderApi) GetProviderList(c *gin.Context) {
list, err := modelProviderService.GetProviderList()
if err != nil {
global.GVA_LOG.Error("获取提供商配置列表失败", zap.Error(err))
response.FailWithMessage("获取失败: "+err.Error(), c)
return
}
response.OkWithData(list, c)
}
// UpdateProviderConfig 更新提供商配置
// @Tags ModelProvider
// @Summary 更新提供商配置
// @Security ApiKeyAuth
// @accept application/json
// @Produce application/json
// @Param data body object true "提供商配置"
// @Success 200 {object} response.Response{msg=string} "更新成功"
// @Router /gaia/model-provider/update [post]
func (m *ModelProviderApi) UpdateProviderConfig(c *gin.Context) {
var req struct {
ProviderName string `json:"provider_name" binding:"required"`
Enabled bool `json:"enabled"`
Models []string `json:"models"`
}
if err := c.ShouldBindJSON(&req); err != nil {
response.FailWithMessage("参数错误: "+err.Error(), c)
return
}
if err := modelProviderService.UpdateProviderConfig(req.ProviderName, req.Enabled, req.Models); err != nil {
global.GVA_LOG.Error("更新提供商配置失败", zap.String("provider", req.ProviderName), zap.Error(err))
response.FailWithMessage("更新失败: "+err.Error(), c)
return
}
response.OkWithMessage("更新成功", c)
}
// GetModels 获取开启的模型列表(OpenAI格式)
// @Tags ModelProvider
// @Summary 获取开启的模型列表
// @Security ApiKeyAuth
// @accept application/json
// @Produce application/json
// @Success 200 {object} gaiaResponse.OpenAIModelsResponse "获取成功"
// @Router /gaia/models [get]
func (m *ModelProviderApi) GetModels(c *gin.Context) {
models, err := modelProviderService.GetEnabledModels()
if err != nil {
global.GVA_LOG.Error("获取模型列表失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": "获取模型列表失败: " + err.Error(),
},
})
return
}
c.JSON(http.StatusOK, models)
}
// proxyWithAccountId 通用代理逻辑:按路径转发到上游并计费。
func proxyWithAccountId(c *gin.Context, accountId string) {
path := c.Param("path")
if path == "" || path == "/" {
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{"message": "代理路径不能为空"}})
return
}
reqHeader := c.Request.Header.Clone()
if q := strings.TrimSpace(c.Query("provider")); q != "" {
reqHeader.Set("X-Gaia-Provider", q)
}
body, err := io.ReadAll(c.Request.Body)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{"message": "读取请求体失败"}})
return
}
var bodyModel string
if len(body) > 0 {
var parseObj map[string]interface{}
if jsonErr := json.Unmarshal(body, &parseObj); jsonErr == nil {
if mv, ok := parseObj["model"].(string); ok {
bodyModel = mv
}
}
}
global.GVA_LOG.Info("Gaia代理请求入参",
zap.String("account_id", accountId),
zap.String("path", path),
zap.String("method", c.Request.Method),
zap.Int("body_len", len(body)),
zap.String("body_model", bodyModel),
)
if err = modelProviderService.ProxyRequest(
accountId, path, c.Request.Method, reqHeader, body, c.Writer); err != nil {
global.GVA_LOG.Error("代理请求失败", zap.String("account_id", accountId), zap.String("path", path), zap.Error(err))
if !c.Writer.Written() {
c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{"message": err.Error()}})
}
}
}
// Proxy 通用中转 API:将 /gaia/proxy/* 的请求按路径转发到上游(需 JWT,account 来自当前登录用户)。
// @Tags ModelProvider
// @Summary 通用中转API(按路径转发)
// @Security ApiKeyAuth
// @Param path path string true "上游路径,如 v1/chat/completions、v1/messages"
// @Router /gaia/proxy/*path [get,post,put,patch,delete]
func (m *ModelProviderApi) Proxy(c *gin.Context) {
accountId := utils.GetUserUuid(c).String()
proxyWithAccountId(c, accountId)
}
// GetAvailableModels 获取提供商的可用模型
// @Tags ModelProvider
// @Summary 获取提供商的可用模型
// @Security ApiKeyAuth
// @accept application/json
// @Produce application/json
// @Param provider_name query string true "提供商名称"
// @Success 200 {object} response.Response{data=[]gaiaResponse.ModelInfo,msg=string} "获取成功"
// @Router /gaia/model-provider/available-models [get]
func (m *ModelProviderApi) GetAvailableModels(c *gin.Context) {
providerName := c.Query("provider_name")
if providerName == "" {
response.FailWithMessage("参数错误: provider_name不能为空", c)
return
}
models, err := modelProviderService.GetAvailableModelsFromDify(providerName)
if err != nil {
global.GVA_LOG.Error("获取可用模型失败", zap.String("provider", providerName), zap.Error(err))
response.FailWithMessage("获取失败: "+err.Error(), c)
return
}
response.OkWithData(models, c)
}
// TestProviderCredentials 测试提供商凭证
// @Tags ModelProvider
// @Summary 测试提供商凭证
// @Security ApiKeyAuth
// @accept application/json
// @Produce application/json
// @Param provider_name query string true "提供商名称"
// @Success 200 {object} response.Response{msg=string} "测试成功"
// @Router /gaia/model-provider/test-credentials [get]
func (m *ModelProviderApi) TestProviderCredentials(c *gin.Context) {
providerName := c.Query("provider_name")
if providerName == "" {
response.FailWithMessage("参数错误: provider_name不能为空", c)
return
}
creds, err := modelProviderService.GetDifyProviderCredentials(providerName)
if err != nil {
global.GVA_LOG.Error("获取提供商凭证失败", zap.String("provider", providerName), zap.Error(err))
response.FailWithMessage("获取凭证失败: "+err.Error(), c)
return
}
// 隐藏API Key的大部分内容
maskedKey := ""
if len(creds.APIKey) > 8 {
maskedKey = creds.APIKey[:4] + "****" + creds.APIKey[len(creds.APIKey)-4:]
} else {
maskedKey = "****"
}
result := map[string]interface{}{
"provider": providerName,
"has_api_key": creds.APIKey != "",
"api_key": maskedKey,
}
response.OkWithData(result, c)
}
// GetProxyLogs 获取代理日志
// @Tags ModelProvider
// @Summary 获取代理日志
// @Security ApiKeyAuth
// @accept application/json
// @Produce application/json
// @Param page query int false "页码"
// @Param page_size query int false "每页数量"
// @Success 200 {object} response.Response{data=map[string]interface{},msg=string} "获取成功"
// @Router /gaia/model-provider/logs [get]
func (m *ModelProviderApi) GetProxyLogs(c *gin.Context) {
page := c.DefaultQuery("page", "1")
pageSize := c.DefaultQuery("page_size", "20")
var pageInt, pageSizeInt int
if _, err := fmt.Sscanf(page, "%d", &pageInt); err != nil {
pageInt = 1
}
if _, err := fmt.Sscanf(pageSize, "%d", &pageSizeInt); err != nil {
pageSizeInt = 20
}
if pageInt < 1 {
pageInt = 1
}
if pageSizeInt < 1 || pageSizeInt > 100 {
pageSizeInt = 20
}
var logs []map[string]interface{}
var total int64
db := global.GVA_DB.Table("model_proxy_log")
// 获取总数
if err := db.Count(&total).Error; err != nil {
global.GVA_LOG.Error("获取日志总数失败", zap.Error(err))
response.FailWithMessage("获取失败: "+err.Error(), c)
return
}
// 分页查询
offset := (pageInt - 1) * pageSizeInt
if err := db.Order("created_at DESC").Limit(pageSizeInt).Offset(
offset).Find(&logs).Error; err != nil {
global.GVA_LOG.Error("获取日志列表失败", zap.Error(err))
response.FailWithMessage("获取失败: "+err.Error(), c)
return
}
result := map[string]interface{}{
"list": logs,
"total": total,
"page": pageInt,
"page_size": pageSizeInt,
}
response.OkWithData(result, c)
}