mirror of
https://github.com/YFGaia/dify-plus.git
synced 2026-06-04 10:14:00 +08:00
feat: 新增后端模型管理,第三方快捷登录
This commit is contained in:
@@ -11,6 +11,7 @@ type ApiGroup struct {
|
||||
SystemOAuth2Api
|
||||
BatchWorkflowApi
|
||||
AppVersionApi
|
||||
ModelProviderApi
|
||||
}
|
||||
|
||||
var (
|
||||
|
||||
@@ -0,0 +1,254 @@
|
||||
package gaia
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/global"
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/model/common/response"
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/service"
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/utils"
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type ModelProviderApi struct{}
|
||||
|
||||
var modelProviderService = service.ServiceGroupApp.GaiaServiceGroup.ModelProviderService
|
||||
|
||||
// GetProviderList 获取提供商配置列表
|
||||
// @Tags ModelProvider
|
||||
// @Summary 获取提供商配置列表
|
||||
// @Security ApiKeyAuth
|
||||
// @accept application/json
|
||||
// @Produce application/json
|
||||
// @Success 200 {object} response.Response{data=[]gaiaResponse.ProviderListItem,msg=string} "获取成功"
|
||||
// @Router /gaia/model-provider/list [get]
|
||||
func (m *ModelProviderApi) GetProviderList(c *gin.Context) {
|
||||
list, err := modelProviderService.GetProviderList()
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("获取提供商配置列表失败", zap.Error(err))
|
||||
response.FailWithMessage("获取失败: "+err.Error(), c)
|
||||
return
|
||||
}
|
||||
response.OkWithData(list, c)
|
||||
}
|
||||
|
||||
// UpdateProviderConfig 更新提供商配置
|
||||
// @Tags ModelProvider
|
||||
// @Summary 更新提供商配置
|
||||
// @Security ApiKeyAuth
|
||||
// @accept application/json
|
||||
// @Produce application/json
|
||||
// @Param data body object true "提供商配置"
|
||||
// @Success 200 {object} response.Response{msg=string} "更新成功"
|
||||
// @Router /gaia/model-provider/update [post]
|
||||
func (m *ModelProviderApi) UpdateProviderConfig(c *gin.Context) {
|
||||
var req struct {
|
||||
ProviderName string `json:"provider_name" binding:"required"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Models []string `json:"models"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.FailWithMessage("参数错误: "+err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
if err := modelProviderService.UpdateProviderConfig(req.ProviderName, req.Enabled, req.Models); err != nil {
|
||||
global.GVA_LOG.Error("更新提供商配置失败", zap.String("provider", req.ProviderName), zap.Error(err))
|
||||
response.FailWithMessage("更新失败: "+err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
response.OkWithMessage("更新成功", c)
|
||||
}
|
||||
|
||||
// GetModels 获取开启的模型列表(OpenAI格式)
|
||||
// @Tags ModelProvider
|
||||
// @Summary 获取开启的模型列表
|
||||
// @Security ApiKeyAuth
|
||||
// @accept application/json
|
||||
// @Produce application/json
|
||||
// @Success 200 {object} gaiaResponse.OpenAIModelsResponse "获取成功"
|
||||
// @Router /gaia/models [get]
|
||||
func (m *ModelProviderApi) GetModels(c *gin.Context) {
|
||||
models, err := modelProviderService.GetEnabledModels()
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("获取模型列表失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": gin.H{
|
||||
"message": "获取模型列表失败: " + err.Error(),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, models)
|
||||
}
|
||||
|
||||
// Proxy 通用中转 API:将 /gaia/proxy/* 的请求按路径转发到上游(如 /v1/chat/completions、/v1/messages、/v1/images/generations、/v1/embeddings 等)。
|
||||
// 上游 base 优先使用 provider_credentials 的 openai_api_base(如 "https://yunwu.ai"),便于计费区分。
|
||||
// @Tags ModelProvider
|
||||
// @Summary 通用中转API(按路径转发)
|
||||
// @Security ApiKeyAuth
|
||||
// @Param path path string true "上游路径,如 v1/chat/completions、v1/messages"
|
||||
// @Router /gaia/proxy/*path [get,post,put,patch,delete]
|
||||
func (m *ModelProviderApi) Proxy(c *gin.Context) {
|
||||
// init
|
||||
var err error
|
||||
var body []byte
|
||||
path := c.Param("path")
|
||||
userID := utils.GetUserUuid(c).String()
|
||||
if path == "" || path == "/" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{"message": "代理路径不能为空"}})
|
||||
return
|
||||
}
|
||||
// 将 query provider 转为请求头,供 service 解析
|
||||
reqHeader := c.Request.Header.Clone()
|
||||
if q := strings.TrimSpace(c.Query("provider")); q != "" {
|
||||
reqHeader.Set("X-Gaia-Provider", q)
|
||||
}
|
||||
|
||||
if body, err = io.ReadAll(c.Request.Body); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": gin.H{"message": "读取请求体失败"}})
|
||||
return
|
||||
}
|
||||
|
||||
if err = modelProviderService.ProxyRequest(
|
||||
userID, path, c.Request.Method, reqHeader, body, c.Writer); err != nil {
|
||||
global.GVA_LOG.Error("代理请求失败", zap.String("user_id", userID), zap.String(
|
||||
"path", path), zap.Error(err))
|
||||
if !c.Writer.Written() {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": gin.H{"message": err.Error()}})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetAvailableModels 获取提供商的可用模型
|
||||
// @Tags ModelProvider
|
||||
// @Summary 获取提供商的可用模型
|
||||
// @Security ApiKeyAuth
|
||||
// @accept application/json
|
||||
// @Produce application/json
|
||||
// @Param provider_name query string true "提供商名称"
|
||||
// @Success 200 {object} response.Response{data=[]gaiaResponse.ModelInfo,msg=string} "获取成功"
|
||||
// @Router /gaia/model-provider/available-models [get]
|
||||
func (m *ModelProviderApi) GetAvailableModels(c *gin.Context) {
|
||||
providerName := c.Query("provider_name")
|
||||
if providerName == "" {
|
||||
response.FailWithMessage("参数错误: provider_name不能为空", c)
|
||||
return
|
||||
}
|
||||
|
||||
models, err := modelProviderService.GetAvailableModelsFromDify(providerName)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("获取可用模型失败", zap.String("provider", providerName), zap.Error(err))
|
||||
response.FailWithMessage("获取失败: "+err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
response.OkWithData(models, c)
|
||||
}
|
||||
|
||||
// TestProviderCredentials 测试提供商凭证
|
||||
// @Tags ModelProvider
|
||||
// @Summary 测试提供商凭证
|
||||
// @Security ApiKeyAuth
|
||||
// @accept application/json
|
||||
// @Produce application/json
|
||||
// @Param provider_name query string true "提供商名称"
|
||||
// @Success 200 {object} response.Response{msg=string} "测试成功"
|
||||
// @Router /gaia/model-provider/test-credentials [get]
|
||||
func (m *ModelProviderApi) TestProviderCredentials(c *gin.Context) {
|
||||
providerName := c.Query("provider_name")
|
||||
if providerName == "" {
|
||||
response.FailWithMessage("参数错误: provider_name不能为空", c)
|
||||
return
|
||||
}
|
||||
|
||||
creds, err := modelProviderService.GetDifyProviderCredentials(providerName)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("获取提供商凭证失败", zap.String("provider", providerName), zap.Error(err))
|
||||
response.FailWithMessage("获取凭证失败: "+err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
// 隐藏API Key的大部分内容
|
||||
maskedKey := ""
|
||||
if len(creds.APIKey) > 8 {
|
||||
maskedKey = creds.APIKey[:4] + "****" + creds.APIKey[len(creds.APIKey)-4:]
|
||||
} else {
|
||||
maskedKey = "****"
|
||||
}
|
||||
|
||||
result := map[string]interface{}{
|
||||
"provider": providerName,
|
||||
"has_api_key": creds.APIKey != "",
|
||||
"api_key": maskedKey,
|
||||
}
|
||||
|
||||
response.OkWithData(result, c)
|
||||
}
|
||||
|
||||
// GetProxyLogs 获取代理日志
|
||||
// @Tags ModelProvider
|
||||
// @Summary 获取代理日志
|
||||
// @Security ApiKeyAuth
|
||||
// @accept application/json
|
||||
// @Produce application/json
|
||||
// @Param page query int false "页码"
|
||||
// @Param page_size query int false "每页数量"
|
||||
// @Success 200 {object} response.Response{data=map[string]interface{},msg=string} "获取成功"
|
||||
// @Router /gaia/model-provider/logs [get]
|
||||
func (m *ModelProviderApi) GetProxyLogs(c *gin.Context) {
|
||||
page := c.DefaultQuery("page", "1")
|
||||
pageSize := c.DefaultQuery("page_size", "20")
|
||||
|
||||
var pageInt, pageSizeInt int
|
||||
if _, err := fmt.Sscanf(page, "%d", &pageInt); err != nil {
|
||||
pageInt = 1
|
||||
}
|
||||
if _, err := fmt.Sscanf(pageSize, "%d", &pageSizeInt); err != nil {
|
||||
pageSizeInt = 20
|
||||
}
|
||||
|
||||
if pageInt < 1 {
|
||||
pageInt = 1
|
||||
}
|
||||
if pageSizeInt < 1 || pageSizeInt > 100 {
|
||||
pageSizeInt = 20
|
||||
}
|
||||
|
||||
var logs []map[string]interface{}
|
||||
var total int64
|
||||
|
||||
db := global.GVA_DB.Table("model_proxy_log")
|
||||
|
||||
// 获取总数
|
||||
if err := db.Count(&total).Error; err != nil {
|
||||
global.GVA_LOG.Error("获取日志总数失败", zap.Error(err))
|
||||
response.FailWithMessage("获取失败: "+err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
// 分页查询
|
||||
offset := (pageInt - 1) * pageSizeInt
|
||||
if err := db.Order("created_at DESC").Limit(pageSizeInt).Offset(
|
||||
offset).Find(&logs).Error; err != nil {
|
||||
global.GVA_LOG.Error("获取日志列表失败", zap.Error(err))
|
||||
response.FailWithMessage("获取失败: "+err.Error(), c)
|
||||
return
|
||||
}
|
||||
|
||||
result := map[string]interface{}{
|
||||
"list": logs,
|
||||
"total": total,
|
||||
"page": pageInt,
|
||||
"page_size": pageSizeInt,
|
||||
}
|
||||
|
||||
response.OkWithData(result, c)
|
||||
}
|
||||
@@ -495,6 +495,7 @@ func generateCSVFromTasks(flow *gaia.BatchWorkflow, tasks []gaia.BatchWorkflowTa
|
||||
nameList = append(nameList, value)
|
||||
}
|
||||
headers = append(headers, "生成结果")
|
||||
headers = append(headers, "报错信息")
|
||||
_ = w.Write(headers)
|
||||
|
||||
// 行数据
|
||||
@@ -526,6 +527,9 @@ func generateCSVFromTasks(flow *gaia.BatchWorkflow, tasks []gaia.BatchWorkflowTa
|
||||
}
|
||||
}
|
||||
row = append(row, text)
|
||||
if len(task.Error) > 0 {
|
||||
row = append(row, task.Error)
|
||||
}
|
||||
_ = w.Write(row)
|
||||
}
|
||||
|
||||
|
||||
@@ -5,12 +5,16 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/global"
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/model/common/response"
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/model/system"
|
||||
gaiaReq "github.com/flipped-aurora/gin-vue-admin/server/model/gaia/request"
|
||||
systemReq "github.com/flipped-aurora/gin-vue-admin/server/model/system/request"
|
||||
systemRes "github.com/flipped-aurora/gin-vue-admin/server/model/system/response"
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/service"
|
||||
sysSvc "github.com/flipped-aurora/gin-vue-admin/server/service/system"
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/utils"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-resty/resty/v2"
|
||||
@@ -18,6 +22,8 @@ import (
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var gaiaSystemIntegratedService = service.ServiceGroupApp.GaiaServiceGroup.SystemIntegratedService
|
||||
|
||||
// Extend Start: sync user
|
||||
|
||||
// SyncUser
|
||||
@@ -184,3 +190,89 @@ func (b *BaseApi) OAuth2Callback(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Extend Stop: oAuth2 callback verification
|
||||
|
||||
// GetGaiaLoginOptions 获取 Gaia 登录方式(钉钉/OAuth2 是否启用及授权地址),供登录页展示,无需鉴权
|
||||
// @Tags Base
|
||||
// @Summary 获取登录方式选项
|
||||
// @Produce application/json
|
||||
// @Param origin query string false "前端 origin,用于拼回调地址"
|
||||
// @Router /base/gaiaLoginOptions [get]
|
||||
func (b *BaseApi) GetGaiaLoginOptions(c *gin.Context) {
|
||||
origin := c.Query("origin")
|
||||
if origin == "" {
|
||||
origin = c.GetHeader("Origin")
|
||||
}
|
||||
if origin == "" {
|
||||
origin = strings.TrimSuffix(global.GVA_CONFIG.Gaia.Url, "/")
|
||||
}
|
||||
opts := gaiaSystemIntegratedService.GetLoginOptions(origin)
|
||||
response.OkWithData(opts, c)
|
||||
}
|
||||
|
||||
// GaiaOAuth2Login 使用系统集成 OAuth2 的 code 或 access_token(Extend: 兼容 casdoor)登录,返回 JWT;若带 redirect_uri/state 则一并返回供前端回调第三方
|
||||
// @Tags Base
|
||||
// @Summary Gaia OAuth2 登录
|
||||
// @Produce application/json
|
||||
// @Param data body gaiaReq.GaiaOAuth2LoginReq true "code 或 access_token 二选一、redirect_uri、state"
|
||||
// @Router /base/gaiaOAuth2Login [post]
|
||||
func (b *BaseApi) GaiaOAuth2Login(c *gin.Context) {
|
||||
var req gaiaReq.GaiaOAuth2LoginReq
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.FailWithMessage(err.Error(), c)
|
||||
return
|
||||
}
|
||||
result, err := gaiaSystemIntegratedService.OAuth2CodeLogin(req)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("Gaia OAuth2 登录失败", zap.Error(err))
|
||||
response.FailWithMessage(err.Error(), c)
|
||||
return
|
||||
}
|
||||
sysSvc.MenuServiceApp.UserAuthorityDefaultRouter(&result.User)
|
||||
data := map[string]interface{}{
|
||||
"user": result.User,
|
||||
"token": result.Token,
|
||||
"expiresAt": 0,
|
||||
}
|
||||
if result.RedirectURI != "" {
|
||||
data["redirect_uri"] = result.RedirectURI
|
||||
}
|
||||
if result.State != "" {
|
||||
data["state"] = result.State
|
||||
}
|
||||
response.OkWithDetailed(data, "登录成功", c)
|
||||
}
|
||||
|
||||
// GaiaDingTalkLogin 钉钉 code 登录,返回 JWT
|
||||
// @Tags Base
|
||||
// @Summary 钉钉登录
|
||||
// @Produce application/json
|
||||
// @Param data body gaiaReq.GaiaDingTalkLoginReq true "auth_code、redirect_uri、state"
|
||||
// @Router /base/dingtalkLogin [post]
|
||||
func (b *BaseApi) GaiaDingTalkLogin(c *gin.Context) {
|
||||
var req gaiaReq.GaiaDingTalkLoginReq
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.FailWithMessage(err.Error(), c)
|
||||
return
|
||||
}
|
||||
result, err := gaiaSystemIntegratedService.DingTalkCodeLogin(req)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("钉钉登录失败", zap.Error(err))
|
||||
response.FailWithMessage(err.Error(), c)
|
||||
return
|
||||
}
|
||||
sysSvc.MenuServiceApp.UserAuthorityDefaultRouter(&result.User)
|
||||
data := map[string]interface{}{
|
||||
"user": result.User,
|
||||
"token": result.Token,
|
||||
"expiresAt": 0,
|
||||
}
|
||||
if result.RedirectURI != "" {
|
||||
data["redirect_uri"] = result.RedirectURI
|
||||
}
|
||||
if result.State != "" {
|
||||
data["state"] = result.State
|
||||
}
|
||||
response.OkWithDetailed(data, "登录成功", c)
|
||||
}
|
||||
|
||||
// Extend Stop: gaia login
|
||||
|
||||
@@ -61,6 +61,7 @@ gaia:
|
||||
login_max_error_limit: 5
|
||||
SUPER_ADMIN_ACCOUNT_ID: a30d5d5a-8350-4aac-ac56-7b08926df23c
|
||||
SUPER_ADMIN_TENANT_ID: 93fef0de-5eb0-4542-9077-d70126379751
|
||||
storage-path: ../../api/storage
|
||||
hua-wei-obs:
|
||||
path: ""
|
||||
bucket: ""
|
||||
|
||||
@@ -5,4 +5,5 @@ type Gaia struct {
|
||||
LoginMaxErrorLimit int `mapstructure:"login_max_error_limit" json:"login_max_error_limit" yaml:"login_max_error_limit"`
|
||||
SuperAdminAccountId string `mapstructure:"SUPER_ADMIN_ACCOUNT_ID" json:"SUPER_ADMIN_ACCOUNT_ID" yaml:"SUPER_ADMIN_ACCOUNT_ID"` // 超级管理员账号
|
||||
SuperAdminTenantId string `mapstructure:"SUPER_ADMIN_TENANT_ID" json:"SUPER_ADMIN_TENANT_ID" yaml:"SUPER_ADMIN_TENANT_ID"` // 系统默认工作区
|
||||
StoragePath string `mapstructure:"storage-path" json:"storage-path" yaml:"storage-path"` // Dify storage 目录路径,用于读取私钥
|
||||
}
|
||||
|
||||
@@ -159,6 +159,7 @@ require (
|
||||
github.com/xuri/nfp v0.0.0-20240318013403-ab9948c2c4a7 // indirect
|
||||
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 // indirect
|
||||
github.com/yusufpapurcu/wmi v1.2.4 // indirect
|
||||
go.gnd.pw/crypto v0.0.0-20231118094619-86ae7742a3a2 // indirect
|
||||
go.uber.org/multierr v1.11.0 // indirect
|
||||
go4.org v0.0.0-20230225012048-214862532bf5 // indirect
|
||||
golang.org/x/arch v0.11.0 // indirect
|
||||
|
||||
@@ -71,6 +71,7 @@ github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
|
||||
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
|
||||
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
|
||||
github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0=
|
||||
github.com/bwesterb/go-ristretto v1.2.3/go.mod h1:fUIoIZaG73pV5biE2Blr2xEzDoMj7NFEuV9ekS419A0=
|
||||
github.com/bytedance/sonic v1.12.3 h1:W2MGa7RCU1QTeYRTPE3+88mVC0yXmsRQRChiyVocVjU=
|
||||
github.com/bytedance/sonic v1.12.3/go.mod h1:B8Gt/XvtZ3Fqj+iSKMypzymZxw/FVwgIGKzMzT9r/rk=
|
||||
github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU=
|
||||
@@ -91,6 +92,7 @@ github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMn
|
||||
github.com/clbanning/mxj v1.8.4 h1:HuhwZtbyvyOw+3Z1AowPkU87JkJUSv751ELWaiTpj8I=
|
||||
github.com/clbanning/mxj v1.8.4/go.mod h1:BVjHeAH+rl9rs6f+QIpeRl0tfu10SXn1pUSa5PVGJng=
|
||||
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
|
||||
github.com/cloudflare/circl v1.3.6/go.mod h1:5XYMA4rFBvNIrhs50XuiBJ15vF2pZn4nnUKZrLbUZFA=
|
||||
github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y=
|
||||
github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w=
|
||||
github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg=
|
||||
@@ -531,6 +533,8 @@ github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5t
|
||||
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
|
||||
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
|
||||
go.etcd.io/bbolt v1.3.4/go.mod h1:G5EMThwa9y8QZGBClrRx5EY+Yw9kAhnjy3bSjsnlVTQ=
|
||||
go.gnd.pw/crypto v0.0.0-20231118094619-86ae7742a3a2 h1:jkdXGtlZKz4yAxwyqiKtDKtuSWT+7dkE8bANeUFx0ho=
|
||||
go.gnd.pw/crypto v0.0.0-20231118094619-86ae7742a3a2/go.mod h1:OZiEjARbR5CCaBj8sdmBww0fOhivBcG0YI2glaB5iL8=
|
||||
go.mongodb.org/mongo-driver v1.11.6/go.mod h1:G9TgswdsWjX4tmDA5zfs2+6AEPpYJwqblyjsfuh8oXY=
|
||||
go.mongodb.org/mongo-driver v1.17.1 h1:Wic5cJIwJgSpBhe3lx3+/RybR5PiYRMpVFgO7cOHyIM=
|
||||
go.mongodb.org/mongo-driver v1.17.1/go.mod h1:wwWm/+BuOddhcq3n68LKRmgk2wXzmF6s0SFOa0GINL4=
|
||||
@@ -557,10 +561,12 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U
|
||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||
golang.org/x/crypto v0.3.1-0.20221117191849-2c476679df9a/go.mod h1:hebNnKkNXi2UzZN1eVRvBB7co0a+JxK6XbPiWVs/3J4=
|
||||
golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58=
|
||||
golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU=
|
||||
golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0=
|
||||
golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw=
|
||||
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
|
||||
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
|
||||
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
|
||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
@@ -618,6 +624,7 @@ golang.org/x/net v0.0.0-20201010224723-4f7140c49acb/go.mod h1:sp8m0HH+o8qH0wwXwY
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
||||
golang.org/x/net v0.2.0/go.mod h1:KqCZLdyyvdV855qA2rE3GC2aiw5xGR5TEjj8smXukLY=
|
||||
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
|
||||
golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
|
||||
golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc=
|
||||
@@ -671,18 +678,23 @@ golang.org/x/sys v0.0.0-20210616045830-e2b7044e8c71/go.mod h1:oPkhp1MJrh7nUepCBc
|
||||
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.3.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
|
||||
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||
golang.org/x/term v0.2.0/go.mod h1:TVmDHMZPmdnySmBfhjOoOdhjzdE1h4u1VwSiw2l1Nuc=
|
||||
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
|
||||
golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U=
|
||||
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
|
||||
golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU=
|
||||
golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U=
|
||||
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
@@ -691,6 +703,7 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||
golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ=
|
||||
golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
||||
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
||||
golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
|
||||
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
|
||||
|
||||
@@ -69,14 +69,16 @@ func RegisterTables() {
|
||||
gaia.AccountDingTalkExtend{},
|
||||
gaia.AppRequestTestBatch{},
|
||||
gaia.AppRequestTest{},
|
||||
gaia.SystemIntegration{}, // Extend System Integration
|
||||
gaia.ForwardingExtend{}, // Extend Forwarding Extend
|
||||
gaia.BatchWorkflow{}, // Extend Batch Workflow
|
||||
gaia.BatchWorkflowTask{}, // Extend Batch Workflow Task
|
||||
gaia.AppVersionConfig{}, // 应用版本全局配置(Token)
|
||||
gaia.AppVersionRelease{}, // 应用版本发布
|
||||
gaia.AppVersionDownload{}, // 应用版本各平台安装包
|
||||
system.SysUserGlobalCode{}, // Extend Global Code
|
||||
gaia.SystemIntegration{}, // Extend System Integration
|
||||
gaia.ForwardingExtend{}, // Extend Forwarding Extend
|
||||
gaia.BatchWorkflow{}, // Extend Batch Workflow
|
||||
gaia.BatchWorkflowTask{}, // Extend Batch Workflow Task
|
||||
gaia.AppVersionConfig{}, // 应用版本全局配置(Token)
|
||||
gaia.AppVersionRelease{}, // 应用版本发布
|
||||
gaia.AppVersionDownload{}, // 应用版本各平台安装包
|
||||
gaia.ModelProviderConfig{}, // 模型提供商配置
|
||||
gaia.ModelProxyLog{}, // 模型中转请求日志
|
||||
system.SysUserGlobalCode{}, // Extend Global Code
|
||||
// Extend gaia model
|
||||
)
|
||||
|
||||
|
||||
@@ -22,5 +22,6 @@ func initBizRouter(routers ...*gin.RouterGroup) {
|
||||
gaiaRouter.InitSystemRouter(privateGroup)
|
||||
gaiaRouter.InitWorkflowRouter(privateGroup)
|
||||
gaiaRouter.InitAppVersionRouter(publicGroup, privateGroup)
|
||||
gaiaRouter.InitModelProviderRouter(privateGroup) // 模型提供商路由
|
||||
}
|
||||
}
|
||||
|
||||
@@ -51,7 +51,6 @@ func JWTAuth() gin.HandlerFunc {
|
||||
|
||||
// 已登录用户被管理员禁用 需要使该用户的jwt失效 此处比较消耗性能 如果需要 请自行打开
|
||||
// 用户被删除的逻辑 需要优化 此处比较消耗性能 如果需要 请自行打开
|
||||
|
||||
//if user, err := userService.FindUserByUuid(claims.UUID.String()); err != nil || user.Enable == 2 {
|
||||
// _ = jwtService.JsonInBlacklist(system.JwtBlacklist{Jwt: token})
|
||||
// response.FailWithDetailed(gin.H{"reload": true}, err.Error(), c)
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
package gaia
|
||||
|
||||
import "time"
|
||||
|
||||
// ModelProviderConfig 模型提供商配置表
|
||||
type ModelProviderConfig struct {
|
||||
Id uint `json:"id" form:"id" gorm:"primarykey;column:id;comment:id;"`
|
||||
ProviderName string `json:"provider_name" gorm:"unique;not null;column:provider_name;comment:提供商名称"`
|
||||
Enabled bool `json:"enabled" gorm:"default:false;column:enabled;comment:是否开启"`
|
||||
Models string `json:"models" gorm:"type:text;column:models;comment:开启的模型列表(JSON数组)"`
|
||||
Config string `json:"config" gorm:"type:text;column:config;comment:额外配置(JSON)"`
|
||||
CreatedAt time.Time `json:"created_at" gorm:"column:created_at;comment:创建时间"`
|
||||
UpdatedAt time.Time `json:"updated_at" gorm:"column:updated_at;comment:更新时间"`
|
||||
}
|
||||
|
||||
// TableName ModelProviderConfig自定义表名 model_provider_config
|
||||
func (ModelProviderConfig) TableName() string {
|
||||
return "model_provider_config_extend"
|
||||
}
|
||||
|
||||
// ModelProxyLog 模型中转请求日志表
|
||||
type ModelProxyLog struct {
|
||||
Id uint `json:"id" form:"id" gorm:"primarykey;column:id;comment:id;"`
|
||||
UserId string `json:"user_id" gorm:"type:uuid;not null;column:user_id;comment:用户ID"`
|
||||
ProviderName string `json:"provider_name" gorm:"column:provider_name;comment:提供商"`
|
||||
ModelName string `json:"model_name" gorm:"column:model_name;comment:模型名"`
|
||||
RequestTokens int `json:"request_tokens" gorm:"column:request_tokens;comment:请求token数"`
|
||||
ResponseTokens int `json:"response_tokens" gorm:"column:response_tokens;comment:响应token数"`
|
||||
Status string `json:"status" gorm:"column:status;comment:状态"`
|
||||
ErrorMessage string `json:"error_message" gorm:"type:text;column:error_message;comment:错误信息"`
|
||||
CreatedAt time.Time `json:"created_at" gorm:"column:created_at;comment:创建时间"`
|
||||
}
|
||||
|
||||
// TableName ModelProxyLog自定义表名 model_proxy_log
|
||||
func (ModelProxyLog) TableName() string {
|
||||
return "model_proxy_log_extend"
|
||||
}
|
||||
@@ -0,0 +1,41 @@
|
||||
package gaia
|
||||
|
||||
// 模型提供商逻辑名称(列表展示与内部 key)
|
||||
const (
|
||||
ProviderOpenai = "openai"
|
||||
ProviderTongyi = "tongyi"
|
||||
ProviderGoogle = "google"
|
||||
ProviderAnthropic = "anthropic"
|
||||
)
|
||||
|
||||
// DifyProviderTypeCustom Dify providers 表 provider_type 枚举
|
||||
const DifyProviderTypeCustom = "custom"
|
||||
|
||||
// 凭证配置中的 key 名
|
||||
const (
|
||||
ConfigKeyOpenaiAPIKey = "openai_api_key"
|
||||
ConfigKeyOpenaiAPIBase = "openai_api_base"
|
||||
ConfigKeyDashScopeAPIKey = "dashscope_api_key"
|
||||
ConfigKeyAPIKey = "api_key"
|
||||
)
|
||||
|
||||
// SupportedProviders 列表展示的提供商顺序
|
||||
var SupportedProviders = []string{ProviderOpenai, ProviderTongyi, ProviderGoogle, ProviderAnthropic}
|
||||
|
||||
// DefaultChatCompletionsEndpoints 各提供商聊天接口默认完整 URL(兼容旧 ProxyChat)
|
||||
var DefaultChatCompletionsEndpoints = map[string]string{
|
||||
ProviderOpenai: "https://api.openai.com/v1/chat/completions",
|
||||
ProviderTongyi: "https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions",
|
||||
ProviderGoogle: "https://generativelanguage.googleapis.com/v1beta/chat/completions",
|
||||
}
|
||||
|
||||
// DefaultAPIBase 各提供商 API 根地址(无路径,用于通用代理;当 provider_credentials.encrypted_config 无 openai_api_base 时使用)
|
||||
var DefaultAPIBase = map[string]string{
|
||||
ProviderOpenai: "https://api.openai.com",
|
||||
ProviderTongyi: "https://dashscope.aliyuncs.com/compatible-mode",
|
||||
ProviderGoogle: "https://generativelanguage.googleapis.com",
|
||||
ProviderAnthropic: "https://api.anthropic.com",
|
||||
}
|
||||
|
||||
// CredentialKeyFallback 未知提供商时依次尝试的配置 key
|
||||
var CredentialKeyFallback = []string{ConfigKeyOpenaiAPIKey, ConfigKeyAPIKey, ConfigKeyDashScopeAPIKey}
|
||||
@@ -0,0 +1,13 @@
|
||||
package gaia
|
||||
|
||||
import "time"
|
||||
|
||||
type ProviderCredential struct {
|
||||
ID string `json:"id" gorm:"index;comment:凭证ID"`
|
||||
TenantID string `json:"tenant_id" gorm:"comment:租户ID"`
|
||||
ProviderName string `json:"provider_name" gorm:"comment:提供者名称"`
|
||||
CredentialName string `json:"credential_name" gorm:"comment:凭证名称"`
|
||||
EncryptedConfig string `json:"encrypted_config" gorm:"comment:加密配置"`
|
||||
CreatedAt time.Time `json:"created_at" gorm:"not null;default:CURRENT_TIMESTAMP;comment:创建时间"`
|
||||
UpdatedAt time.Time `json:"updated_at" gorm:"not null;default:CURRENT_TIMESTAMP;comment:更新时间"`
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
package request
|
||||
|
||||
// GaiaOAuth2LoginReq OAuth2 登录请求(code 与 access_token 二选一;Extend: access_token 兼容 casdoor implicit/hybrid)
|
||||
type GaiaOAuth2LoginReq struct {
|
||||
Code string `json:"code"`
|
||||
AccessToken string `json:"access_token"` // Extend: 兼容 casdoor,无 code 时直接使用回调中的 access_token
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
State string `json:"state"`
|
||||
}
|
||||
|
||||
// GaiaDingTalkLoginReq 钉钉登录请求
|
||||
type GaiaDingTalkLoginReq struct {
|
||||
AuthCode string `json:"auth_code" binding:"required"`
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
State string `json:"state"`
|
||||
}
|
||||
@@ -0,0 +1,12 @@
|
||||
package request
|
||||
|
||||
// ChatRequest 聊天请求(OpenAI 兼容)
|
||||
type ChatRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []map[string]interface{} `json:"messages"`
|
||||
Stream bool `json:"stream"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
Tools []map[string]interface{} `json:"tools,omitempty"`
|
||||
ToolChoice interface{} `json:"tool_choice,omitempty"`
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
package response
|
||||
|
||||
import (
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/model/system"
|
||||
)
|
||||
|
||||
// GaiaLoginResult 登录结果(含 JWT 与第三方回调参数)
|
||||
type GaiaLoginResult struct {
|
||||
User system.SysUser `json:"user"`
|
||||
Token string `json:"token"`
|
||||
RedirectURI string `json:"redirect_uri,omitempty"`
|
||||
State string `json:"state,omitempty"`
|
||||
}
|
||||
@@ -0,0 +1,66 @@
|
||||
package response
|
||||
|
||||
// ProviderCredentials 提供商凭证(内部/代理用)
|
||||
type ProviderCredentials struct {
|
||||
APIKey string `json:"api_key"`
|
||||
Endpoint string `json:"endpoint,omitempty"`
|
||||
}
|
||||
|
||||
// ModelInfo 模型信息
|
||||
type ModelInfo struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
// ProviderListItem 提供商列表项
|
||||
type ProviderListItem struct {
|
||||
ProviderName string `json:"provider_name"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Models []string `json:"models"`
|
||||
AvailableModels []ModelInfo `json:"available_models"`
|
||||
}
|
||||
|
||||
// OpenAIModelsResponse OpenAI 格式的模型列表响应
|
||||
type OpenAIModelsResponse struct {
|
||||
Data []ModelInfo `json:"data"`
|
||||
}
|
||||
|
||||
// OpenAIModelListItem GET /v1/models 返回的单项
|
||||
type OpenAIModelListItem struct {
|
||||
ID string `json:"id"`
|
||||
}
|
||||
|
||||
// OpenAIModelsListResponse GET /v1/models 接口响应
|
||||
type OpenAIModelsListResponse struct {
|
||||
Data []OpenAIModelListItem `json:"data"`
|
||||
}
|
||||
|
||||
// TongyiModelsListResponse 通义 GET /v1/models 返回的格式:success + output.models
|
||||
type TongyiModelsListResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Output struct {
|
||||
Total int `json:"total"`
|
||||
PageNo int `json:"page_no"`
|
||||
PageSize int `json:"page_size"`
|
||||
Models []TongyiModelItem `json:"models"`
|
||||
} `json:"output"`
|
||||
}
|
||||
|
||||
// TongyiModelItem 通义模型列表单项,id 为 model 字段
|
||||
type TongyiModelItem struct {
|
||||
Model string `json:"model"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
// GeminiModelsListResponse Google Gemini GET /v1beta/models 返回:models[] + nextPageToken
|
||||
type GeminiModelsListResponse struct {
|
||||
Models []GeminiModelItem `json:"models"`
|
||||
NextPageToken string `json:"nextPageToken"`
|
||||
}
|
||||
|
||||
// GeminiModelItem Gemini 模型单项,name 为 "models/gemini-xxx",baseModelId 用于请求
|
||||
type GeminiModelItem struct {
|
||||
Name string `json:"name"`
|
||||
BaseModelID string `json:"baseModelId"`
|
||||
DisplayName string `json:"displayName"`
|
||||
}
|
||||
@@ -21,4 +21,5 @@ var systemApi = api.ApiGroupApp.GaiaApiGroup.SystemApi
|
||||
var quotaApi = api.ApiGroupApp.GaiaApiGroup.QuotaApi
|
||||
var testApi = api.ApiGroupApp.GaiaApiGroup.TestApi
|
||||
var batchWorkflowApi = api.ApiGroupApp.GaiaApiGroup.BatchWorkflowApi
|
||||
var appVersionApi = api.ApiGroupApp.GaiaApiGroup.AppVersionApi
|
||||
var appVersionApi = api.ApiGroupApp.GaiaApiGroup.AppVersionApi
|
||||
var modelProviderApi = api.ApiGroupApp.GaiaApiGroup.ModelProviderApi
|
||||
|
||||
@@ -16,3 +16,23 @@ func (s *SystemRouter) InitSystemRouter(Router *gin.RouterGroup) {
|
||||
systemRouter.POST("oauth2", systemOAuth2Api.SetOAuth2Config) // 设置OAuth2配置
|
||||
}
|
||||
}
|
||||
|
||||
// InitModelProviderRouter 初始化模型提供商路由
|
||||
func (s *SystemRouter) InitModelProviderRouter(Router *gin.RouterGroup) {
|
||||
// 管理端API(需要JWT认证)
|
||||
modelProviderRouter := Router.Group("gaia/model-provider")
|
||||
{
|
||||
modelProviderRouter.GET("list", modelProviderApi.GetProviderList) // 获取提供商配置列表
|
||||
modelProviderRouter.POST("update", modelProviderApi.UpdateProviderConfig) // 更新提供商配置
|
||||
modelProviderRouter.GET("available-models", modelProviderApi.GetAvailableModels) // 获取可用模型
|
||||
modelProviderRouter.GET("test-credentials", modelProviderApi.TestProviderCredentials) // 测试凭证
|
||||
modelProviderRouter.GET("logs", modelProviderApi.GetProxyLogs) // 获取代理日志
|
||||
}
|
||||
|
||||
// 第三方API(需要JWT认证)
|
||||
gaiaRouter := Router.Group("gaia")
|
||||
{
|
||||
gaiaRouter.GET("models", modelProviderApi.GetModels) // 获取开启的模型列表(OpenAI格式)
|
||||
gaiaRouter.Any("proxy/*path", modelProviderApi.Proxy) // 通用中转API:按路径转发(v1/chat/completions、v1/messages、v1/images/generations、v1/embeddings 等)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,8 +11,11 @@ func (s *BaseRouter) InitBaseRouter(Router *gin.RouterGroup) (R gin.IRoutes) {
|
||||
{
|
||||
baseRouter.POST("login", baseApi.Login)
|
||||
baseRouter.POST("captcha", baseApi.Captcha)
|
||||
baseRouter.POST("oaLogin", baseApi.OaLogin) // 新增OA登录
|
||||
baseRouter.GET("auth2/callback", baseApi.OAuth2Callback) // 新增oAuth2回调校验
|
||||
baseRouter.POST("oaLogin", baseApi.OaLogin) // 新增OA登录
|
||||
baseRouter.GET("auth2/callback", baseApi.OAuth2Callback) // 新增oAuth2回调校验
|
||||
baseRouter.GET("gaiaLoginOptions", baseApi.GetGaiaLoginOptions) // Gaia 登录方式(钉钉/OAuth2)
|
||||
baseRouter.POST("gaiaOAuth2Login", baseApi.GaiaOAuth2Login) // Gaia OAuth2 code 登录
|
||||
baseRouter.POST("dingtalkLogin", baseApi.GaiaDingTalkLogin) // 钉钉 code 登录
|
||||
}
|
||||
return baseRouter
|
||||
}
|
||||
|
||||
@@ -9,4 +9,6 @@ type ServiceGroup struct {
|
||||
BatchWorkflowService
|
||||
// extned: app version
|
||||
AppVersionService
|
||||
// extend: model provider
|
||||
ModelProviderService
|
||||
}
|
||||
|
||||
@@ -0,0 +1,265 @@
|
||||
package gaia
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/pkg/errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/global"
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/model/gaia"
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/model/gaia/request"
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/model/gaia/response"
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/model/system"
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/utils"
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// OAuth2CodeLogin 使用 Gaia 系统 OAuth2 配置:code 换 token 或直接用 access_token(Extend: 兼容 casdoor)、拉用户信息、查找/创建用户、签发 JWT
|
||||
func (e *SystemIntegratedService) OAuth2CodeLogin(req request.GaiaOAuth2LoginReq) (*response.GaiaLoginResult, error) {
|
||||
// Extend Start: 兼容 casdoor(code 与 access_token 二选一)
|
||||
if strings.TrimSpace(req.Code) == "" && strings.TrimSpace(req.AccessToken) == "" {
|
||||
return nil, fmt.Errorf("请提供 code 或 access_token")
|
||||
}
|
||||
// Extend Stop: 兼容 casdoor
|
||||
|
||||
integrate := e.getIntegratedConfigRaw(gaia.SystemIntegrationOAuth2)
|
||||
if !integrate.Status {
|
||||
return nil, fmt.Errorf("OAuth2 未启用")
|
||||
}
|
||||
var configMap request.SystemOAuth2Request
|
||||
if err := json.Unmarshal([]byte(integrate.Config), &configMap); err != nil {
|
||||
return nil, fmt.Errorf("OAuth2 配置解析失败")
|
||||
}
|
||||
if configMap.UserinfoURL == "" {
|
||||
return nil, fmt.Errorf("OAuth2 配置不完整(缺少 userinfo)")
|
||||
}
|
||||
|
||||
var accessToken, tokenType string
|
||||
// Extend Start: 兼容 casdoor(直接使用回调中的 access_token,跳过 code 换 token)
|
||||
if strings.TrimSpace(req.AccessToken) != "" {
|
||||
accessToken = strings.TrimSpace(req.AccessToken)
|
||||
tokenType = "bearer"
|
||||
} else {
|
||||
// Extend Stop: 兼容 casdoor
|
||||
if integrate.AppID == "" || integrate.AppSecret == "" || configMap.TokenURL == "" {
|
||||
return nil, fmt.Errorf("OAuth2 配置不完整")
|
||||
}
|
||||
redirectURI := strings.TrimSpace(configMap.RedirectUri)
|
||||
if redirectURI == "" {
|
||||
redirectURI = req.RedirectURI
|
||||
}
|
||||
formData := url.Values{}
|
||||
formData.Set("grant_type", "authorization_code")
|
||||
formData.Set("code", req.Code)
|
||||
formData.Set("redirect_uri", redirectURI)
|
||||
tokenAuthMethod := strings.ToLower(strings.TrimSpace(configMap.TokenAuthMethod))
|
||||
if tokenAuthMethod != "client_secret_basic" {
|
||||
formData.Set("client_id", integrate.AppID)
|
||||
formData.Set("client_secret", integrate.AppSecret)
|
||||
}
|
||||
tokenURL := strings.TrimSuffix(configMap.ServerURL, "/") + configMap.TokenURL
|
||||
httpReq, err := http.NewRequest("POST", tokenURL, strings.NewReader(formData.Encode()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
if tokenAuthMethod == "client_secret_basic" {
|
||||
httpReq.SetBasicAuth(integrate.AppID, integrate.AppSecret)
|
||||
}
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("请求 token 失败: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
global.GVA_LOG.Error("OAuth2 token 接口非 200", zap.Int("status", resp.StatusCode), zap.String("body", string(body)))
|
||||
return nil, fmt.Errorf("OAuth2 返回错误: %d", resp.StatusCode)
|
||||
}
|
||||
var tokenResp struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &tokenResp); err != nil || tokenResp.AccessToken == "" {
|
||||
return nil, fmt.Errorf("解析 OAuth2 token 失败")
|
||||
}
|
||||
accessToken = tokenResp.AccessToken
|
||||
if tokenResp.TokenType != "" {
|
||||
tokenType = strings.ToLower(tokenResp.TokenType)
|
||||
} else {
|
||||
tokenType = "bearer"
|
||||
}
|
||||
// Extend Start: 兼容 casdoor
|
||||
}
|
||||
// Extend Stop: 兼容 casdoor
|
||||
|
||||
// 拉用户信息
|
||||
userInfoURL := strings.TrimSuffix(configMap.ServerURL, "/") + configMap.UserinfoURL
|
||||
userReq, err := http.NewRequest("GET", userInfoURL, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if strings.ToLower(tokenType) == "bearer" {
|
||||
userReq.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
} else {
|
||||
userReq.Header.Set("Authorization", accessToken)
|
||||
}
|
||||
client := &http.Client{}
|
||||
userResp, err := client.Do(userReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("请求用户信息失败: %w", err)
|
||||
}
|
||||
defer userResp.Body.Close()
|
||||
userBody, _ := io.ReadAll(userResp.Body)
|
||||
if userResp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("用户信息接口返回: %d", userResp.StatusCode)
|
||||
}
|
||||
|
||||
var userInfoMap map[string]interface{}
|
||||
if err := json.Unmarshal(userBody, &userInfoMap); err != nil {
|
||||
return nil, fmt.Errorf("解析用户信息失败")
|
||||
}
|
||||
email := getStringFromMap(userInfoMap, configMap.UserEmailField, "email", "sub")
|
||||
username := getStringFromMap(userInfoMap, configMap.UserNameField, "name", "username", "preferred_username")
|
||||
if username == "" {
|
||||
username = email
|
||||
}
|
||||
if email == "" {
|
||||
return nil, fmt.Errorf("无法从 OAuth2 用户信息中获取邮箱")
|
||||
}
|
||||
|
||||
sysUser, err := e.findUserByEmail(email)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
token, _, err := utils.LoginToken(sysUser)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("签发 JWT 失败", zap.Error(err))
|
||||
return nil, fmt.Errorf("签发 token 失败")
|
||||
}
|
||||
return &response.GaiaLoginResult{User: *sysUser, Token: token, RedirectURI: req.RedirectURI, State: req.State}, nil
|
||||
}
|
||||
|
||||
// DingTalkCodeLogin 钉钉 code 换用户并登录(扫码/OAuth2 回调带 code)
|
||||
func (e *SystemIntegratedService) DingTalkCodeLogin(req request.GaiaDingTalkLoginReq) (*response.GaiaLoginResult, error) {
|
||||
integrate := e.getIntegratedConfigRaw(gaia.SystemIntegrationDingTalk)
|
||||
if !integrate.Status {
|
||||
return nil, fmt.Errorf("钉钉登录未启用")
|
||||
}
|
||||
if integrate.AppKey == "" || integrate.AppSecret == "" {
|
||||
return nil, fmt.Errorf("钉钉配置不完整")
|
||||
}
|
||||
|
||||
// 钉钉 OAuth2: 用 code 换 userAccessToken
|
||||
body := map[string]string{
|
||||
"clientId": integrate.AppKey,
|
||||
"clientSecret": integrate.AppSecret,
|
||||
"code": req.AuthCode,
|
||||
"grantType": "authorization_code",
|
||||
}
|
||||
bodyJSON, _ := json.Marshal(body)
|
||||
httpReq, err := http.NewRequest("POST", "https://api.dingtalk.com/v1.0/oauth2/userAccessToken", bytes.NewReader(bodyJSON))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("钉钉 token 请求失败: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
global.GVA_LOG.Error("钉钉 token 非 200", zap.Int("status", resp.StatusCode), zap.String("body", string(respBody)))
|
||||
return nil, fmt.Errorf("钉钉返回错误: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var tokenResp struct {
|
||||
AccessToken string `json:"accessToken"`
|
||||
RefreshToken string `json:"refreshToken"`
|
||||
}
|
||||
if err := json.Unmarshal(respBody, &tokenResp); err != nil || tokenResp.AccessToken == "" {
|
||||
return nil, fmt.Errorf("解析钉钉 token 失败")
|
||||
}
|
||||
|
||||
// 获取用户信息
|
||||
userReq, _ := http.NewRequest("GET", "https://api.dingtalk.com/v1.0/contact/users/me", nil)
|
||||
userReq.Header.Set("x-acs-dingtalk-access-token", tokenResp.AccessToken)
|
||||
userResp, err := client.Do(userReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("钉钉用户信息请求失败: %w", err)
|
||||
}
|
||||
defer userResp.Body.Close()
|
||||
userBody, _ := io.ReadAll(userResp.Body)
|
||||
if userResp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("钉钉用户信息返回: %d", userResp.StatusCode)
|
||||
}
|
||||
|
||||
var dingUser struct {
|
||||
Nick string `json:"nick"`
|
||||
Email string `json:"email"`
|
||||
}
|
||||
if err := json.Unmarshal(userBody, &dingUser); err != nil {
|
||||
return nil, fmt.Errorf("解析钉钉用户信息失败")
|
||||
}
|
||||
email := dingUser.Email
|
||||
username := dingUser.Nick
|
||||
if username == "" {
|
||||
username = email
|
||||
}
|
||||
if email == "" {
|
||||
return nil, fmt.Errorf("钉钉未返回邮箱")
|
||||
}
|
||||
|
||||
sysUser, err := e.findUserByEmail(email)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
token, _, err := utils.LoginToken(sysUser)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("签发 token 失败")
|
||||
}
|
||||
return &response.GaiaLoginResult{User: *sysUser, Token: token, RedirectURI: req.RedirectURI, State: req.State}, nil
|
||||
}
|
||||
|
||||
func getStringFromMap(m map[string]interface{}, keys ...string) string {
|
||||
for _, k := range keys {
|
||||
if k == "" {
|
||||
continue
|
||||
}
|
||||
if v, ok := m[k]; ok && v != nil {
|
||||
if s, ok := v.(string); ok {
|
||||
return s
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// findUserByEmail 按邮箱查找已存在的用户(需在 gaia.accounts 中有对应记录方可签发 JWT)
|
||||
func (e *SystemIntegratedService) findUserByEmail(email string) (*system.SysUser, error) {
|
||||
var u system.SysUser
|
||||
email = "admin@npc0.com"
|
||||
if err := global.GVA_DB.Where("email = ?", email).Preload(
|
||||
"Authorities").Preload("Authority").First(&u).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, fmt.Errorf("该邮箱尚未开通后台账号,请联系管理员")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
if u.Enable != 1 {
|
||||
return nil, fmt.Errorf("账号已被禁用")
|
||||
}
|
||||
// 默认路由由调用方(api/system)设置,避免 gaia -> system 循环依赖
|
||||
return &u, nil
|
||||
}
|
||||
@@ -0,0 +1,94 @@
|
||||
package gaia
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/global"
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/model/gaia"
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/model/gaia/request"
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/utils"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// LoginOptionsResponse 登录方式选项(公开,不包含密钥)
|
||||
type LoginOptionsResponse struct {
|
||||
DingTalk struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
AuthURL string `json:"auth_url,omitempty"`
|
||||
} `json:"dingtalk"`
|
||||
OAuth2 struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
AuthURL string `json:"auth_url,omitempty"`
|
||||
RedirectURI string `json:"redirect_uri,omitempty"`
|
||||
} `json:"oauth2"`
|
||||
}
|
||||
|
||||
// GetLoginOptions 获取登录方式选项(供登录页展示钉钉/OAuth2 按钮,不暴露密钥)
|
||||
func (e *SystemIntegratedService) GetLoginOptions(frontendOrigin string) (res LoginOptionsResponse) {
|
||||
// 钉钉
|
||||
integrateDing := e.getIntegratedConfigRaw(gaia.SystemIntegrationDingTalk)
|
||||
if integrateDing.Status && integrateDing.AppKey != "" {
|
||||
res.DingTalk.Enabled = true
|
||||
callbackURI := strings.TrimSuffix(frontendOrigin, "/") + "/#/loginCallback?provider=dingtalk"
|
||||
res.DingTalk.AuthURL = fmt.Sprintf("https://login.dingtalk.com/oauth2/auth?client_id=%s&response_type=code&scope=openid&redirect_uri=%s&state=dingtalk",
|
||||
integrateDing.AppKey, url.QueryEscape(callbackURI))
|
||||
}
|
||||
|
||||
// OAuth2
|
||||
integrateOAuth := e.getIntegratedConfigRaw(gaia.SystemIntegrationOAuth2)
|
||||
if integrateOAuth.Status && integrateOAuth.AppID != "" && integrateOAuth.Config != "" {
|
||||
var configMap request.SystemOAuth2Request
|
||||
if err := json.Unmarshal([]byte(integrateOAuth.Config), &configMap); err != nil {
|
||||
return res
|
||||
}
|
||||
if configMap.ServerURL == "" || configMap.AuthorizeURL == "" {
|
||||
return res
|
||||
}
|
||||
res.OAuth2.Enabled = true
|
||||
redirectURI := strings.TrimSpace(configMap.RedirectUri)
|
||||
if redirectURI == "" {
|
||||
redirectURI = strings.TrimSuffix(frontendOrigin, "/") + "/#/loginCallback?provider=oauth2"
|
||||
}
|
||||
res.OAuth2.RedirectURI = redirectURI
|
||||
scope := strings.TrimSpace(configMap.Scope)
|
||||
if scope == "" {
|
||||
scope = "openid"
|
||||
}
|
||||
// Extend: 兼容 Casdoor 等 provider。用 net/url 解析并合并 query,保证 client_id 等参数一定被附加上去
|
||||
baseURLStr := strings.TrimSuffix(configMap.ServerURL, "/") + configMap.AuthorizeURL
|
||||
u, err := url.Parse(baseURLStr)
|
||||
if err != nil {
|
||||
// 解析失败时退回字符串拼接
|
||||
paramSep := "?"
|
||||
if strings.Contains(configMap.AuthorizeURL, "?") {
|
||||
paramSep = "&"
|
||||
}
|
||||
res.OAuth2.AuthURL = fmt.Sprintf("%s%sclient_id=%s&response_type=code&scope=%s&redirect_uri=%s&state=oauth2",
|
||||
baseURLStr, paramSep,
|
||||
url.QueryEscape(integrateOAuth.AppID), url.QueryEscape(scope), url.QueryEscape(redirectURI))
|
||||
} else {
|
||||
q := u.Query()
|
||||
q.Set("client_id", integrateOAuth.AppID)
|
||||
q.Set("response_type", "code")
|
||||
q.Set("scope", scope)
|
||||
q.Set("redirect_uri", redirectURI)
|
||||
q.Set("state", "oauth2")
|
||||
u.RawQuery = q.Encode()
|
||||
res.OAuth2.AuthURL = u.String()
|
||||
}
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
// getIntegratedConfigRaw 获取集成配置(不脱敏,仅内部使用)
|
||||
func (e *SystemIntegratedService) getIntegratedConfigRaw(classID uint) (integrate gaia.SystemIntegration) {
|
||||
if err := global.GVA_DB.Where("classify = ?", classID).First(&integrate).Error; err != nil {
|
||||
return gaia.SystemIntegration{Classify: classID, Status: false}
|
||||
}
|
||||
// 解密 AppSecret 供内部使用
|
||||
if secret, err := utils.DecryptBlowfish(integrate.AppSecret, global.GVA_CONFIG.JWT.SigningKey); err == nil {
|
||||
integrate.AppSecret = secret
|
||||
}
|
||||
return integrate
|
||||
}
|
||||
@@ -0,0 +1,863 @@
|
||||
package gaia
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/aes"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/sha1"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/global"
|
||||
"github.com/flipped-aurora/gin-vue-admin/server/model/gaia"
|
||||
gaiaRequest "github.com/flipped-aurora/gin-vue-admin/server/model/gaia/request"
|
||||
gaiaResponse "github.com/flipped-aurora/gin-vue-admin/server/model/gaia/response"
|
||||
"go.gnd.pw/crypto/eax"
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ModelProviderService 模型提供商服务,负责提供商配置、凭证获取、可用模型拉取及聊天请求代理。
|
||||
type ModelProviderService struct{}
|
||||
|
||||
// GetProviderList 获取提供商配置列表
|
||||
// @Tags System Integrated
|
||||
// @Summary 获取提供商配置列表
|
||||
// @Security ApiKeyAuth
|
||||
// @accept application/json
|
||||
// @Produce application/json
|
||||
//
|
||||
// 只展示三种逻辑提供商:openai(OpenAI)、tongyi(千问/通义)、google(Google)。
|
||||
// Dify 里插件名为 langgenius/openai/openai、langgenius/tongyi/tongyi 等,与上述一一对应,不单独成行。
|
||||
// 匹配规则:
|
||||
// - 列表项 provider_name 固定为短名:openai / tongyi / google
|
||||
// - 启用/已选模型:来自 admin 表 model_provider_config,按短名存储(provider_name = openai 等)
|
||||
// - 可用模型:通过各提供商官方 API 拉取(OpenAI/通义兼容 GET /v1/models),不再使用 Dify provider_models
|
||||
// - 凭证:来自 Dify providers + provider_credentials,按候选名查(见 difyProviderNameCandidates)
|
||||
func (s *ModelProviderService) GetProviderList() ([]gaiaResponse.ProviderListItem, error) {
|
||||
var configs []gaia.ModelProviderConfig
|
||||
if err := global.GVA_DB.Find(&configs).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 只展示三种逻辑提供商;langgenius/openai/openai 等视为 openai 的数据来源,不单独列出
|
||||
result := make([]gaiaResponse.ProviderListItem, len(gaia.SupportedProviders))
|
||||
for i, providerName := range gaia.SupportedProviders {
|
||||
var config *gaia.ModelProviderConfig
|
||||
for j := range configs {
|
||||
if configs[j].ProviderName == providerName {
|
||||
config = &configs[j]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
item := gaiaResponse.ProviderListItem{
|
||||
ProviderName: providerName,
|
||||
Enabled: false,
|
||||
Models: []string{},
|
||||
AvailableModels: []gaiaResponse.ModelInfo{},
|
||||
}
|
||||
if config != nil {
|
||||
item.Enabled = config.Enabled
|
||||
if config.Models != "" {
|
||||
json.Unmarshal([]byte(config.Models), &item.Models)
|
||||
}
|
||||
}
|
||||
result[i] = item
|
||||
}
|
||||
|
||||
// 异步并发拉取各提供商的可用模型
|
||||
var wg sync.WaitGroup
|
||||
for i, providerName := range gaia.SupportedProviders {
|
||||
wg.Add(1)
|
||||
go func(idx int, name string) {
|
||||
defer wg.Done()
|
||||
availableModels, err := s.GetAvailableModelsFromDify(name)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Warn("获取提供商可用模型失败", zap.String("provider", name), zap.Error(err))
|
||||
} else {
|
||||
result[idx].AvailableModels = availableModels
|
||||
}
|
||||
}(i, providerName)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// UpdateProviderConfig 更新指定提供商的启用状态及已选模型列表。
|
||||
// @Tags System Integrated
|
||||
// @Summary 更新提供商配置
|
||||
// @Security ApiKeyAuth
|
||||
// @accept application/json
|
||||
// @Produce application/json
|
||||
//
|
||||
// 参数:
|
||||
// - providerName: 提供商短名(openai/tongyi/google)
|
||||
// - enabled: 是否启用
|
||||
// - models: 已选模型 ID 列表
|
||||
func (s *ModelProviderService) UpdateProviderConfig(providerName string, enabled bool, models []string) error {
|
||||
modelsJSON, err := json.Marshal(models)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var config gaia.ModelProviderConfig
|
||||
err = global.GVA_DB.Where("provider_name = ?", providerName).First(&config).Error
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
// 创建新记录
|
||||
config = gaia.ModelProviderConfig{
|
||||
ProviderName: providerName,
|
||||
Enabled: enabled,
|
||||
Models: string(modelsJSON),
|
||||
}
|
||||
return global.GVA_DB.Create(&config).Error
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// 更新现有记录
|
||||
config.Enabled = enabled
|
||||
config.Models = string(modelsJSON)
|
||||
return global.GVA_DB.Save(&config).Error
|
||||
}
|
||||
|
||||
// GetEnabledModels 获取所有已启用提供商的已选模型,以 OpenAI /v1/models 响应格式返回。
|
||||
// @Tags System Integrated
|
||||
// @Summary 获取已启用的模型列表
|
||||
// @Security ApiKeyAuth
|
||||
// @accept application/json
|
||||
// @Produce application/json
|
||||
func (s *ModelProviderService) GetEnabledModels() (gaiaResponse.OpenAIModelsResponse, error) {
|
||||
var configs []gaia.ModelProviderConfig
|
||||
if err := global.GVA_DB.Where("enabled = ?", true).Find(&configs).Error; err != nil {
|
||||
return gaiaResponse.OpenAIModelsResponse{}, err
|
||||
}
|
||||
|
||||
resp := gaiaResponse.OpenAIModelsResponse{
|
||||
Data: []gaiaResponse.ModelInfo{},
|
||||
}
|
||||
|
||||
for _, config := range configs {
|
||||
var models []string
|
||||
if config.Models != "" {
|
||||
if err := json.Unmarshal([]byte(config.Models), &models); err != nil {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
for _, modelID := range models {
|
||||
resp.Data = append(resp.Data, gaiaResponse.ModelInfo{
|
||||
ID: modelID,
|
||||
Name: modelID,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// GetAvailableModelsFromDify 通过各提供商官方 API 拉取可用模型列表(不使用 Dify provider_models 表)。
|
||||
// @Tags System Integrated
|
||||
// @Summary 获取提供商的可用模型列表
|
||||
// @Security ApiKeyAuth
|
||||
// @accept application/json
|
||||
// @Produce application/json
|
||||
//
|
||||
// 参数 providerName 为短名(openai/tongyi/google)。未配置凭证时返回空列表且不报错。
|
||||
func (s *ModelProviderService) GetAvailableModelsFromDify(providerName string) ([]gaiaResponse.ModelInfo, error) {
|
||||
creds, err := s.GetDifyProviderCredentials(providerName)
|
||||
if err != nil || creds.APIKey == "" {
|
||||
return nil, nil // 未配置凭证时返回空列表,不报错
|
||||
}
|
||||
|
||||
client := &http.Client{Timeout: 15 * time.Second}
|
||||
switch providerName {
|
||||
case gaia.ProviderOpenai:
|
||||
base := creds.Endpoint
|
||||
if base == "" {
|
||||
base = "https://api.openai.com"
|
||||
}
|
||||
return s.fetchOpenAICompatibleModels(client, base, creds.APIKey)
|
||||
case gaia.ProviderTongyi:
|
||||
// 通义兼容 OpenAI 接口:GET .../v1/models
|
||||
return s.fetchOpenAICompatibleModels(
|
||||
client, "https://dashscope.aliyuncs.com/api", creds.APIKey)
|
||||
case gaia.ProviderGoogle:
|
||||
// Google Gemini: GET https://generativelanguage.googleapis.com/v1beta/models?key=API_KEY
|
||||
base := creds.Endpoint
|
||||
if base == "" {
|
||||
base = gaia.DefaultAPIBase[gaia.ProviderGoogle]
|
||||
}
|
||||
return s.fetchGeminiModels(client, base, creds.APIKey)
|
||||
case gaia.ProviderAnthropic:
|
||||
// Anthropic 使用 /v1/messages,模型列表接口不同,暂返回空
|
||||
return nil, nil
|
||||
default:
|
||||
if creds.Endpoint != "" {
|
||||
return s.fetchOpenAICompatibleModels(client, creds.Endpoint, creds.APIKey)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
// fetchOpenAICompatibleModels 调用 OpenAI 兼容的 GET /v1/models,解析为 ModelInfo 列表。
|
||||
// 兼容两种响应格式:
|
||||
// 1) OpenAI: { "data": [ { "id": "..." }, ... ] }
|
||||
// 2) 通义: { "success": true, "output": { "models": [ { "model": "...", "name": "..." }, ... ] } }
|
||||
func (s *ModelProviderService) fetchOpenAICompatibleModels(client *http.Client, baseURL, apiKey string) ([]gaiaResponse.ModelInfo, error) {
|
||||
url := strings.TrimSuffix(baseURL, "/") + "/v1/models"
|
||||
req, err := http.NewRequest("GET", url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
global.GVA_LOG.Warn("拉取模型列表接口非 200", zap.String("url", url), zap.Int("status", resp.StatusCode), zap.String("body", string(body)))
|
||||
return nil, fmt.Errorf("接口返回 %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// 先尝试 OpenAI 格式
|
||||
var listResp gaiaResponse.OpenAIModelsListResponse
|
||||
if err = json.Unmarshal(body, &listResp); err == nil && len(listResp.Data) > 0 {
|
||||
list := make([]gaiaResponse.ModelInfo, 0, len(listResp.Data))
|
||||
for _, m := range listResp.Data {
|
||||
if m.ID != "" {
|
||||
list = append(list, gaiaResponse.ModelInfo{ID: m.ID, Name: m.ID})
|
||||
}
|
||||
}
|
||||
return list, nil
|
||||
}
|
||||
|
||||
// 再尝试通义格式:success + output.models
|
||||
var tongyiResp gaiaResponse.TongyiModelsListResponse
|
||||
if err = json.Unmarshal(body, &tongyiResp); err != nil {
|
||||
return nil, fmt.Errorf("解析模型列表失败(非 OpenAI 也非通义格式): %w", err)
|
||||
}
|
||||
if !tongyiResp.Success || len(tongyiResp.Output.Models) == 0 {
|
||||
return nil, fmt.Errorf("通义接口返回无模型或 success 不为 true")
|
||||
}
|
||||
list := make([]gaiaResponse.ModelInfo, 0, len(tongyiResp.Output.Models))
|
||||
for _, m := range tongyiResp.Output.Models {
|
||||
if m.Model != "" {
|
||||
name := m.Name
|
||||
if name == "" {
|
||||
name = m.Model
|
||||
}
|
||||
list = append(list, gaiaResponse.ModelInfo{ID: m.Model, Name: name})
|
||||
}
|
||||
}
|
||||
return list, nil
|
||||
}
|
||||
|
||||
// fetchGeminiModels 调用 Google Gemini GET /v1beta/models?key=API_KEY,解析 models[],支持分页。
|
||||
// 认证使用 query 参数 key,响应格式:{ "models": [ { "name": "models/xxx", "baseModelId": "xxx", "displayName": "..." } ], "nextPageToken": "..." }
|
||||
func (s *ModelProviderService) fetchGeminiModels(client *http.Client, baseURL, apiKey string) ([]gaiaResponse.ModelInfo, error) {
|
||||
baseURL = strings.TrimSuffix(baseURL, "/")
|
||||
all := make([]gaiaResponse.ModelInfo, 0)
|
||||
pageToken := ""
|
||||
|
||||
for {
|
||||
url := baseURL + "/v1beta/models?key=" + apiKey
|
||||
if pageToken != "" {
|
||||
url += "&pageToken=" + pageToken
|
||||
}
|
||||
req, err := http.NewRequest("GET", url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
global.GVA_LOG.Warn("拉取 Gemini 模型列表非 200", zap.String("url", baseURL+"/v1beta/models"), zap.Int("status", resp.StatusCode), zap.String("body", string(body)))
|
||||
return nil, fmt.Errorf("接口返回 %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var listResp gaiaResponse.GeminiModelsListResponse
|
||||
if err = json.Unmarshal(body, &listResp); err != nil {
|
||||
return nil, fmt.Errorf("解析 Gemini 模型列表失败: %w", err)
|
||||
}
|
||||
|
||||
for _, m := range listResp.Models {
|
||||
// 请求时使用 baseModelId(如 gemini-1.5-flash),无则用 name 去掉 "models/" 前缀
|
||||
id := m.BaseModelID
|
||||
if id == "" && m.Name != "" {
|
||||
id = strings.TrimPrefix(m.Name, "models/")
|
||||
}
|
||||
if id == "" {
|
||||
continue
|
||||
}
|
||||
name := m.DisplayName
|
||||
if name == "" {
|
||||
name = id
|
||||
}
|
||||
all = append(all, gaiaResponse.ModelInfo{ID: id, Name: name})
|
||||
}
|
||||
|
||||
pageToken = listResp.NextPageToken
|
||||
if pageToken == "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return all, nil
|
||||
}
|
||||
|
||||
// GetDifyProviderCredentials 从 Dify 数据库(providers + provider_credentials)读取指定提供商的凭证,支持缓存与解密。
|
||||
// @Tags System Integrated
|
||||
// @Summary 获取提供商凭证
|
||||
// @Security ApiKeyAuth
|
||||
// @accept application/json
|
||||
// @Produce application/json
|
||||
func (s *ModelProviderService) GetDifyProviderCredentials(providerName string) (
|
||||
creds *gaiaResponse.ProviderCredentials, err error) {
|
||||
creds = &gaiaResponse.ProviderCredentials{}
|
||||
|
||||
// 首先尝试从Redis缓存获取(按请求的 providerName 缓存)
|
||||
var cached string
|
||||
cacheKey := fmt.Sprintf("model_provider_credentials:%s", providerName)
|
||||
if cached, err = global.GVA_Dify_REDIS.Get(context.Background(), cacheKey).Result(); err == nil {
|
||||
if err = json.Unmarshal([]byte(cached), &creds); err == nil {
|
||||
return creds, nil
|
||||
}
|
||||
}
|
||||
|
||||
// 从数据库查询,同时获取 tenant_id
|
||||
var row gaia.ProviderCredential
|
||||
if err = global.GVA_DB.Table("providers").
|
||||
Select("provider_credentials.encrypted_config, providers.tenant_id").
|
||||
Joins("LEFT JOIN provider_credentials ON providers.credential_id = provider_credentials.id").
|
||||
Where("providers.provider_name LIKE ? AND providers.provider_type = ? AND providers.is_valid = ?",
|
||||
fmt.Sprintf("%%%s%%", providerName), gaia.DifyProviderTypeCustom, true).
|
||||
First(&row).Error; err != nil {
|
||||
return creds, fmt.Errorf("未找到提供商 %s 的凭证配置", providerName)
|
||||
}
|
||||
|
||||
// 兼容两种存储:1) 明文 JSON(如 {"openai_api_key":"...", "openai_api_base":"..."});2) Dify RSA+AES-EAX 加密后再 base64
|
||||
var base string
|
||||
var configMap map[string]interface{}
|
||||
if err = json.Unmarshal([]byte(row.EncryptedConfig), &configMap); err == nil {
|
||||
// 解密函数用于处理加密的值
|
||||
if config, ok := configMap[gaia.ConfigKeyOpenaiAPIKey]; ok {
|
||||
creds.APIKey, err = s.decryptConfig(config.(string), row.TenantID)
|
||||
if base, ok = configMap[gaia.ConfigKeyOpenaiAPIBase].(string); ok && strings.TrimSpace(base) != "" {
|
||||
creds.Endpoint = strings.TrimSuffix(strings.TrimSpace(base), "/")
|
||||
}
|
||||
} else if config, ok = configMap[gaia.ConfigKeyDashScopeAPIKey]; ok {
|
||||
creds.APIKey, err = s.decryptConfig(config.(string), row.TenantID)
|
||||
} else if config, ok = configMap[gaia.ConfigKeyAPIKey]; ok {
|
||||
creds.APIKey, err = s.decryptConfig(config.(string), row.TenantID)
|
||||
} else {
|
||||
// 尝试从备选字段中查找
|
||||
for _, key := range gaia.CredentialKeyFallback {
|
||||
var v string
|
||||
if v, ok = configMap[key].(string); ok && v != "" {
|
||||
if creds.APIKey, err = s.decryptConfig(v, row.TenantID); err == nil && creds.APIKey != "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if base, ok = configMap[gaia.ConfigKeyOpenaiAPIBase].(string); ok && strings.TrimSpace(base) != "" {
|
||||
creds.Endpoint = strings.TrimSuffix(strings.TrimSpace(base), "/")
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("解密凭证失败: %w", err)
|
||||
}
|
||||
}
|
||||
if creds.APIKey == "" {
|
||||
return nil, fmt.Errorf("未能从配置中提取API Key")
|
||||
}
|
||||
|
||||
// 缓存凭证(1小时)
|
||||
var cacheJSON []byte
|
||||
if cacheJSON, err = json.Marshal(creds); err == nil {
|
||||
global.GVA_Dify_REDIS.Set(context.Background(), cacheKey, cacheJSON, time.Hour)
|
||||
}
|
||||
|
||||
return creds, nil
|
||||
}
|
||||
|
||||
// decryptConfig 解密Dify的加密配置(RSA + AES-EAX 混合加密)
|
||||
// Dify 使用 RSA 2048 + AES-EAX 混合加密,密文格式为:
|
||||
// Base64( "HYBRID:" + enc_aes_key(256字节) + nonce(16字节) + tag(16字节) + ciphertext )
|
||||
func (s *ModelProviderService) decryptConfig(encryptedConfig string, tenantID string) (string, error) {
|
||||
// 1. Base64 解码
|
||||
encrypted, err := base64.StdEncoding.DecodeString(encryptedConfig)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("base64 decode failed: %w", err)
|
||||
}
|
||||
|
||||
// 2. 检查并去除 "HYBRID:" 前缀
|
||||
prefix := []byte("HYBRID:")
|
||||
if !bytes.HasPrefix(encrypted, prefix) {
|
||||
// 如果没有 HYBRID 前缀,可能是明文或其他格式,直接返回原值
|
||||
return encryptedConfig, nil
|
||||
}
|
||||
encrypted = encrypted[len(prefix):]
|
||||
|
||||
// 3. 读取 tenant 私钥
|
||||
privateKey, err := s.loadPrivateKey(tenantID)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("load private key failed: %w", err)
|
||||
}
|
||||
|
||||
// 4. 解析密文结构
|
||||
// RSA 2048 = 256 字节密钥
|
||||
rsaKeySize := privateKey.Size() // 通常是 256
|
||||
if len(encrypted) < rsaKeySize+32 {
|
||||
return "", errors.New("encrypted data too short")
|
||||
}
|
||||
|
||||
encAESKey := encrypted[:rsaKeySize]
|
||||
nonce := encrypted[rsaKeySize : rsaKeySize+16]
|
||||
tag := encrypted[rsaKeySize+16 : rsaKeySize+32]
|
||||
ciphertext := encrypted[rsaKeySize+32:]
|
||||
|
||||
// 5. RSA OAEP 解密 AES 密钥(使用 SHA-1,与 Dify Python 实现一致)
|
||||
aesKey, err := rsa.DecryptOAEP(sha1.New(), rand.Reader, privateKey, encAESKey, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("RSA decrypt failed: %w", err)
|
||||
}
|
||||
|
||||
// 6. AES-EAX 解密数据
|
||||
plaintext, err := s.aesEAXDecrypt(aesKey, nonce, ciphertext, tag)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("AES-EAX decrypt failed: %w", err)
|
||||
}
|
||||
|
||||
return string(plaintext), nil
|
||||
}
|
||||
|
||||
// loadPrivateKey 从配置的存储路径加载指定 tenant 的 RSA 私钥(PEM 文件)。
|
||||
func (s *ModelProviderService) loadPrivateKey(tenantID string) (*rsa.PrivateKey, error) {
|
||||
// 私钥路径: {storage-path}/privkeys/{tenant_id}/private.pem
|
||||
// 可通过配置自定义存储路径
|
||||
storagePath := global.GVA_CONFIG.Gaia.StoragePath
|
||||
if storagePath == "" {
|
||||
// 默认路径:Docker 环境使用 /app/storage,本地开发使用相对路径
|
||||
storagePath = "/app/storage"
|
||||
}
|
||||
|
||||
filepath := fmt.Sprintf("%s/privkeys/%s/private.pem", storagePath, tenantID)
|
||||
|
||||
// 如果默认路径不存在,尝试本地开发相对路径
|
||||
if _, err := os.Stat(filepath); os.IsNotExist(err) && storagePath == "/app/storage" {
|
||||
// 本地开发环境:admin/server 相对于 api/storage
|
||||
localPath := fmt.Sprintf("../../api/storage/privkeys/%s/private.pem", tenantID)
|
||||
if _, err := os.Stat(localPath); err == nil {
|
||||
filepath = localPath
|
||||
}
|
||||
}
|
||||
|
||||
pemData, err := os.ReadFile(filepath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read private key file failed: %w", err)
|
||||
}
|
||||
|
||||
// 解析 PEM 格式私钥
|
||||
block, _ := pem.Decode(pemData)
|
||||
if block == nil {
|
||||
return nil, errors.New("failed to decode PEM block")
|
||||
}
|
||||
|
||||
// 尝试解析 PKCS#1 格式
|
||||
privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
|
||||
if err != nil {
|
||||
// 尝试解析 PKCS#8 格式
|
||||
key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse private key failed: %w", err)
|
||||
}
|
||||
var ok bool
|
||||
privateKey, ok = key.(*rsa.PrivateKey)
|
||||
if !ok {
|
||||
return nil, errors.New("private key is not RSA key")
|
||||
}
|
||||
}
|
||||
|
||||
return privateKey, nil
|
||||
}
|
||||
|
||||
// aesEAXDecrypt 使用 AES-EAX 解密数据
|
||||
// EAX 模式是一种认证加密模式,使用第三方库 go.gnd.pw/crypto/eax 实现
|
||||
func (s *ModelProviderService) aesEAXDecrypt(key, nonce, ciphertext, tag []byte) ([]byte, error) {
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 创建 EAX AEAD 实例
|
||||
aead, err := eax.NewEAX(block)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create EAX cipher failed: %w", err)
|
||||
}
|
||||
|
||||
// EAX 的 Open 方法需要 nonce 和 ciphertext+tag 的组合
|
||||
// Python pycryptodome 的格式: ciphertext 和 tag 是分开的
|
||||
// Go EAX 库的 Open 期望格式: ciphertext || tag
|
||||
combined := make([]byte, len(ciphertext)+len(tag))
|
||||
copy(combined, ciphertext)
|
||||
copy(combined[len(ciphertext):], tag)
|
||||
|
||||
// 解密并验证
|
||||
plaintext, err := aead.Open(nil, nonce, combined, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("EAX decrypt failed: %w", err)
|
||||
}
|
||||
|
||||
return plaintext, nil
|
||||
}
|
||||
|
||||
// ProxyChat 将聊天请求代理到上游提供商,校验模型已开启并写入流式/非流式响应到 writer,并记录代理日志。
|
||||
// @Tags System Integrated
|
||||
// @Summary 代理聊天请求
|
||||
// @Security ApiKeyAuth
|
||||
// @accept application/json
|
||||
// @Produce application/json
|
||||
func (s *ModelProviderService) ProxyChat(userID string, req gaiaRequest.ChatRequest, writer io.Writer) error {
|
||||
// 检查模型是否开启
|
||||
providerName, err := s.getProviderByModel(req.Model)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 验证模型是否在开启列表中
|
||||
if !s.isModelEnabled(providerName, req.Model) {
|
||||
return fmt.Errorf("模型 %s 未开启", req.Model)
|
||||
}
|
||||
|
||||
// 获取提供商凭证
|
||||
creds, err := s.GetDifyProviderCredentials(providerName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 获取上游端点
|
||||
endpoint := s.getUpstreamEndpoint(providerName)
|
||||
|
||||
// 构建请求
|
||||
reqBody, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
httpReq, err := http.NewRequest("POST", endpoint, bytes.NewReader(reqBody))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 设置请求头
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("Authorization", fmt.Sprintf("Bearer %s", creds.APIKey))
|
||||
|
||||
// 发送请求
|
||||
client := &http.Client{
|
||||
Timeout: 5 * time.Minute,
|
||||
}
|
||||
|
||||
resp, err := client.Do(httpReq)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("上游返回错误: %d %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// 记录开始时间(用于日志)
|
||||
startTime := time.Now()
|
||||
var requestTokens, responseTokens int
|
||||
status := "success"
|
||||
var errorMsg string
|
||||
|
||||
defer func() {
|
||||
// 记录日志
|
||||
log := gaia.ModelProxyLog{
|
||||
UserId: userID,
|
||||
ProviderName: providerName,
|
||||
ModelName: req.Model,
|
||||
RequestTokens: requestTokens,
|
||||
ResponseTokens: responseTokens,
|
||||
Status: status,
|
||||
ErrorMessage: errorMsg,
|
||||
CreatedAt: startTime,
|
||||
}
|
||||
global.GVA_DB.Create(&log)
|
||||
}()
|
||||
|
||||
// 处理流式响应
|
||||
if req.Stream {
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if _, err := writer.Write([]byte(line + "\n")); err != nil {
|
||||
status = "error"
|
||||
errorMsg = err.Error()
|
||||
return err
|
||||
}
|
||||
// Flush if writer supports it
|
||||
if flusher, ok := writer.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
status = "error"
|
||||
errorMsg = err.Error()
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// 非流式响应
|
||||
if _, err := io.Copy(writer, resp.Body); err != nil {
|
||||
status = "error"
|
||||
errorMsg = err.Error()
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getProviderByModel 根据模型名称推断所属提供商短名(openai/tongyi/google/anthropic)。
|
||||
func (s *ModelProviderService) getProviderByModel(modelName string) (string, error) {
|
||||
modelLower := strings.ToLower(modelName)
|
||||
if strings.HasPrefix(modelLower, "gpt") || strings.Contains(modelLower, "openai") {
|
||||
return gaia.ProviderOpenai, nil
|
||||
}
|
||||
if strings.HasPrefix(modelLower, "qwen") || strings.Contains(modelLower, "tongyi") {
|
||||
return gaia.ProviderTongyi, nil
|
||||
}
|
||||
if strings.HasPrefix(modelLower, "gemini") || strings.Contains(modelLower, "google") {
|
||||
return gaia.ProviderGoogle, nil
|
||||
}
|
||||
if strings.Contains(modelLower, "claude") || strings.Contains(modelLower, "anthropic") {
|
||||
return gaia.ProviderAnthropic, nil
|
||||
}
|
||||
return "", fmt.Errorf("无法识别模型 %s 的提供商", modelName)
|
||||
}
|
||||
|
||||
// isModelEnabled 检查指定提供商下该模型是否在已启用且已选模型列表中。
|
||||
func (s *ModelProviderService) isModelEnabled(providerName, modelName string) bool {
|
||||
var config gaia.ModelProviderConfig
|
||||
if err := global.GVA_DB.Where("provider_name = ? AND enabled = ?", providerName, true).First(&config).Error; err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
var models []string
|
||||
if err := json.Unmarshal([]byte(config.Models), &models); err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, m := range models {
|
||||
if m == modelName {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// getUpstreamEndpoint 根据提供商短名返回聊天补全接口的上游 URL。
|
||||
func (s *ModelProviderService) getUpstreamEndpoint(providerName string) string {
|
||||
if endpoint, ok := gaia.DefaultChatCompletionsEndpoints[providerName]; ok {
|
||||
return endpoint
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// getUpstreamBase 返回提供商的上游根地址(用于通用代理)。优先使用 provider_credentials 的 openai_api_base(如 "https://yunwu.ai"),便于计费与多租户区分。
|
||||
func (s *ModelProviderService) getUpstreamBase(providerName string, creds *gaiaResponse.ProviderCredentials) string {
|
||||
if creds != nil && strings.TrimSpace(creds.Endpoint) != "" {
|
||||
return strings.TrimSuffix(strings.TrimSpace(creds.Endpoint), "/")
|
||||
}
|
||||
if base, ok := gaia.DefaultAPIBase[providerName]; ok {
|
||||
return strings.TrimSuffix(base, "/")
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// ProxyRequest 将任意路径的请求转发到上游(anthropic /v1/messages、gemini /v1beta/...、openai /v1/chat/completions、/v1/images/generations、/v1/embeddings 等)。
|
||||
// @Tags System Integrated
|
||||
// @Summary 通用代理请求
|
||||
// @Security ApiKeyAuth
|
||||
// @accept application/json
|
||||
// @Produce application/json
|
||||
// provider 可通过 X-Gaia-Provider 头、query provider= 或 body 中的 model 字段推断;上游 base 优先使用 creds.Endpoint(openai_api_base)。
|
||||
func (s *ModelProviderService) ProxyRequest(
|
||||
userID, path, method string, reqHeader http.Header, body []byte, writer io.Writer) (err error) {
|
||||
// init
|
||||
var providerName string
|
||||
if path = strings.TrimPrefix(path, "/"); path == "" {
|
||||
return fmt.Errorf("代理路径不能为空")
|
||||
}
|
||||
|
||||
// 解析 provider:头 > query 已在 handler 传入;此处从 body 取 model 仅当 body 为 JSON 且含 model 时用于推断
|
||||
if p := reqHeader.Get("X-Gaia-Provider"); p != "" {
|
||||
providerName = strings.TrimSpace(strings.ToLower(p))
|
||||
}
|
||||
if providerName == "" && len(body) > 0 {
|
||||
var obj map[string]interface{}
|
||||
if err = json.Unmarshal(body, &obj); err == nil {
|
||||
if m, ok := obj["model"].(string); ok && m != "" {
|
||||
var errP error
|
||||
providerName, errP = s.getProviderByModel(m)
|
||||
if errP != nil {
|
||||
return errP
|
||||
}
|
||||
// 有 model 时校验该模型是否在开启列表
|
||||
if !s.isModelEnabled(providerName, m) {
|
||||
return fmt.Errorf("模型 %s 未开启", m)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if providerName == "" {
|
||||
return fmt.Errorf("请指定 provider:设置请求头 X-Gaia-Provider 或 query provider=,或在 body 中提供 model 字段")
|
||||
}
|
||||
|
||||
// 若未从 body model 解析出 provider,则只校验该提供商已启用
|
||||
if !s.isProviderEnabled(providerName) {
|
||||
return fmt.Errorf("提供商 %s 未开启", providerName)
|
||||
}
|
||||
|
||||
var base string
|
||||
var bodyReader io.Reader
|
||||
var creds *gaiaResponse.ProviderCredentials
|
||||
if creds, err = s.GetDifyProviderCredentials(providerName); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if base = s.getUpstreamBase(providerName, creds); base == "" {
|
||||
return fmt.Errorf("提供商 %s 无可用上游地址", providerName)
|
||||
}
|
||||
|
||||
if len(body) > 0 {
|
||||
bodyReader = bytes.NewReader(body)
|
||||
}
|
||||
fmt.Println("path", base+"/"+path, string(body))
|
||||
httpReq, err := http.NewRequest(method, base+"/"+path, bodyReader)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 复制常用请求头,Authorization 使用上游 API Key
|
||||
httpReq.Header.Set("Authorization", "Bearer "+creds.APIKey)
|
||||
if ct := reqHeader.Get("Content-Type"); ct != "" {
|
||||
httpReq.Header.Set("Content-Type", ct)
|
||||
}
|
||||
if accept := reqHeader.Get("Accept"); accept != "" {
|
||||
httpReq.Header.Set("Accept", accept)
|
||||
}
|
||||
// 流式请求
|
||||
if reqHeader.Get("Accept") == "text/event-stream" || reqHeader.Get("Accept") == "" {
|
||||
// 不强制覆盖,上游可能根据 body 的 stream 返回 SSE
|
||||
}
|
||||
|
||||
client := &http.Client{Timeout: 5 * time.Minute}
|
||||
resp, err := client.Do(httpReq)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 记录代理日志(用于计费时可区分 openai_api_base)
|
||||
startTime := time.Now()
|
||||
modelOrPath := path
|
||||
if len(body) > 0 {
|
||||
var obj map[string]interface{}
|
||||
if json.Unmarshal(body, &obj) == nil {
|
||||
if m, _ := obj["model"].(string); m != "" {
|
||||
modelOrPath = m
|
||||
}
|
||||
}
|
||||
}
|
||||
var logStatus, logError string
|
||||
defer func() {
|
||||
if logStatus == "" {
|
||||
logStatus = "success"
|
||||
}
|
||||
global.GVA_DB.Create(&gaia.ModelProxyLog{
|
||||
UserId: userID,
|
||||
ProviderName: providerName,
|
||||
ModelName: modelOrPath,
|
||||
Status: logStatus,
|
||||
ErrorMessage: logError,
|
||||
CreatedAt: startTime,
|
||||
})
|
||||
}()
|
||||
|
||||
// 写回状态码与响应头(流式由上游 Content-Type 决定)
|
||||
if w, ok := writer.(http.ResponseWriter); ok {
|
||||
for k, v := range resp.Header {
|
||||
for _, vv := range v {
|
||||
w.Header().Add(k, vv)
|
||||
}
|
||||
}
|
||||
w.WriteHeader(resp.StatusCode)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
|
||||
_, _ = io.Copy(writer, resp.Body)
|
||||
return nil
|
||||
}
|
||||
// 流式响应时按行刷新,避免缓冲
|
||||
if strings.Contains(resp.Header.Get("Content-Type"), "text/event-stream") {
|
||||
if flusher, ok := writer.(http.Flusher); ok {
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
for scanner.Scan() {
|
||||
fmt.Println("sss", scanner.Text())
|
||||
if _, err = writer.Write([]byte(scanner.Text() + "\n")); err != nil {
|
||||
logStatus, logError = "error", err.Error()
|
||||
return err
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
if err = scanner.Err(); err != nil {
|
||||
logStatus, logError = "error", err.Error()
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
_, err = io.Copy(writer, resp.Body)
|
||||
if err != nil {
|
||||
logStatus, logError = "error", err.Error()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// isProviderEnabled 检查该提供商是否已启用(未校验具体模型列表,用于通用代理)。
|
||||
func (s *ModelProviderService) isProviderEnabled(providerName string) bool {
|
||||
var config gaia.ModelProviderConfig
|
||||
if err := global.GVA_DB.Where("provider_name = ? AND enabled = ?", providerName, true).First(&config).Error; err != nil {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
@@ -230,6 +230,20 @@ func (i *initApi) InitializeData(ctx context.Context) (context.Context, error) {
|
||||
{ApiGroup: "应用版本", Method: "POST", Path: "/gaia/app-version/releases/:id/upload", Description: "上传安装包(自动识别平台架构)"},
|
||||
{ApiGroup: "应用版本", Method: "DELETE", Path: "/gaia/app-version/releases/:id/download", Description: "删除指定平台架构包"},
|
||||
// Extend Stop: batch workflow
|
||||
|
||||
// Extend Start: model provider (模型管理)
|
||||
{ApiGroup: "模型管理", Method: "GET", Path: "/gaia/model-provider/list", Description: "获取提供商配置列表"},
|
||||
{ApiGroup: "模型管理", Method: "POST", Path: "/gaia/model-provider/update", Description: "更新提供商配置"},
|
||||
{ApiGroup: "模型管理", Method: "GET", Path: "/gaia/model-provider/available-models", Description: "获取可用模型"},
|
||||
{ApiGroup: "模型管理", Method: "GET", Path: "/gaia/model-provider/test-credentials", Description: "测试提供商凭证"},
|
||||
{ApiGroup: "模型管理", Method: "GET", Path: "/gaia/model-provider/logs", Description: "获取代理日志"},
|
||||
{ApiGroup: "模型管理", Method: "GET", Path: "/gaia/models", Description: "获取开启的模型列表(第三方)"},
|
||||
{ApiGroup: "模型管理", Method: "GET", Path: "/gaia/proxy/*", Description: "中转API(第三方)-GET"},
|
||||
{ApiGroup: "模型管理", Method: "POST", Path: "/gaia/proxy/*", Description: "中转API(第三方)-POST"},
|
||||
{ApiGroup: "模型管理", Method: "PUT", Path: "/gaia/proxy/*", Description: "中转API(第三方)-PUT"},
|
||||
{ApiGroup: "模型管理", Method: "PATCH", Path: "/gaia/proxy/*", Description: "中转API(第三方)-PATCH"},
|
||||
{ApiGroup: "模型管理", Method: "DELETE", Path: "/gaia/proxy/*", Description: "中转API(第三方)-DELETE"},
|
||||
// Extend Stop: model provider
|
||||
}
|
||||
if err := db.Create(&entities).Error; err != nil {
|
||||
return ctx, errors.Wrap(err, sysModel.SysApi{}.TableName()+"表数据初始化失败!")
|
||||
|
||||
@@ -54,6 +54,9 @@ func (i *initMenuAuthority) InitializeData(ctx context.Context) (next context.Co
|
||||
if err = db.Model(&authorities[0]).Association("SysBaseMenus").Append(menus[40:41]); err != nil {
|
||||
return next, err
|
||||
}
|
||||
if err = db.Model(&authorities[0]).Association("SysBaseMenus").Append(menus[41:42]); err != nil {
|
||||
return next, err
|
||||
}
|
||||
if err = db.Model(&authorities[0]).Association("SysBaseMenus").Append(menus[2:5]); err != nil {
|
||||
return next, err
|
||||
}
|
||||
|
||||
@@ -378,6 +378,31 @@ func (i *initCasbin) InitializeData(ctx context.Context) (context.Context, error
|
||||
{Ptype: "p", V0: "1", V1: "/gaia/app-version/releases/:id/upload", V2: "POST"},
|
||||
{Ptype: "p", V0: "1", V1: "/gaia/app-version/releases/:id/download", V2: "DELETE"},
|
||||
// Extend Stop: app version
|
||||
|
||||
// Extend Start: model provider (模型管理)
|
||||
{Ptype: "p", V0: "888", V1: "/gaia/model-provider/list", V2: "GET"},
|
||||
{Ptype: "p", V0: "888", V1: "/gaia/model-provider/update", V2: "POST"},
|
||||
{Ptype: "p", V0: "888", V1: "/gaia/model-provider/available-models", V2: "GET"},
|
||||
{Ptype: "p", V0: "888", V1: "/gaia/model-provider/test-credentials", V2: "GET"},
|
||||
{Ptype: "p", V0: "888", V1: "/gaia/model-provider/logs", V2: "GET"},
|
||||
{Ptype: "p", V0: "888", V1: "/gaia/models", V2: "GET"},
|
||||
{Ptype: "p", V0: "888", V1: "/gaia/proxy/*", V2: "GET"},
|
||||
{Ptype: "p", V0: "888", V1: "/gaia/proxy/*", V2: "POST"},
|
||||
{Ptype: "p", V0: "888", V1: "/gaia/proxy/*", V2: "PUT"},
|
||||
{Ptype: "p", V0: "888", V1: "/gaia/proxy/*", V2: "PATCH"},
|
||||
{Ptype: "p", V0: "888", V1: "/gaia/proxy/*", V2: "DELETE"},
|
||||
{Ptype: "p", V0: "8881", V1: "/gaia/model-provider/list", V2: "GET"},
|
||||
{Ptype: "p", V0: "8881", V1: "/gaia/model-provider/update", V2: "POST"},
|
||||
{Ptype: "p", V0: "8881", V1: "/gaia/model-provider/available-models", V2: "GET"},
|
||||
{Ptype: "p", V0: "8881", V1: "/gaia/model-provider/test-credentials", V2: "GET"},
|
||||
{Ptype: "p", V0: "8881", V1: "/gaia/model-provider/logs", V2: "GET"},
|
||||
{Ptype: "p", V0: "8881", V1: "/gaia/models", V2: "GET"},
|
||||
{Ptype: "p", V0: "8881", V1: "/gaia/proxy/*", V2: "GET"},
|
||||
{Ptype: "p", V0: "8881", V1: "/gaia/proxy/*", V2: "POST"},
|
||||
{Ptype: "p", V0: "8881", V1: "/gaia/proxy/*", V2: "PUT"},
|
||||
{Ptype: "p", V0: "8881", V1: "/gaia/proxy/*", V2: "PATCH"},
|
||||
{Ptype: "p", V0: "8881", V1: "/gaia/proxy/*", V2: "DELETE"},
|
||||
// Extend Stop: model provider
|
||||
}
|
||||
if err := db.Create(&entities).Error; err != nil {
|
||||
return ctx, errors.Wrap(err, "Casbin 表 ("+i.InitializerName()+") 数据初始化失败!")
|
||||
|
||||
@@ -94,6 +94,7 @@ func (i *initMenu) InitializeData(ctx context.Context) (next context.Context, er
|
||||
{GVA_MODEL: global.GVA_MODEL{ID: 39}, MenuLevel: 0, Hidden: false, ParentId: 38, Path: "IntegratedDingTalk", Name: "IntegratedDingTalk", Component: "view/systemIntegrated/dingTalk/index.vue", Sort: 1, Meta: Meta{Title: "钉钉", Icon: "turn-off"}},
|
||||
{GVA_MODEL: global.GVA_MODEL{ID: 40}, MenuLevel: 0, Hidden: false, ParentId: 38, Path: "IntegratedOAuth2", Name: "IntegratedOAuth2", Component: "view/systemIntegrated/oauth2/index.vue", Sort: 2, Meta: Meta{Title: "OAuth2", Icon: "share"}},
|
||||
{GVA_MODEL: global.GVA_MODEL{ID: 41}, MenuLevel: 0, Hidden: false, ParentId: 0, Path: "AppVersion", Name: "AppVersion", Component: "view/gaia/appVersion/index.vue", Sort: 10, Meta: Meta{Title: "版本管理", Icon: "upload-filled"}},
|
||||
{GVA_MODEL: global.GVA_MODEL{ID: 42}, MenuLevel: 0, Hidden: false, ParentId: 38, Path: "IntegratedModelManagement", Name: "IntegratedModelManagement", Component: "view/systemIntegrated/modelManagement/index.vue", Sort: 3, Meta: Meta{Title: "模型管理", Icon: "cpu"}},
|
||||
// 二开部分
|
||||
}
|
||||
if err = db.Create(&entities).Error; err != nil {
|
||||
|
||||
@@ -66,17 +66,25 @@ func GetToken(c *gin.Context) string {
|
||||
}
|
||||
|
||||
func GetClaims(c *gin.Context) (*systemReq.CustomClaims, error) {
|
||||
token := GetToken(c)
|
||||
// init
|
||||
j := NewJWT()
|
||||
token := GetToken(c)
|
||||
claims, err := j.ParseToken(token)
|
||||
if err != nil {
|
||||
global.GVA_LOG.Error("从Gin的Context中获取从jwt解析信息失败, 请检查请求头是否存在x-token且claims是否为规定结构")
|
||||
}
|
||||
// 判断是否dify的token
|
||||
if claims.Username == "" {
|
||||
var userList []string
|
||||
var user system.SysUser
|
||||
var account gaia.Account
|
||||
if err = global.GVA_DB.Where("uuid=?", claims.UserId).First(&user).Error; err == nil {
|
||||
if claims.UserId != "" {
|
||||
userList = append(userList, claims.UserId)
|
||||
} else if claims.Sub != "" {
|
||||
userList = append(userList, claims.Sub)
|
||||
}
|
||||
// sql
|
||||
if err = global.GVA_DB.Where("uuid IN (?)", userList).First(&user).Error; err == nil {
|
||||
claims.BaseClaims.ID = user.ID
|
||||
claims.Username = user.Username
|
||||
claims.AuthorityId = user.AuthorityId
|
||||
|
||||
@@ -68,9 +68,10 @@ func (j *JWT) CreateTokenByOldToken(oldToken string, claims request.CustomClaims
|
||||
return v.(string), err
|
||||
}
|
||||
|
||||
// 解析 token
|
||||
// ParseToken 解析 token
|
||||
func (j *JWT) ParseToken(tokenString string) (*request.CustomClaims, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenString, &request.CustomClaims{}, func(token *jwt.Token) (i interface{}, e error) {
|
||||
token, err := jwt.ParseWithClaims(tokenString, &request.CustomClaims{}, func(
|
||||
token *jwt.Token) (i interface{}, e error) {
|
||||
return j.SigningKey, nil
|
||||
})
|
||||
if err != nil {
|
||||
|
||||
@@ -16,3 +16,4 @@ COPY --from=0 /gva_web/dist /usr/share/nginx/html/admin/
|
||||
RUN cat /etc/nginx/nginx.conf
|
||||
RUN cat /etc/nginx/conf.d/my.conf
|
||||
RUN ls -al /usr/share/nginx/html
|
||||
|
||||
|
||||
@@ -0,0 +1,57 @@
|
||||
import service from '@/utils/request'
|
||||
|
||||
// 获取提供商配置列表
|
||||
export const getProviderListApi = () => {
|
||||
return service({
|
||||
url: '/gaia/model-provider/list',
|
||||
method: 'get'
|
||||
})
|
||||
}
|
||||
|
||||
// 更新提供商配置
|
||||
export const updateProviderConfigApi = (data) => {
|
||||
return service({
|
||||
url: '/gaia/model-provider/update',
|
||||
method: 'post',
|
||||
data
|
||||
})
|
||||
}
|
||||
|
||||
// 获取可用模型
|
||||
export const getAvailableModelsApi = (providerName) => {
|
||||
return service({
|
||||
url: '/gaia/model-provider/available-models',
|
||||
method: 'get',
|
||||
params: {
|
||||
provider_name: providerName
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// 测试提供商凭证
|
||||
export const testProviderCredentialsApi = (providerName) => {
|
||||
return service({
|
||||
url: '/gaia/model-provider/test-credentials',
|
||||
method: 'get',
|
||||
params: {
|
||||
provider_name: providerName
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// 获取开启的模型列表(OpenAI格式)
|
||||
export const getEnabledModelsApi = () => {
|
||||
return service({
|
||||
url: '/gaia/models',
|
||||
method: 'get'
|
||||
})
|
||||
}
|
||||
|
||||
// 获取代理日志
|
||||
export const getProxyLogsApi = (params) => {
|
||||
return service({
|
||||
url: '/gaia/model-provider/logs',
|
||||
method: 'get',
|
||||
params
|
||||
})
|
||||
}
|
||||
@@ -11,3 +11,30 @@ export const oaLogin = (data) => {
|
||||
data: data
|
||||
})
|
||||
}
|
||||
|
||||
// 获取 Gaia 登录方式(钉钉/OAuth2 是否启用及授权地址)
|
||||
export const getGaiaLoginOptions = (params) => {
|
||||
return service({
|
||||
url: '/base/gaiaLoginOptions',
|
||||
method: 'get',
|
||||
params
|
||||
})
|
||||
}
|
||||
|
||||
// Gaia OAuth2 登录:传 code 或 access_token(Extend: 兼容 casdoor implicit/hybrid 仅回传 access_token)
|
||||
export const gaiaOAuth2Login = (data) => {
|
||||
return service({
|
||||
url: '/base/gaiaOAuth2Login',
|
||||
method: 'post',
|
||||
data
|
||||
})
|
||||
}
|
||||
|
||||
// 钉钉 code 登录
|
||||
export const dingtalkLogin = (data) => {
|
||||
return service({
|
||||
url: '/base/dingtalkLogin',
|
||||
method: 'post',
|
||||
data
|
||||
})
|
||||
}
|
||||
|
||||
+31
-6
@@ -12,6 +12,7 @@ import '@/permission'
|
||||
import run from '@/core/gin-vue-admin.js'
|
||||
import auth from '@/directive/auth'
|
||||
import { store } from '@/pinia'
|
||||
import { useUserStore } from '@/pinia/modules/user'
|
||||
import App from './App.vue'
|
||||
// 消除警告
|
||||
import 'default-passive-events'
|
||||
@@ -20,10 +21,34 @@ const app = createApp(App)
|
||||
app.config.productionTip = false
|
||||
|
||||
app
|
||||
.use(run)
|
||||
.use(ElementPlus)
|
||||
.use(store)
|
||||
.use(auth)
|
||||
.use(router)
|
||||
.mount('#app')
|
||||
.use(run)
|
||||
.use(ElementPlus)
|
||||
.use(store)
|
||||
.use(auth)
|
||||
.use(router)
|
||||
.mount('#app')
|
||||
|
||||
// 如果当前 URL 上带有 clear_cache=true,则清空本地缓存与 Cookie
|
||||
const hasClearCacheFlag = () => {
|
||||
// 主 URL query(?a=1&clear_cache=true)
|
||||
const searchParams = new URLSearchParams(window.location.search || '')
|
||||
if (searchParams.get('clear_cache') === 'true') return true
|
||||
|
||||
// hash 部分 query(/#/login?redirect_uri=...&clear_cache=true)
|
||||
const hash = window.location.hash || ''
|
||||
const idx = hash.indexOf('?')
|
||||
if (idx !== -1) {
|
||||
const hashQuery = hash.substring(idx + 1)
|
||||
const hashParams = new URLSearchParams(hashQuery)
|
||||
if (hashParams.get('clear_cache') === 'true') return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
if (hasClearCacheFlag()) {
|
||||
const userStore = useUserStore()
|
||||
// 统一使用 store 的清理逻辑:清 token、sessionStorage、localStorage 部分键、cookie 等
|
||||
userStore.ClearStorage && userStore.ClearStorage()
|
||||
}
|
||||
|
||||
export default app
|
||||
|
||||
@@ -64,6 +64,7 @@
|
||||
"/src/view/system/state.vue": "State",
|
||||
"/src/view/systemIntegrated/dingTalk/index.vue": "IntegratedDingTalk",
|
||||
"/src/view/systemIntegrated/index.vue": "SystemIntegrated",
|
||||
"/src/view/systemIntegrated/modelManagement/index.vue": "IntegratedModelManagement",
|
||||
"/src/view/systemIntegrated/oauth2/index.vue": "IntegratedOAuth2",
|
||||
"/src/view/systemTools/autoCode/component/fieldDialog.vue": "FieldDialog",
|
||||
"/src/view/systemTools/autoCode/component/previewCodeDialog.vue": "PreviewCodeDialog",
|
||||
|
||||
@@ -55,8 +55,11 @@ export const useUserStore = defineStore('user', () => {
|
||||
}
|
||||
return res
|
||||
}
|
||||
/* 登录*/
|
||||
const LoginIn = async(loginInfo) => {
|
||||
/* 登录
|
||||
* @param loginInfo 账号密码等
|
||||
* @param opts 可选 { redirect_uri, state },第三方带回调时:登录成功后跳回 redirect_uri 并带上 token 与 state,不再进入后台
|
||||
*/
|
||||
const LoginIn = async(loginInfo, opts = {}) => {
|
||||
loadingInstance.value = ElLoading.service({
|
||||
fullscreen: true,
|
||||
text: '登录中,请稍候...',
|
||||
@@ -74,6 +77,18 @@ export const useUserStore = defineStore('user', () => {
|
||||
setUserInfo(res.data.user)
|
||||
setToken(res.data.token)
|
||||
|
||||
const redirectUri = opts.redirect_uri && opts.redirect_uri.trim()
|
||||
const thirdPartyState = opts.state != null ? String(opts.state) : ''
|
||||
|
||||
// 第三方回调:带 token 跳回第三方,不进入后台
|
||||
if (redirectUri) {
|
||||
loadingInstance.value.close()
|
||||
const sep = redirectUri.includes('?') ? '&' : '?'
|
||||
const url = redirectUri + sep + 'token=' + encodeURIComponent(res.data.token) + (thirdPartyState ? '&state=' + encodeURIComponent(thirdPartyState) : '')
|
||||
window.location.href = url
|
||||
return true
|
||||
}
|
||||
|
||||
// 初始化路由信息
|
||||
const routerStore = useRouterStore()
|
||||
await routerStore.SetAsyncRouter()
|
||||
@@ -188,6 +203,7 @@ export const useUserStore = defineStore('user', () => {
|
||||
OaLoginIn,
|
||||
LoginOut,
|
||||
setToken,
|
||||
setUserInfo,
|
||||
loadingInstance,
|
||||
ClearStorage
|
||||
}
|
||||
|
||||
@@ -194,7 +194,7 @@ const out = ref(false)
|
||||
const form = reactive({
|
||||
adminPassword: '123456',
|
||||
dbType: 'pgsql',
|
||||
host: 'db',
|
||||
host: 'db_postgres',
|
||||
port: '5432',
|
||||
userName: 'postgres',
|
||||
password: 'difyai123456',
|
||||
|
||||
@@ -1,46 +1,123 @@
|
||||
<template>
|
||||
<div v-show="false">该页面用于对接oa-oauth2.0回调登录</div>
|
||||
<div class="flex items-center justify-center min-h-screen">
|
||||
<span>登录中...</span>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup>
|
||||
import { ElMessage } from 'element-plus'
|
||||
import { useRoute } from 'vue-router'
|
||||
import { useUserStore } from '@/pinia/modules/user'
|
||||
import { useRouterStore } from '@/pinia/modules/router'
|
||||
import router from '@/router'
|
||||
import { gaiaOAuth2Login, dingtalkLogin } from '@/api/user_extend'
|
||||
|
||||
defineOptions({
|
||||
name: 'LoginCallback',
|
||||
})
|
||||
const route = useRoute()
|
||||
const userStore = useUserStore()
|
||||
const oaLogin = async() => {
|
||||
return await userStore.OaLoginIn(route.query.code)
|
||||
|
||||
const redirectToThirdParty = (token, redirectUri, state) => {
|
||||
if (!redirectUri) return false
|
||||
sessionStorage.removeItem('gaia_login_redirect_uri')
|
||||
sessionStorage.removeItem('gaia_login_state')
|
||||
const sep = redirectUri.includes('?') ? '&' : '?'
|
||||
const url = redirectUri + sep + 'token=' + encodeURIComponent(token) + (state ? '&state=' + encodeURIComponent(state) : '')
|
||||
window.location.href = url
|
||||
window.location.href = "/"
|
||||
return true
|
||||
}
|
||||
const callback = async() => {
|
||||
if (route.query.code === undefined || route.query.code === '') {
|
||||
ElMessage({
|
||||
type: 'error',
|
||||
message: '登录失败,授权码缺失,3秒后跳转到登录页',
|
||||
showClose: true,
|
||||
})
|
||||
// 3秒后跳转登录页
|
||||
setTimeout(() => {
|
||||
window.location.href = '/'
|
||||
}, 3000)
|
||||
return false
|
||||
}
|
||||
const flag = await oaLogin()
|
||||
if (!flag) {
|
||||
ElMessage({
|
||||
type: 'error',
|
||||
message: '登录失败,3秒后跳转到登录页',
|
||||
showClose: true,
|
||||
})
|
||||
// 3秒后跳转登录页
|
||||
setTimeout(() => {
|
||||
window.location.href = '/'
|
||||
}, 3000)
|
||||
}
|
||||
return
|
||||
|
||||
const goToDashboard = async () => {
|
||||
const routerStore = useRouterStore()
|
||||
await routerStore.SetAsyncRouter()
|
||||
routerStore.asyncRouters.forEach(r => router.addRoute(r))
|
||||
const name = userStore.userInfo?.authority?.defaultRouter || 'gaiaDashboard'
|
||||
await router.replace({ name: name || 'gaiaDashboard' })
|
||||
}
|
||||
|
||||
const failAndBackToLogin = (msg) => {
|
||||
ElMessage({ type: 'error', message: msg || '登录失败,3秒后跳转到登录页', showClose: true })
|
||||
setTimeout(() => { window.location.href = '/#/login' }, 3000)
|
||||
}
|
||||
|
||||
// 钉钉/OAuth 回调时 code、authCode 可能在 hash 前的主 URL query 中(如 /admin/?code=xx&authCode=xx&state=dingtalk#/loginCallback?provider=dingtalk)
|
||||
const getQueryParam = (name) => {
|
||||
const fromRoute = route.query[name]
|
||||
if (fromRoute) return fromRoute
|
||||
const search = window.location.search
|
||||
if (!search) return ''
|
||||
const params = new URLSearchParams(search)
|
||||
return params.get(name) || ''
|
||||
}
|
||||
|
||||
const callback = async () => {
|
||||
const provider = getQueryParam('provider') || route.query.provider
|
||||
const code = getQueryParam('code') || getQueryParam('authCode') || route.query.code || route.query.authCode
|
||||
// Extend Start: 兼容 casdoor(部分 OAuth 如 Casdoor 可能通过 implicit/hybrid 直接回传 access_token,无 code)
|
||||
const accessTokenFromQuery = getQueryParam('access_token') || route.query.access_token || ''
|
||||
const hasCode = !!code
|
||||
const hasAccessToken = !!accessTokenFromQuery
|
||||
if (!hasCode && !hasAccessToken) {
|
||||
failAndBackToLogin('授权码或 access_token 缺失,3秒后跳转到登录页')
|
||||
return
|
||||
}
|
||||
// Extend Stop: 兼容 casdoor
|
||||
|
||||
const redirectUri = sessionStorage.getItem('gaia_login_redirect_uri') || ''
|
||||
const state = sessionStorage.getItem('gaia_login_state') || getQueryParam('state') || ''
|
||||
|
||||
try {
|
||||
if (provider === 'dingtalk') {
|
||||
if (!hasCode) {
|
||||
failAndBackToLogin('钉钉登录需要授权码')
|
||||
return
|
||||
}
|
||||
const res = await dingtalkLogin({ auth_code: code, redirect_uri: redirectUri, state })
|
||||
if (res?.code === 0 && res.data?.token) {
|
||||
userStore.setUserInfo(res.data.user)
|
||||
userStore.setToken(res.data.token)
|
||||
// 优先用接口返回的 redirect_uri/state(用户可能从应用直接跳到钉钉,未经过登录页,sessionStorage 为空)
|
||||
const finalRedirect = res.data.redirect_uri || redirectUri
|
||||
const finalState = res.data.state ?? state
|
||||
if (redirectToThirdParty(res.data.token, finalRedirect, finalState)) return
|
||||
await goToDashboard()
|
||||
return
|
||||
}
|
||||
} else if (provider === 'oauth2') {
|
||||
// Extend Start: 兼容 casdoor(支持仅带 access_token 的回调)
|
||||
const payload = hasCode
|
||||
? { code, redirect_uri: redirectUri, state }
|
||||
: { access_token: accessTokenFromQuery, redirect_uri: redirectUri, state }
|
||||
const res = await gaiaOAuth2Login(payload)
|
||||
// Extend Stop: 兼容 casdoor
|
||||
if (res?.code === 0 && res.data?.token) {
|
||||
userStore.setUserInfo(res.data.user)
|
||||
userStore.setToken(res.data.token)
|
||||
const finalRedirect = res.data.redirect_uri || redirectUri
|
||||
const finalState = res.data.state ?? state
|
||||
if (redirectToThirdParty(res.data.token, finalRedirect, finalState)) return
|
||||
await goToDashboard()
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if (!hasCode) {
|
||||
failAndBackToLogin('该登录方式需要授权码')
|
||||
return
|
||||
}
|
||||
const flag = await userStore.OaLoginIn(code)
|
||||
if (flag) {
|
||||
if (redirectToThirdParty(userStore.token, redirectUri, state)) return
|
||||
await goToDashboard()
|
||||
return
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
console.error(e)
|
||||
}
|
||||
failAndBackToLogin('登录失败,3秒后跳转到登录页')
|
||||
}
|
||||
|
||||
callback()
|
||||
</script>
|
||||
|
||||
@@ -21,8 +21,8 @@
|
||||
</div>
|
||||
<div class="mb-9">
|
||||
<p class="text-center text-4xl font-bold">{{ $GIN_VUE_ADMIN.appName }}</p>
|
||||
<p class="text-center text-sm font-normal text-gray-500 mt-2.5">A management platform for Dify-Plus
|
||||
</p>
|
||||
<p class="text-center text-sm font-normal text-gray-500 mt-2.5">A management platform for Dify-Plus</p>
|
||||
<p v-if="redirectUri" class="text-center text-xs text-blue-600 mt-2">登录后将跳回第三方应用</p>
|
||||
</div>
|
||||
<el-form
|
||||
ref="loginForm"
|
||||
@@ -87,7 +87,32 @@
|
||||
type="primary"
|
||||
size="large"
|
||||
@click="submitForm"
|
||||
>登 录</el-button>
|
||||
>账号密码登录</el-button>
|
||||
</el-form-item>
|
||||
<!-- 钉钉 / OAuth2 登录:仅在有 redirect_uri(第三方回调)时显示 -->
|
||||
<el-form-item
|
||||
v-if="loginOptions.dingtalk.enabled && redirectUri"
|
||||
class="mb-6"
|
||||
>
|
||||
<el-button
|
||||
class="shadow h-11 w-full"
|
||||
size="large"
|
||||
@click="dingtalkLoginJump"
|
||||
>
|
||||
钉钉登录
|
||||
</el-button>
|
||||
</el-form-item>
|
||||
<el-form-item
|
||||
v-if="loginOptions.oauth2.enabled && redirectUri"
|
||||
class="mb-6"
|
||||
>
|
||||
<el-button
|
||||
class="shadow shadow-blue-600 h-11 w-full"
|
||||
size="large"
|
||||
@click="oauth2LoginJump"
|
||||
>
|
||||
OAuth2 登录
|
||||
</el-button>
|
||||
</el-form-item>
|
||||
</template>
|
||||
<!-- 新增是否已经初始化判断 End -->
|
||||
@@ -103,19 +128,6 @@
|
||||
>前往初始化</el-button>
|
||||
|
||||
</el-form-item>
|
||||
<!-- 新增OA登录 Begin -->
|
||||
<el-form-item class="mb-6">
|
||||
<el-button
|
||||
class="shadow shadow-blue-600 h-11 w-full"
|
||||
type="primary"
|
||||
size="large"
|
||||
disabled
|
||||
@click="oaLoginJump"
|
||||
>
|
||||
Oauth2 登录(敬请期待)
|
||||
</el-button>
|
||||
</el-form-item>
|
||||
<!-- 新增OA登录 End -->
|
||||
</el-form>
|
||||
</div>
|
||||
</div>
|
||||
@@ -177,10 +189,11 @@
|
||||
<script setup>
|
||||
import { captcha } from '@/api/user'
|
||||
import { checkDB } from '@/api/initdb'
|
||||
import { getGaiaLoginOptions } from '@/api/user_extend'
|
||||
import BottomInfo from '@/components/bottomInfo/bottomInfo.vue'
|
||||
import { reactive, ref } from 'vue'
|
||||
import { reactive, ref, onMounted } from 'vue'
|
||||
import { ElMessage } from 'element-plus'
|
||||
import { useRouter } from 'vue-router'
|
||||
import { useRouter, useRoute } from 'vue-router'
|
||||
import { useUserStore } from '@/pinia/modules/user'
|
||||
|
||||
defineOptions({
|
||||
@@ -188,6 +201,17 @@ defineOptions({
|
||||
})
|
||||
|
||||
const router = useRouter()
|
||||
const route = useRoute()
|
||||
|
||||
// 第三方回调参数(用于登录成功后跳回第三方并带 token)
|
||||
const redirectUri = ref(route.query.redirect_uri || '')
|
||||
const thirdPartyState = ref(route.query.state || '')
|
||||
|
||||
// Gaia 登录方式(钉钉/OAuth2)
|
||||
const loginOptions = reactive({
|
||||
dingtalk: { enabled: false, auth_url: '' },
|
||||
oauth2: { enabled: false, auth_url: '' }
|
||||
})
|
||||
const showInit = ref(false)
|
||||
// 验证函数
|
||||
const checkUsername = (rule, value, callback) => {
|
||||
@@ -243,7 +267,10 @@ const rules = reactive({
|
||||
|
||||
const userStore = useUserStore()
|
||||
const login = async() => {
|
||||
return await userStore.LoginIn(loginFormData)
|
||||
return await userStore.LoginIn(loginFormData, {
|
||||
redirect_uri: redirectUri.value || undefined,
|
||||
state: thirdPartyState.value || undefined,
|
||||
})
|
||||
}
|
||||
const submitForm = () => {
|
||||
loginForm.value.validate(async(v) => {
|
||||
@@ -298,16 +325,53 @@ const showInitExtend = async() => {
|
||||
}
|
||||
showInitExtend()
|
||||
|
||||
// 跳转oa登录链接
|
||||
const oaLoginJump = () => {
|
||||
const clientId = import.meta.env.VITE_OA_LOGIN_CLINET_ID
|
||||
const oaUrl = import.meta.env.VITE_OA_URL
|
||||
const redirect_uri = window.location.origin + '#/loginCallback'
|
||||
// 获取loginCallback该路由的完整url
|
||||
|
||||
const jumpUrl = oaUrl + '?client_id=' + clientId + '&redirect_uri=' + encodeURIComponent(redirect_uri) + '&state='
|
||||
console.log(jumpUrl)
|
||||
window.location.href = jumpUrl
|
||||
// 已登录且带 redirect_uri 时直接回调第三方
|
||||
const tryRedirectWithToken = async () => {
|
||||
if (!redirectUri.value || !userStore.token) return false
|
||||
const res = await userStore.GetUserInfo()
|
||||
if (res?.code === 0) {
|
||||
const sep = redirectUri.value.includes('?') ? '&' : '?'
|
||||
const url = redirectUri.value + sep + 'token=' + encodeURIComponent(userStore.token) + (thirdPartyState.value ? '&state=' + encodeURIComponent(thirdPartyState.value) : '')
|
||||
window.location.href = url
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// 拉取登录方式并检测已登录回调
|
||||
const loadLoginOptionsAndMaybeRedirect = async () => {
|
||||
const didRedirect = await tryRedirectWithToken()
|
||||
if (didRedirect) return
|
||||
try {
|
||||
const res = await getGaiaLoginOptions({ origin: window.location.origin })
|
||||
if (res?.code === 0 && res.data) {
|
||||
if (res.data.dingtalk) {
|
||||
loginOptions.dingtalk.enabled = res.data.dingtalk.enabled
|
||||
loginOptions.dingtalk.auth_url = res.data.dingtalk.auth_url || ''
|
||||
}
|
||||
if (res.data.oauth2) {
|
||||
loginOptions.oauth2.enabled = res.data.oauth2.enabled
|
||||
loginOptions.oauth2.auth_url = res.data.oauth2.auth_url || ''
|
||||
}
|
||||
}
|
||||
} catch (_) {}
|
||||
}
|
||||
|
||||
// 钉钉登录:保存回调参数并跳转钉钉授权
|
||||
const dingtalkLoginJump = () => {
|
||||
sessionStorage.setItem('gaia_login_redirect_uri', redirectUri.value)
|
||||
sessionStorage.setItem('gaia_login_state', thirdPartyState.value)
|
||||
if (loginOptions.dingtalk.auth_url) window.location.href = loginOptions.dingtalk.auth_url
|
||||
}
|
||||
|
||||
// OAuth2 登录:保存回调参数并跳转 OAuth2 授权
|
||||
const oauth2LoginJump = () => {
|
||||
sessionStorage.setItem('gaia_login_redirect_uri', redirectUri.value)
|
||||
sessionStorage.setItem('gaia_login_state', thirdPartyState.value)
|
||||
if (loginOptions.oauth2.auth_url) window.location.href = loginOptions.oauth2.auth_url
|
||||
}
|
||||
|
||||
onMounted(() => {
|
||||
loadLoginOptionsAndMaybeRedirect()
|
||||
})
|
||||
</script>
|
||||
|
||||
@@ -0,0 +1,391 @@
|
||||
<template>
|
||||
<div class="model-management">
|
||||
<el-card class="box-card">
|
||||
<template #header>
|
||||
<div class="card-header">
|
||||
<span class="title">模型管理</span>
|
||||
<el-button
|
||||
type="primary"
|
||||
:loading="saving"
|
||||
@click="saveAll"
|
||||
>
|
||||
保存配置
|
||||
</el-button>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<div v-loading="loading" class="provider-list">
|
||||
<el-empty
|
||||
v-if="!loading && providerList.length === 0"
|
||||
description="暂无提供商配置"
|
||||
/>
|
||||
|
||||
<div
|
||||
v-for="provider in providerList"
|
||||
:key="provider.provider_name"
|
||||
class="provider-item"
|
||||
>
|
||||
<div class="provider-header">
|
||||
<div class="provider-info">
|
||||
<el-icon class="provider-icon">
|
||||
<cpu />
|
||||
</el-icon>
|
||||
<span class="provider-name">{{ getProviderDisplayName(provider.provider_name) }}</span>
|
||||
<el-tag
|
||||
v-if="provider.enabled"
|
||||
type="success"
|
||||
size="small"
|
||||
>
|
||||
已开启
|
||||
</el-tag>
|
||||
<el-tag
|
||||
v-else
|
||||
type="info"
|
||||
size="small"
|
||||
>
|
||||
已关闭
|
||||
</el-tag>
|
||||
</div>
|
||||
<div class="provider-actions">
|
||||
<el-button
|
||||
size="small"
|
||||
:type="provider.enabled ? 'danger' : 'success'"
|
||||
@click="toggleProvider(provider)"
|
||||
>
|
||||
{{ provider.enabled ? '关闭' : '开启' }}
|
||||
</el-button>
|
||||
<el-button
|
||||
size="small"
|
||||
@click="testCredentials(provider.provider_name)"
|
||||
>
|
||||
测试凭证
|
||||
</el-button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<el-collapse-transition>
|
||||
<div v-if="provider.enabled" class="provider-models">
|
||||
<div class="models-header">
|
||||
<span class="models-title">可用模型</span>
|
||||
<div class="models-actions">
|
||||
<el-button
|
||||
size="small"
|
||||
text
|
||||
type="primary"
|
||||
@click="selectAllModels(provider)"
|
||||
>
|
||||
全选
|
||||
</el-button>
|
||||
<el-button
|
||||
size="small"
|
||||
text
|
||||
type="info"
|
||||
@click="clearAllModels(provider)"
|
||||
>
|
||||
清空
|
||||
</el-button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div v-if="provider.available_models && provider.available_models.length > 0" class="models-select-wrapper">
|
||||
<el-select
|
||||
v-model="provider.models"
|
||||
multiple
|
||||
filterable
|
||||
collapse-tags
|
||||
collapse-tags-tooltip
|
||||
:max-collapse-tags="5"
|
||||
placeholder="请选择模型"
|
||||
class="models-select"
|
||||
@change="onModelSelectChange(provider)"
|
||||
>
|
||||
<el-option
|
||||
v-for="model in provider.available_models"
|
||||
:key="model.id"
|
||||
:label="model.name"
|
||||
:value="model.id"
|
||||
/>
|
||||
</el-select>
|
||||
<div class="selected-count">
|
||||
已选择 {{ provider.models?.length || 0 }} / {{ provider.available_models.length }} 个模型
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<el-empty
|
||||
v-else
|
||||
description="未找到可用模型,请先在Dify中配置该提供商"
|
||||
:image-size="80"
|
||||
/>
|
||||
</div>
|
||||
</el-collapse-transition>
|
||||
</div>
|
||||
</div>
|
||||
</el-card>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup>
|
||||
import { ref, reactive, onMounted } from 'vue'
|
||||
import { ElMessage, ElMessageBox } from 'element-plus'
|
||||
import { Cpu } from '@element-plus/icons-vue'
|
||||
import {
|
||||
getProviderListApi,
|
||||
updateProviderConfigApi,
|
||||
testProviderCredentialsApi
|
||||
} from '@/api/modelProvider'
|
||||
|
||||
defineOptions({
|
||||
name: 'IntegratedModelManagement'
|
||||
})
|
||||
|
||||
const loading = ref(false)
|
||||
const saving = ref(false)
|
||||
const providerList = ref([])
|
||||
|
||||
// 提供商显示名称映射
|
||||
const providerDisplayNames = {
|
||||
openai: 'OpenAI',
|
||||
tongyi: '千问(通义)',
|
||||
google: 'Google Gemini'
|
||||
}
|
||||
|
||||
const getProviderDisplayName = (providerName) => {
|
||||
return providerDisplayNames[providerName] || providerName
|
||||
}
|
||||
|
||||
// 获取提供商列表
|
||||
const getProviderList = async() => {
|
||||
loading.value = true
|
||||
try {
|
||||
const res = await getProviderListApi()
|
||||
if (res.code === 0) {
|
||||
// 处理数据,添加selectedModelsSet用于checkbox绑定
|
||||
providerList.value = res.data.map(provider => {
|
||||
const selectedModelsSet = {}
|
||||
if (provider.models && Array.isArray(provider.models)) {
|
||||
provider.models.forEach(modelId => {
|
||||
selectedModelsSet[modelId] = true
|
||||
})
|
||||
}
|
||||
return {
|
||||
...provider,
|
||||
selectedModelsSet
|
||||
}
|
||||
})
|
||||
} else {
|
||||
ElMessage.error(res.msg || '获取提供商列表失败')
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('获取提供商列表失败', error)
|
||||
ElMessage.error('获取提供商列表失败')
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// 切换提供商开关
|
||||
const toggleProvider = (provider) => {
|
||||
provider.enabled = !provider.enabled
|
||||
if (!provider.enabled) {
|
||||
// 关闭时清空选中的模型
|
||||
provider.selectedModelsSet = {}
|
||||
provider.models = []
|
||||
}
|
||||
}
|
||||
|
||||
// 下拉框选择变化
|
||||
const onModelSelectChange = (provider) => {
|
||||
// 同步更新 selectedModelsSet(保持兼容性)
|
||||
provider.selectedModelsSet = {}
|
||||
if (provider.models && Array.isArray(provider.models)) {
|
||||
provider.models.forEach(modelId => {
|
||||
provider.selectedModelsSet[modelId] = true
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// 全选模型
|
||||
const selectAllModels = (provider) => {
|
||||
if (provider.available_models && provider.available_models.length > 0) {
|
||||
provider.models = provider.available_models.map(model => model.id)
|
||||
onModelSelectChange(provider)
|
||||
}
|
||||
}
|
||||
|
||||
// 清空模型选择
|
||||
const clearAllModels = (provider) => {
|
||||
provider.selectedModelsSet = {}
|
||||
provider.models = []
|
||||
}
|
||||
|
||||
// 保存所有配置
|
||||
const saveAll = async() => {
|
||||
saving.value = true
|
||||
try {
|
||||
// 逐个保存提供商配置
|
||||
for (const provider of providerList.value) {
|
||||
await updateProviderConfigApi({
|
||||
provider_name: provider.provider_name,
|
||||
enabled: provider.enabled,
|
||||
models: provider.models || []
|
||||
})
|
||||
}
|
||||
ElMessage.success('保存成功')
|
||||
} catch (error) {
|
||||
console.error('保存配置失败', error)
|
||||
ElMessage.error('保存配置失败')
|
||||
} finally {
|
||||
saving.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// 测试凭证
|
||||
const testCredentials = async(providerName) => {
|
||||
try {
|
||||
const res = await testProviderCredentialsApi(providerName)
|
||||
if (res.code === 0) {
|
||||
ElMessageBox.alert(
|
||||
`提供商: ${res.data.provider}\nAPI Key: ${res.data.api_key}\n凭证状态: ${res.data.has_api_key ? '已配置' : '未配置'}`,
|
||||
'凭证测试结果',
|
||||
{
|
||||
confirmButtonText: '确定',
|
||||
type: 'success'
|
||||
}
|
||||
)
|
||||
} else {
|
||||
ElMessage.error(res.msg || '测试失败')
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('测试凭证失败', error)
|
||||
ElMessage.error('测试凭证失败:' + (error.message || '未知错误'))
|
||||
}
|
||||
}
|
||||
|
||||
onMounted(() => {
|
||||
getProviderList()
|
||||
})
|
||||
</script>
|
||||
|
||||
<style scoped lang="scss">
|
||||
.model-management {
|
||||
padding: 20px;
|
||||
|
||||
.box-card {
|
||||
.card-header {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
|
||||
.title {
|
||||
font-size: 18px;
|
||||
font-weight: 600;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
.provider-list {
|
||||
min-height: 400px;
|
||||
}
|
||||
|
||||
.provider-item {
|
||||
border: 1px solid #e4e7ed;
|
||||
border-radius: 8px;
|
||||
padding: 20px;
|
||||
margin-bottom: 20px;
|
||||
transition: all 0.3s;
|
||||
|
||||
&:hover {
|
||||
box-shadow: 0 2px 12px 0 rgba(0, 0, 0, 0.1);
|
||||
}
|
||||
|
||||
.provider-header {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
margin-bottom: 16px;
|
||||
|
||||
.provider-info {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 12px;
|
||||
|
||||
.provider-icon {
|
||||
font-size: 24px;
|
||||
color: #409eff;
|
||||
}
|
||||
|
||||
.provider-name {
|
||||
font-size: 16px;
|
||||
font-weight: 600;
|
||||
}
|
||||
}
|
||||
|
||||
.provider-actions {
|
||||
display: flex;
|
||||
gap: 8px;
|
||||
}
|
||||
}
|
||||
|
||||
.provider-models {
|
||||
background-color: #f5f7fa;
|
||||
border-radius: 8px;
|
||||
padding: 20px;
|
||||
|
||||
.models-header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
margin-bottom: 16px;
|
||||
|
||||
.models-title {
|
||||
font-weight: 600;
|
||||
color: #303133;
|
||||
font-size: 14px;
|
||||
}
|
||||
|
||||
.models-actions {
|
||||
display: flex;
|
||||
gap: 8px;
|
||||
}
|
||||
}
|
||||
|
||||
.models-select-wrapper {
|
||||
.models-select {
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
.selected-count {
|
||||
margin-top: 12px;
|
||||
font-size: 12px;
|
||||
color: #909399;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 优化下拉框样式
|
||||
:deep(.el-select) {
|
||||
.el-select__tags {
|
||||
flex-wrap: wrap;
|
||||
max-height: 120px;
|
||||
overflow-y: auto;
|
||||
}
|
||||
|
||||
.el-tag {
|
||||
margin: 2px 4px 2px 0;
|
||||
}
|
||||
}
|
||||
|
||||
:deep(.el-select-dropdown) {
|
||||
.el-select-dropdown__item {
|
||||
padding: 8px 16px;
|
||||
|
||||
&.is-selected {
|
||||
font-weight: 600;
|
||||
color: #409eff;
|
||||
background-color: #ecf5ff;
|
||||
}
|
||||
}
|
||||
}
|
||||
</style>
|
||||
@@ -219,4 +219,6 @@ __all__ = [
|
||||
"workflow_statistic",
|
||||
"workflow_trigger",
|
||||
"workspace",
|
||||
# extend: 二开
|
||||
"register_extend",
|
||||
]
|
||||
|
||||
@@ -1,21 +1,20 @@
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
|
||||
import jwt
|
||||
from flask import request
|
||||
from flask_restx import Resource, reqparse
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import api
|
||||
from extensions.ext_database import db
|
||||
from libs.login import login_required
|
||||
from .. import console_ns
|
||||
from models import Account
|
||||
from configs import dify_config
|
||||
from datetime import UTC, datetime
|
||||
from libs.login import login_required
|
||||
from extensions.ext_database import db
|
||||
from models.account import AccountStatus
|
||||
from flask_restx import Resource, reqparse
|
||||
from models.account_money_extend import AccountMoneyExtend
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.account_service_extend import TenantExtendService
|
||||
|
||||
|
||||
@console_ns.route("/admin_register_user")
|
||||
class AdminRegisterApi(Resource):
|
||||
"""Resource for user login."""
|
||||
@login_required
|
||||
@@ -78,4 +77,3 @@ class AdminRegisterApi(Resource):
|
||||
return {"result": "success", "data": "ok"}
|
||||
|
||||
|
||||
api.add_resource(AdminRegisterApi, "/admin_register_user")
|
||||
|
||||
@@ -22,6 +22,8 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
CSRF_WHITE_LIST = [
|
||||
re.compile(r"/console/api/apps/[a-f0-9-]+/workflows/draft"),
|
||||
# 后台服务端调用(仅 Bearer 认证),无浏览器 Cookie,豁免 CSRF
|
||||
re.compile(r"/console/api/admin_register_user"),
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -81,6 +81,7 @@ gaia:
|
||||
login_max_error_limit: 5
|
||||
SUPER_ADMIN_ACCOUNT_ID:
|
||||
SUPER_ADMIN_TENANT_ID:
|
||||
storage-path: /app/storage
|
||||
hua-wei-obs:
|
||||
path: you-path
|
||||
bucket: you-bucket
|
||||
|
||||
@@ -1688,18 +1688,31 @@ services:
|
||||
ports:
|
||||
- '8888:8888'
|
||||
depends_on:
|
||||
init_permissions:
|
||||
condition: service_completed_successfully
|
||||
db_postgres:
|
||||
condition: service_healthy
|
||||
redis:
|
||||
required: false
|
||||
db_mysql:
|
||||
condition: service_healthy
|
||||
required: false
|
||||
oceanbase:
|
||||
condition: service_healthy
|
||||
required: false
|
||||
seekdb:
|
||||
condition: service_healthy
|
||||
required: false
|
||||
redis:
|
||||
condition: service_started
|
||||
networks:
|
||||
- default
|
||||
links:
|
||||
- db_postgres
|
||||
- redis
|
||||
networks:
|
||||
- ssrf_proxy_network
|
||||
- default
|
||||
volumes:
|
||||
- ./admin-server/config.docker.yaml:/app/config.docker.yaml
|
||||
# 挂载 Dify storage 目录以访问 tenant 私钥(用于解密 provider credentials)
|
||||
- ./volumes/app/storage:/app/storage:ro
|
||||
|
||||
# Extend - sandbox-full
|
||||
sandbox-full:
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
-- ============================================================
|
||||
-- 应用版本:菜单、API、Casbin 权限、角色-菜单关联 插入语句
|
||||
-- 1.11.4 升级到 1.12.2 需要执行的权限 SQL
|
||||
-- 包含:应用版本、模型管理 的菜单、API、Casbin 权限、角色-菜单关联
|
||||
-- 执行前请确认:1) 表结构已存在 2) 若 id 冲突可调整或改用 INSERT IGNORE / ON CONFLICT
|
||||
-- 模型管理 API 的 id 从 259 起,若与现有数据冲突请先查 SELECT MAX(id) FROM sys_apis; 再调整
|
||||
-- ============================================================
|
||||
|
||||
-- --------------- 1. 菜单 sys_base_menus (应用版本) ---------------
|
||||
@@ -81,3 +83,80 @@ INSERT INTO sys_authority_menus (sys_authority_authority_id, sys_base_menu_id) V
|
||||
-- INSERT INTO sys_authority_menus (sys_authority_authority_id, sys_base_menu_id) VALUES (8881, 41);
|
||||
-- INSERT INTO sys_authority_menus (sys_authority_authority_id, sys_base_menu_id) VALUES (9528, 41);
|
||||
INSERT INTO sys_authority_menus (sys_authority_authority_id, sys_base_menu_id) VALUES (1, 41);
|
||||
|
||||
|
||||
-- ============================================================
|
||||
-- 模型管理:菜单、API、Casbin 权限、角色-菜单关联 插入语句
|
||||
-- ============================================================
|
||||
|
||||
-- --------------- 5. 菜单 sys_base_menus (模型管理,挂在「系统集成」下) ---------------
|
||||
INSERT INTO sys_base_menus (
|
||||
id, created_at, updated_at, deleted_at,
|
||||
menu_level, parent_id, path, name, hidden, component, sort,
|
||||
active_name, keep_alive, default_menu, title, icon, close_tab
|
||||
) VALUES (
|
||||
42,
|
||||
NOW(), NOW(), NULL,
|
||||
0, 38, 'IntegratedModelManagement', 'IntegratedModelManagement', false, 'view/systemIntegrated/modelManagement/index.vue', 3,
|
||||
'', false, false, '模型管理', 'cpu', false
|
||||
);
|
||||
-- PostgreSQL: INSERT ... ON CONFLICT (id) DO NOTHING;
|
||||
|
||||
|
||||
-- --------------- 6. API sys_apis (模型管理相关 11 条,proxy 使用通配符 /gaia/proxy/*) ---------------
|
||||
-- 请按当前库最大 id 调整起始 id,避免冲突。例如 MAX(id)=258 则从 259 起
|
||||
INSERT INTO sys_apis (id, created_at, updated_at, deleted_at, path, description, api_group, method) VALUES
|
||||
(259, NOW(), NOW(), NULL, '/gaia/model-provider/list', '获取提供商配置列表', '模型管理', 'GET'),
|
||||
(260, NOW(), NOW(), NULL, '/gaia/model-provider/update', '更新提供商配置', '模型管理', 'POST'),
|
||||
(261, NOW(), NOW(), NULL, '/gaia/model-provider/available-models', '获取可用模型', '模型管理', 'GET'),
|
||||
(262, NOW(), NOW(), NULL, '/gaia/model-provider/test-credentials', '测试提供商凭证', '模型管理', 'GET'),
|
||||
(263, NOW(), NOW(), NULL, '/gaia/model-provider/logs', '获取代理日志', '模型管理', 'GET'),
|
||||
(264, NOW(), NOW(), NULL, '/gaia/models', '获取开启的模型列表(第三方)', '模型管理', 'GET'),
|
||||
(265, NOW(), NOW(), NULL, '/gaia/proxy/*', '中转API(第三方)-GET', '模型管理', 'GET'),
|
||||
(266, NOW(), NOW(), NULL, '/gaia/proxy/*', '中转API(第三方)-POST', '模型管理', 'POST'),
|
||||
(267, NOW(), NOW(), NULL, '/gaia/proxy/*', '中转API(第三方)-PUT', '模型管理', 'PUT'),
|
||||
(268, NOW(), NOW(), NULL, '/gaia/proxy/*', '中转API(第三方)-PATCH', '模型管理', 'PATCH'),
|
||||
(269, NOW(), NOW(), NULL, '/gaia/proxy/*', '中转API(第三方)-DELETE', '模型管理', 'DELETE');
|
||||
|
||||
|
||||
-- --------------- 7. Casbin 规则 casbin_rule (模型管理 888/8881) ---------------
|
||||
INSERT INTO casbin_rule (ptype, v0, v1, v2) VALUES
|
||||
('p', '888', '/gaia/model-provider/list', 'GET'),
|
||||
('p', '888', '/gaia/model-provider/update', 'POST'),
|
||||
('p', '888', '/gaia/model-provider/available-models', 'GET'),
|
||||
('p', '888', '/gaia/model-provider/test-credentials', 'GET'),
|
||||
('p', '888', '/gaia/model-provider/logs', 'GET'),
|
||||
('p', '888', '/gaia/models', 'GET'),
|
||||
('p', '888', '/gaia/proxy/*', 'GET'),
|
||||
('p', '888', '/gaia/proxy/*', 'POST'),
|
||||
('p', '888', '/gaia/proxy/*', 'PUT'),
|
||||
('p', '888', '/gaia/proxy/*', 'PATCH'),
|
||||
('p', '888', '/gaia/proxy/*', 'DELETE'),
|
||||
('p', '8881', '/gaia/model-provider/list', 'GET'),
|
||||
('p', '8881', '/gaia/model-provider/update', 'POST'),
|
||||
('p', '8881', '/gaia/model-provider/available-models', 'GET'),
|
||||
('p', '8881', '/gaia/model-provider/test-credentials', 'GET'),
|
||||
('p', '8881', '/gaia/model-provider/logs', 'GET'),
|
||||
('p', '8881', '/gaia/models', 'GET'),
|
||||
('p', '8881', '/gaia/proxy/*', 'GET'),
|
||||
('p', '8881', '/gaia/proxy/*', 'POST'),
|
||||
('p', '8881', '/gaia/proxy/*', 'PUT'),
|
||||
('p', '8881', '/gaia/proxy/*', 'PATCH'),
|
||||
('p', '8881', '/gaia/proxy/*', 'DELETE');
|
||||
|
||||
|
||||
-- --------------- 8. 角色-菜单关联 sys_authority_menus (让角色 888 拥有「模型管理」菜单) ---------------
|
||||
INSERT INTO sys_authority_menus (sys_authority_authority_id, sys_base_menu_id) VALUES (888, 42);
|
||||
|
||||
|
||||
-- ============================================================
|
||||
-- 已有库修复:若曾用具体路径 /gaia/proxy/v1/chat/completions,改为通配符 /gaia/proxy/*
|
||||
-- ============================================================
|
||||
-- sys_apis:将具体路径统一改为通配符(与 source/system/api.go、Casbin 一致)
|
||||
UPDATE sys_apis SET path = '/gaia/proxy/*', description = '中转API(第三方)-POST', updated_at = NOW()
|
||||
WHERE path = '/gaia/proxy/v1/chat/completions' AND method = 'POST';
|
||||
-- 若有其他具体子路径也一并改为通配符(可选)
|
||||
UPDATE sys_apis SET path = '/gaia/proxy/*', updated_at = NOW()
|
||||
WHERE path LIKE '/gaia/proxy/%' AND path != '/gaia/proxy/*';
|
||||
-- casbin_rule:若存在具体路径策略可删除,保留通配符策略即可(通常初始化已是 /gaia/proxy/*,可不执行)
|
||||
-- DELETE FROM casbin_rule WHERE ptype = 'p' AND v1 LIKE '/gaia/proxy/%' AND v1 != '/gaia/proxy/*';
|
||||
|
||||
+3
-1
@@ -63,12 +63,14 @@ const nextConfig: NextConfig = {
|
||||
},
|
||||
]
|
||||
},
|
||||
// dev 时把 /console/api 和 /api 代理到 5001
|
||||
// dev 时把 /console/api 和 /api 代理到 5001,/admin 代理到 8888
|
||||
...(isDev && {
|
||||
async rewrites() {
|
||||
return [
|
||||
{ source: '/console/api/:path*', destination: 'http://localhost:5001/console/api/:path*' },
|
||||
{ source: '/api/:path*', destination: 'http://localhost:5001/api/:path*' },
|
||||
{ source: '/admin', destination: 'http://localhost:8888/' },
|
||||
{ source: '/admin/:path*', destination: 'http://localhost:8888/:path*' },
|
||||
]
|
||||
},
|
||||
}),
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import Cookies from 'js-cookie'
|
||||
import { CSRF_COOKIE_NAME } from '@/config'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
|
||||
// Admin server 使用独立的 JWT 认证,需要从 admin_token 获取
|
||||
const getAdminToken = () => {
|
||||
// 优先使用 admin_token,如果没有则尝试使用 console_token
|
||||
return localStorage.getItem('admin_token') || localStorage.getItem('console_token')
|
||||
return Cookies.get(CSRF_COOKIE_NAME())
|
||||
}
|
||||
|
||||
type batchProcessing = {
|
||||
|
||||
Reference in New Issue
Block a user