mirror of
https://github.com/YFGaia/dify-plus.git
synced 2026-06-04 10:14:00 +08:00
fix: 规范化处理代码
This commit is contained in:
@@ -135,6 +135,7 @@ func (e *SystemIntegratedService) OAuth2CodeLogin(
|
|||||||
return nil, fmt.Errorf("无法从 OAuth2 用户信息中获取邮箱或用户唯一标识")
|
return nil, fmt.Errorf("无法从 OAuth2 用户信息中获取邮箱或用户唯一标识")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fmt.Println("OAuth2CodeLogin", email, username)
|
||||||
sysUser, err := e.findUserByEmailOrPhone(email, userID)
|
sysUser, err := e.findUserByEmailOrPhone(email, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -187,7 +188,8 @@ func (e *SystemIntegratedService) DingTalkTestCallback(code string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// DingTalkCodeLogin 钉钉 code 换用户并登录(扫码/OAuth2 回调带 code)
|
// DingTalkCodeLogin 钉钉 code 换用户并登录(扫码/OAuth2 回调带 code)
|
||||||
func (e *SystemIntegratedService) DingTalkCodeLogin(req request.GaiaDingTalkLoginReq) (*response.GaiaLoginResult, error) {
|
func (e *SystemIntegratedService) DingTalkCodeLogin(
|
||||||
|
req request.GaiaDingTalkLoginReq) (*response.GaiaLoginResult, error) {
|
||||||
integrate := e.getIntegratedConfigRaw(gaia.SystemIntegrationDingTalk)
|
integrate := e.getIntegratedConfigRaw(gaia.SystemIntegrationDingTalk)
|
||||||
if !integrate.Status {
|
if !integrate.Status {
|
||||||
return nil, fmt.Errorf("钉钉登录未启用")
|
return nil, fmt.Errorf("钉钉登录未启用")
|
||||||
@@ -277,6 +279,7 @@ func (e *SystemIntegratedService) DingTalkCodeLogin(req request.GaiaDingTalkLogi
|
|||||||
if emailConfig.Enabled && dingId != "" {
|
if emailConfig.Enabled && dingId != "" {
|
||||||
emailList, err = e.callEmailApi(dingId, emailConfig)
|
emailList, err = e.callEmailApi(dingId, emailConfig)
|
||||||
if err == nil && len(emailList) > 0 {
|
if err == nil && len(emailList) > 0 {
|
||||||
|
fmt.Println("钉钉 code 换用户并登录(扫码/OAuth2 回调带 code)", emailList)
|
||||||
sysUser, findErr := e.findUserByEmail(emailList)
|
sysUser, findErr := e.findUserByEmail(emailList)
|
||||||
if findErr != nil {
|
if findErr != nil {
|
||||||
return nil, findErr
|
return nil, findErr
|
||||||
@@ -301,6 +304,7 @@ func (e *SystemIntegratedService) DingTalkCodeLogin(req request.GaiaDingTalkLogi
|
|||||||
return nil, fmt.Errorf("钉钉未返回邮箱")
|
return nil, fmt.Errorf("钉钉未返回邮箱")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fmt.Println("钉钉 code 换用户并登录第三方邮箱 API 获取失败", email)
|
||||||
sysUser, err := e.findUserByEmail([]string{email})
|
sysUser, err := e.findUserByEmail([]string{email})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -404,7 +408,8 @@ func (e *SystemIntegratedService) findUserByEmail(mailList []string) (*system.Sy
|
|||||||
}
|
}
|
||||||
|
|
||||||
// findUserByEmailOrPhone 按邮箱或用户唯一标识(如手机号)查找用户,优先邮箱
|
// findUserByEmailOrPhone 按邮箱或用户唯一标识(如手机号)查找用户,优先邮箱
|
||||||
func (e *SystemIntegratedService) findUserByEmailOrPhone(mail, userID string) (u *system.SysUser, err error) {
|
func (e *SystemIntegratedService) findUserByEmailOrPhone(
|
||||||
|
mail, userID string) (u *system.SysUser, err error) {
|
||||||
if mail != "" {
|
if mail != "" {
|
||||||
if u, err = e.findUserByEmail([]string{mail}); err == nil {
|
if u, err = e.findUserByEmail([]string{mail}); err == nil {
|
||||||
return u, nil
|
return u, nil
|
||||||
|
|||||||
@@ -398,7 +398,8 @@ func (s *ModelProviderService) getAvailableModelsFromProviderModelCredentials(pr
|
|||||||
Distinct("model_name").
|
Distinct("model_name").
|
||||||
Pluck("model_name", &modelNames).Error
|
Pluck("model_name", &modelNames).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
global.GVA_LOG.Warn("从 provider_model_credentials 拉取模型列表失败", zap.String("provider", providerName), zap.Error(err))
|
global.GVA_LOG.Warn("从 provider_model_credentials 拉取模型列表失败", zap.String(
|
||||||
|
"provider", providerName), zap.Error(err))
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
list := make([]gaiaResponse.ModelInfo, 0, len(modelNames))
|
list := make([]gaiaResponse.ModelInfo, 0, len(modelNames))
|
||||||
@@ -456,7 +457,8 @@ func (s *ModelProviderService) GetAvailableModelsFromDify(providerName string) (
|
|||||||
// 兼容两种响应格式:
|
// 兼容两种响应格式:
|
||||||
// 1) OpenAI: { "data": [ { "id": "..." }, ... ] }
|
// 1) OpenAI: { "data": [ { "id": "..." }, ... ] }
|
||||||
// 2) 通义: { "success": true, "output": { "models": [ { "model": "...", "name": "..." }, ... ] } }
|
// 2) 通义: { "success": true, "output": { "models": [ { "model": "...", "name": "..." }, ... ] } }
|
||||||
func (s *ModelProviderService) fetchOpenAICompatibleModels(client *http.Client, baseURL, apiKey string) ([]gaiaResponse.ModelInfo, error) {
|
func (s *ModelProviderService) fetchOpenAICompatibleModels(client *http.Client, baseURL, apiKey string) (
|
||||||
|
[]gaiaResponse.ModelInfo, error) {
|
||||||
url := strings.TrimSuffix(baseURL, "/") + "/v1/models"
|
url := strings.TrimSuffix(baseURL, "/") + "/v1/models"
|
||||||
req, err := http.NewRequest("GET", url, nil)
|
req, err := http.NewRequest("GET", url, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -471,7 +473,8 @@ func (s *ModelProviderService) fetchOpenAICompatibleModels(client *http.Client,
|
|||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
body, _ := io.ReadAll(resp.Body)
|
body, _ := io.ReadAll(resp.Body)
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
global.GVA_LOG.Warn("拉取模型列表接口非 200", zap.String("url", url), zap.Int("status", resp.StatusCode), zap.String("body", string(body)))
|
global.GVA_LOG.Warn("拉取模型列表接口非 200", zap.String("url", url), zap.Int(
|
||||||
|
"status", resp.StatusCode), zap.String("body", string(body)))
|
||||||
return nil, fmt.Errorf("接口返回 %d", resp.StatusCode)
|
return nil, fmt.Errorf("接口返回 %d", resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -510,7 +513,8 @@ func (s *ModelProviderService) fetchOpenAICompatibleModels(client *http.Client,
|
|||||||
|
|
||||||
// fetchGeminiModels 调用 Google Gemini GET /v1beta/models?key=API_KEY,解析 models[],支持分页。
|
// fetchGeminiModels 调用 Google Gemini GET /v1beta/models?key=API_KEY,解析 models[],支持分页。
|
||||||
// 认证使用 query 参数 key,响应格式:{ "models": [ { "name": "models/xxx", "baseModelId": "xxx", "displayName": "..." } ], "nextPageToken": "..." }
|
// 认证使用 query 参数 key,响应格式:{ "models": [ { "name": "models/xxx", "baseModelId": "xxx", "displayName": "..." } ], "nextPageToken": "..." }
|
||||||
func (s *ModelProviderService) fetchGeminiModels(client *http.Client, baseURL, apiKey string) ([]gaiaResponse.ModelInfo, error) {
|
func (s *ModelProviderService) fetchGeminiModels(client *http.Client, baseURL, apiKey string) (
|
||||||
|
[]gaiaResponse.ModelInfo, error) {
|
||||||
baseURL = strings.TrimSuffix(baseURL, "/")
|
baseURL = strings.TrimSuffix(baseURL, "/")
|
||||||
all := make([]gaiaResponse.ModelInfo, 0)
|
all := make([]gaiaResponse.ModelInfo, 0)
|
||||||
pageToken := ""
|
pageToken := ""
|
||||||
@@ -532,7 +536,9 @@ func (s *ModelProviderService) fetchGeminiModels(client *http.Client, baseURL, a
|
|||||||
body, _ := io.ReadAll(resp.Body)
|
body, _ := io.ReadAll(resp.Body)
|
||||||
resp.Body.Close()
|
resp.Body.Close()
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
global.GVA_LOG.Warn("拉取 Gemini 模型列表非 200", zap.String("url", baseURL+"/v1beta/models"), zap.Int("status", resp.StatusCode), zap.String("body", string(body)))
|
global.GVA_LOG.Warn("拉取 Gemini 模型列表非 200", zap.String(
|
||||||
|
"url", baseURL+"/v1beta/models"), zap.Int("status", resp.StatusCode), zap.String(
|
||||||
|
"body", string(body)))
|
||||||
return nil, fmt.Errorf("接口返回 %d", resp.StatusCode)
|
return nil, fmt.Errorf("接口返回 %d", resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -568,7 +574,8 @@ func (s *ModelProviderService) fetchGeminiModels(client *http.Client, baseURL, a
|
|||||||
|
|
||||||
// fetchAzureOpenAIModels 调用 Azure OpenAI GET {endpoint}/openai/models?api-version={version},解析 data[]。
|
// fetchAzureOpenAIModels 调用 Azure OpenAI GET {endpoint}/openai/models?api-version={version},解析 data[]。
|
||||||
// 认证使用 api-key 请求头,响应格式:{ "data": [ { "id": "...", "object": "model" } ] }
|
// 认证使用 api-key 请求头,响应格式:{ "data": [ { "id": "...", "object": "model" } ] }
|
||||||
func (s *ModelProviderService) fetchAzureOpenAIModels(client *http.Client, baseURL, apiKey, apiVersion string) ([]gaiaResponse.ModelInfo, error) {
|
func (s *ModelProviderService) fetchAzureOpenAIModels(client *http.Client, baseURL, apiKey, apiVersion string) (
|
||||||
|
[]gaiaResponse.ModelInfo, error) {
|
||||||
baseURL = strings.TrimSuffix(baseURL, "/")
|
baseURL = strings.TrimSuffix(baseURL, "/")
|
||||||
if apiVersion == "" {
|
if apiVersion == "" {
|
||||||
apiVersion = "2024-08-01-preview" // 默认 API 版本
|
apiVersion = "2024-08-01-preview" // 默认 API 版本
|
||||||
@@ -590,7 +597,8 @@ func (s *ModelProviderService) fetchAzureOpenAIModels(client *http.Client, baseU
|
|||||||
|
|
||||||
body, _ := io.ReadAll(resp.Body)
|
body, _ := io.ReadAll(resp.Body)
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
global.GVA_LOG.Warn("拉取 Azure OpenAI 模型列表非 200", zap.String("url", url), zap.Int("status", resp.StatusCode), zap.String("body", string(body)))
|
global.GVA_LOG.Warn("拉取 Azure OpenAI 模型列表非 200", zap.String("url", url), zap.Int(
|
||||||
|
"status", resp.StatusCode), zap.String("body", string(body)))
|
||||||
return nil, fmt.Errorf("接口返回 %d", resp.StatusCode)
|
return nil, fmt.Errorf("接口返回 %d", resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -920,7 +928,7 @@ func (s *ModelProviderService) ProxyChat(userID string, req gaiaRequest.ChatRequ
|
|||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
// 记录日志
|
// 记录日志
|
||||||
log := gaia.ModelProxyLog{
|
global.GVA_DB.Create(&gaia.ModelProxyLog{
|
||||||
UserId: userID,
|
UserId: userID,
|
||||||
ProviderName: providerName,
|
ProviderName: providerName,
|
||||||
ModelName: req.Model,
|
ModelName: req.Model,
|
||||||
@@ -929,8 +937,7 @@ func (s *ModelProviderService) ProxyChat(userID string, req gaiaRequest.ChatRequ
|
|||||||
Status: status,
|
Status: status,
|
||||||
ErrorMessage: errorMsg,
|
ErrorMessage: errorMsg,
|
||||||
CreatedAt: startTime,
|
CreatedAt: startTime,
|
||||||
}
|
})
|
||||||
global.GVA_DB.Create(&log)
|
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// 处理流式响应
|
// 处理流式响应
|
||||||
@@ -938,7 +945,7 @@ func (s *ModelProviderService) ProxyChat(userID string, req gaiaRequest.ChatRequ
|
|||||||
scanner := bufio.NewScanner(resp.Body)
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
line := scanner.Text()
|
line := scanner.Text()
|
||||||
if _, err := writer.Write([]byte(line + "\n")); err != nil {
|
if _, err = writer.Write([]byte(line + "\n")); err != nil {
|
||||||
status = "error"
|
status = "error"
|
||||||
errorMsg = err.Error()
|
errorMsg = err.Error()
|
||||||
return err
|
return err
|
||||||
@@ -948,14 +955,14 @@ func (s *ModelProviderService) ProxyChat(userID string, req gaiaRequest.ChatRequ
|
|||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err := scanner.Err(); err != nil {
|
if err = scanner.Err(); err != nil {
|
||||||
status = "error"
|
status = "error"
|
||||||
errorMsg = err.Error()
|
errorMsg = err.Error()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// 非流式响应
|
// 非流式响应
|
||||||
if _, err := io.Copy(writer, resp.Body); err != nil {
|
if _, err = io.Copy(writer, resp.Body); err != nil {
|
||||||
status = "error"
|
status = "error"
|
||||||
errorMsg = err.Error()
|
errorMsg = err.Error()
|
||||||
return err
|
return err
|
||||||
|
|||||||
Reference in New Issue
Block a user