fix: 钉钉和快捷登录兼容

This commit is contained in:
npc0-hue
2026-01-26 00:25:06 +08:00
parent 8284f9c3c8
commit 4807f03e0a
24 changed files with 1359 additions and 339 deletions
@@ -3,6 +3,11 @@ from flask_restx import Resource
from controllers.console.app.error_extend import DingTalkNotExist
from controllers.console.wraps import setup_required
from libs.token import (
set_access_token_to_cookie,
set_csrf_token_to_cookie,
set_refresh_token_to_cookie,
)
from services.ding_talk_extend import DingTalkService
from .. import api
@@ -17,10 +22,17 @@ class DingTalk(Resource):
code = request.args.get("code", "")
if not (0 < len(code) < 500):
raise DingTalkNotExist
token, err = DingTalkService.get_user_info(code)
token_pair, redirect_url, err = DingTalkService.get_user_info(code)
if len(err) > 0:
raise DingTalkNotExist(err)
return redirect(token)
if token_pair is None:
raise DingTalkNotExist("Failed to get token pair")
response = redirect(redirect_url)
set_access_token_to_cookie(request, response, token_pair.access_token)
set_refresh_token_to_cookie(request, response, token_pair.refresh_token)
set_csrf_token_to_cookie(request, response, token_pair.csrf_token)
return response
class DingTalkThirdParty(Resource):
@@ -32,10 +44,17 @@ class DingTalkThirdParty(Resource):
code = request.args.get("authCode", "")
if not (0 < len(code) < 500):
raise DingTalkNotExist
token, err = DingTalkService.user_third_party(code)
token_pair, redirect_url, err = DingTalkService.user_third_party(code)
if len(err) > 0:
raise DingTalkNotExist(err)
return redirect(token)
if token_pair is None:
raise DingTalkNotExist("Failed to get token pair")
response = redirect(redirect_url)
set_access_token_to_cookie(request, response, token_pair.access_token)
set_refresh_token_to_cookie(request, response, token_pair.refresh_token)
set_csrf_token_to_cookie(request, response, token_pair.csrf_token)
return response
api.add_resource(DingTalk, "/ding-talk/login")
+2 -1
View File
@@ -1,4 +1,5 @@
import logging
from typing import Optional # Extend: OAuto third-party login
import httpx
from flask import current_app, redirect, request
@@ -13,7 +14,7 @@ from events.tenant_event import tenant_was_created
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from libs.helper import extract_remote_ip
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
from libs.oauth import GitHubOAuth, GoogleOAuth, OaOAuth, OAuthUserInfo # Extend: OAuto third-party login
from libs.token import (
set_access_token_to_cookie,
set_csrf_token_to_cookie,
+1
View File
@@ -3,6 +3,7 @@ import urllib.parse
from dataclasses import dataclass
import httpx
import requests
from configs import dify_config # Extend OAuto third-party login
from extensions.ext_database import db # Extend OAuto third-party login
+9
View File
@@ -1337,6 +1337,15 @@ class RegisterService:
account.status = status or AccountStatus.ACTIVE
account.initialized_at = naive_utc_now()
# extend begin:初始化用户额度数据
account_money_add = AccountMoneyExtend(
account_id=account.id,
used_quota=0,
total_quota=dify_config.ACCOUNT_TOTAL_QUOTA,
)
db.session.add(account_money_add)
# extend end:初始化用户额度数据
if open_id is not None and provider is not None:
AccountService.link_account_integrate(provider, open_id, account)
+236 -19
View File
@@ -36,6 +36,194 @@ class DingTalkService:
config.region_id = "central"
return dingtalkoauth2_1_0Client(config)
@classmethod
def extract_data(cls, dictionary: dict, path: str):
"""
从字典中提取指定路径的数据
支持点号分隔的路径和数组索引
Args:
dictionary (dict): 源字典
path (str): 以点分隔的路径 "data.email" "data[0].userName"
Returns:
提取的数据如果路径不存在返回None
"""
if not path:
return None
import re
# 处理路径中的数组索引,如 data[0].userName -> data.[0].userName
path = re.sub(r'\[(\d+)\]', r'.[\1]', path)
parts = path.split('.')
current = dictionary
for part in parts:
if not part:
continue
# 处理数组索引
array_match = re.match(r'\[(\d+)\]', part)
if array_match:
index = int(array_match.group(1))
if isinstance(current, list) and 0 <= index < len(current):
current = current[index]
else:
return None
elif isinstance(current, dict) and part in current:
current = current[part]
else:
return None
return current
@classmethod
def get_email_from_third_party_api(cls, userid: str, integration: SystemIntegrationExtend) -> str:
"""
通过第三方API获取用户邮箱
Args:
userid: 钉钉用户ID
integration: 集成配置对象
Returns:
邮箱地址获取失败返回空字符串
"""
try:
# 解析config字段
if not integration.config:
return ""
config_data = json.loads(integration.config)
email_api_config = config_data.get("email_api", {})
# 检查是否启用
if not email_api_config.get("enabled", False):
return ""
# 获取配置参数
api_url = email_api_config.get("url", "")
method = email_api_config.get("method", "GET").upper()
param_field = email_api_config.get("request_param_field", "userId")
email_field = email_api_config.get("response_email_field", "data[0].userName")
body_type = email_api_config.get("body_type", "raw")
headers = email_api_config.get("headers", {})
authorization = email_api_config.get("authorization", {})
body_data = email_api_config.get("body_data", {})
if not api_url:
logger.warning("Third-party email API URL is not configured")
return ""
# 准备请求头
request_headers = dict(headers) if headers else {}
# 处理Authorization
auth = None
auth_type = authorization.get("type", "none")
if auth_type == "bearer":
token = authorization.get("token", "")
if token:
request_headers["Authorization"] = f"Bearer {token}"
elif auth_type == "basic":
username = authorization.get("username", "")
password = authorization.get("password", "")
if username and password:
from requests.auth import HTTPBasicAuth
auth = HTTPBasicAuth(username, password)
# 构建请求数据
request_data = {}
# 处理Body数据(仅POST/PUT/DELETE
if method in ["POST", "PUT", "DELETE"]:
if body_type == "form-data":
# form-data: 合并body_data中的form_data
form_data_items = body_data.get("form_data", [])
for item in form_data_items:
if isinstance(item, dict) and "key" in item and "value" in item:
key = item.get("key", "").strip()
value = item.get("value", "").strip()
if key:
request_data[key] = value
# 确保主请求字段的值始终是userid(覆盖body_data中的值)
request_data[param_field] = userid
# form-data使用data参数
response = requests.request(
method, api_url, data=request_data,
headers=request_headers, auth=auth, timeout=10
)
elif body_type == "x-www-form-urlencoded":
# x-www-form-urlencoded: 合并body_data中的urlencoded
urlencoded_items = body_data.get("urlencoded", [])
for item in urlencoded_items:
if isinstance(item, dict) and "key" in item and "value" in item:
key = item.get("key", "").strip()
value = item.get("value", "").strip()
if key:
request_data[key] = value
# 确保主请求字段的值始终是userid(覆盖body_data中的值)
request_data[param_field] = userid
# 确保Content-Type正确
if "Content-Type" not in request_headers:
request_headers["Content-Type"] = "application/x-www-form-urlencoded"
response = requests.request(
method, api_url, data=request_data,
headers=request_headers, auth=auth, timeout=10
)
else: # raw (JSON)
# raw: 合并body_data中的raw JSON
raw_json = body_data.get("raw", "")
if raw_json:
try:
raw_data = json.loads(raw_json)
if isinstance(raw_data, dict):
request_data.update(raw_data)
except json.JSONDecodeError:
logger.warning("Failed to parse raw JSON body: %s", raw_json)
# 确保主请求字段的值始终是userid(覆盖raw JSON中的值)
request_data[param_field] = userid
# 确保Content-Type正确
if "Content-Type" not in request_headers:
request_headers["Content-Type"] = "application/json"
response = requests.request(
method, api_url, json=request_data,
headers=request_headers, auth=auth, timeout=10
)
else: # GET请求
# GET请求:所有数据作为URL参数
response = requests.get(
api_url, params=request_data,
headers=request_headers, auth=auth, timeout=10
)
# 检查响应
if response.status_code != 200:
logger.error(f"Third-party email API returned status code: {response.status_code}")
return ""
# 解析响应
response_data = response.json()
email = cls.extract_data(response_data, email_field)
if email and isinstance(email, str) and "@" in email:
logger.info("Successfully retrieved email from third-party API for userid: %s", userid)
return email
else:
logger.warning("Failed to extract valid email from response using path: %s", email_field)
return ""
except json.JSONDecodeError as e:
logger.error("Failed to parse email API config: %s", e)
return ""
except requests.exceptions.RequestException as e:
logger.error("Failed to call third-party email API: %s", e)
return ""
except Exception as e:
logger.error("Unexpected error in get_email_from_third_party_api: %s", e)
return ""
@classmethod
def get_user_token(cls, code: str) -> (str, str):
# get token
@@ -94,6 +282,13 @@ class DingTalkService:
@classmethod
def auto_create_user(cls, userid: str) -> (str, str):
# 获取集成配置
integration: SystemIntegrationExtend = (
db.session.query(SystemIntegrationExtend).filter(
SystemIntegrationExtend.status == True,
SystemIntegrationExtend.classify == SystemIntegrationClassify.SYSTEM_INTEGRATION_DINGTALK).first()
)
dingTalkToken, err = cls.get_access_token()
responses = requests.post(
f'https://oapi.dingtalk.com/topapi/v2/user/get?access_token={dingTalkToken}',
@@ -107,9 +302,21 @@ class DingTalkService:
return "", "Request for user information failed: " + userid + " " + json.dumps(reqs)
# Check if the user exists
username = reqs["result"]['name']
email = f"{''.join(lazy_pinyin(username))}@{dify_config.EMAIL_DOMAIN}"
if "email" in reqs["result"] and len(reqs["result"]["email"]):
# 优先尝试从第三方API获取邮箱
email = ""
if integration:
email = cls.get_email_from_third_party_api(userid, integration)
# 降级处理:使用钉钉返回的邮箱
if not email and "email" in reqs["result"] and len(reqs["result"]["email"]):
email = reqs["result"]["email"]
# 最终降级:使用拼音生成邮箱
if not email:
email = f"{''.join(lazy_pinyin(username))}@{dify_config.EMAIL_DOMAIN}"
logger.info("Using pinyin-generated email for user %s: %s", userid, email)
account: Account = (
db.session.query(Account).filter(Account.email == email).first()
)
@@ -141,22 +348,26 @@ class DingTalkService:
return token, ""
@classmethod
def user_third_party(cls, code: str) -> (str, str):
def user_third_party(cls, code: str):
"""
第三方钉钉登录
返回: (token_pair, redirect_url, error)
"""
userToken, err = cls.get_user_token(code)
if err != "":
return "", f"Failed to obtain token: {err}"
return None, "", f"Failed to obtain token: {err}"
response = requests.get(
"https://api.dingtalk.com/v1.0/contact/users/me",
headers={"x-acs-dingtalk-access-token": userToken},
)
# Check the response status code
if response.status_code != 200:
return "", f"Request failed, status code: {response.status_code}, msg: {response.text}"
return None, "", f"Request failed, status code: {response.status_code}, msg: {response.text}"
# Print the response content
req = response.json()
if "statusCode" in req.keys() and req["statusCode"] != 200:
return "", f"Request failed, msg: {req.message}"
return None, "", f"Request failed, msg: {req.message}"
# 提取userid
dingTalkToken, err = cls.get_access_token()
unionIdResponse = requests.post(
@@ -165,37 +376,43 @@ class DingTalkService:
)
# Check the response status code
if unionIdResponse.status_code != 200:
return "", f"unionIdResponse failed, status code: {unionIdResponse.status_code}, msg: {unionIdResponse.text}"
return None, "", f"unionIdResponse failed, status code: {unionIdResponse.status_code}, msg: {unionIdResponse.text}"
# Print the response content
unionIdReq = unionIdResponse.json()
if unionIdReq["errcode"] != 0:
return "", f"Request failed, msg: {unionIdReq['errmsg']}"
return None, "", f"Request failed, msg: {unionIdReq['errmsg']}"
token, err = cls.auto_create_user(unionIdReq["result"]["userid"])
token_pair, err = cls.auto_create_user(unionIdReq["result"]["userid"])
if len(err) > 0:
return "", "Request failed: " + err
return f"{dify_config.CONSOLE_WEB_URL}/explore/apps-center-extend?console_token={token.access_token}&&refresh_token={token.refresh_token}", ""
return None, "", "Request failed: " + err
redirect_url = f"{dify_config.CONSOLE_WEB_URL}/explore/apps-center-extend"
return token_pair, redirect_url, ""
@classmethod
def get_user_info(cls, code: str) -> (str, str):
def get_user_info(cls, code: str):
"""
获取用户信息并登录
返回: (token_pair, redirect_url, error)
"""
host = "https://oapi.dingtalk.com/topapi/v2/user"
token, err = cls.get_access_token()
if err != "":
return "", f"Failed to obtain token: {err}"
return None, "", f"Failed to obtain token: {err}"
response = requests.post(
f"{host}/getuserinfo?access_token={token}",
json={"code": code},
)
# Check the response status code
if response.status_code != 200:
return "", f"Request failed, status code: {response.status_code}"
return None, "", f"Request failed, status code: {response.status_code}"
# Print the response content
req = response.json()
if req["errcode"] != 0:
return "", "Request failed: " + req["errmsg"]
token, err = cls.auto_create_user(req["result"]["userid"])
return None, "", "Request failed: " + req["errmsg"]
token_pair, err = cls.auto_create_user(req["result"]["userid"])
if len(err) != 0:
return "", "Request failed: " + err
return None, "", "Request failed: " + err
return f"{dify_config.CONSOLE_WEB_URL}/explore/apps-center-extend?console_token={token.access_token}&&refresh_token={token.refresh_token}", ""
redirect_url = f"{dify_config.CONSOLE_WEB_URL}/explore/apps-center-extend"
return token_pair, redirect_url, ""
@@ -32,8 +32,11 @@ class RecommendedAppService:
classList = app.tags
description = app.description
config = app.app_model_config
# Extend: start Handle apps without tags
if len(classList) == 0:
classList.append(Tag(name="未分类"))
# Create a simple object with name attribute for "未分类" category
classList.append(type('Tag', (), {'name': '未分类'})())
# Extend: stop Handle apps without tags
if (
len(description) == 0
and config is not None
@@ -60,6 +63,7 @@ class RecommendedAppService:
"icon_background": app.icon_background,
},
"app_id": installed_app.app_id,
"installed_id": installed_app.id,
"description": description,
"copyright": "",
"privacy_policy": "",