mirror of
https://github.com/YFGaia/dify-plus.git
synced 2026-06-12 18:11:42 +08:00
fix: 钉钉和快捷登录兼容
This commit is contained in:
@@ -3,6 +3,11 @@ from flask_restx import Resource
|
||||
|
||||
from controllers.console.app.error_extend import DingTalkNotExist
|
||||
from controllers.console.wraps import setup_required
|
||||
from libs.token import (
|
||||
set_access_token_to_cookie,
|
||||
set_csrf_token_to_cookie,
|
||||
set_refresh_token_to_cookie,
|
||||
)
|
||||
from services.ding_talk_extend import DingTalkService
|
||||
|
||||
from .. import api
|
||||
@@ -17,10 +22,17 @@ class DingTalk(Resource):
|
||||
code = request.args.get("code", "")
|
||||
if not (0 < len(code) < 500):
|
||||
raise DingTalkNotExist
|
||||
token, err = DingTalkService.get_user_info(code)
|
||||
token_pair, redirect_url, err = DingTalkService.get_user_info(code)
|
||||
if len(err) > 0:
|
||||
raise DingTalkNotExist(err)
|
||||
return redirect(token)
|
||||
if token_pair is None:
|
||||
raise DingTalkNotExist("Failed to get token pair")
|
||||
|
||||
response = redirect(redirect_url)
|
||||
set_access_token_to_cookie(request, response, token_pair.access_token)
|
||||
set_refresh_token_to_cookie(request, response, token_pair.refresh_token)
|
||||
set_csrf_token_to_cookie(request, response, token_pair.csrf_token)
|
||||
return response
|
||||
|
||||
|
||||
class DingTalkThirdParty(Resource):
|
||||
@@ -32,10 +44,17 @@ class DingTalkThirdParty(Resource):
|
||||
code = request.args.get("authCode", "")
|
||||
if not (0 < len(code) < 500):
|
||||
raise DingTalkNotExist
|
||||
token, err = DingTalkService.user_third_party(code)
|
||||
token_pair, redirect_url, err = DingTalkService.user_third_party(code)
|
||||
if len(err) > 0:
|
||||
raise DingTalkNotExist(err)
|
||||
return redirect(token)
|
||||
if token_pair is None:
|
||||
raise DingTalkNotExist("Failed to get token pair")
|
||||
|
||||
response = redirect(redirect_url)
|
||||
set_access_token_to_cookie(request, response, token_pair.access_token)
|
||||
set_refresh_token_to_cookie(request, response, token_pair.refresh_token)
|
||||
set_csrf_token_to_cookie(request, response, token_pair.csrf_token)
|
||||
return response
|
||||
|
||||
|
||||
api.add_resource(DingTalk, "/ding-talk/login")
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
from typing import Optional # Extend: OAuto third-party login
|
||||
|
||||
import httpx
|
||||
from flask import current_app, redirect, request
|
||||
@@ -13,7 +14,7 @@ from events.tenant_event import tenant_was_created
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import extract_remote_ip
|
||||
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
|
||||
from libs.oauth import GitHubOAuth, GoogleOAuth, OaOAuth, OAuthUserInfo # Extend: OAuto third-party login
|
||||
from libs.token import (
|
||||
set_access_token_to_cookie,
|
||||
set_csrf_token_to_cookie,
|
||||
|
||||
@@ -3,6 +3,7 @@ import urllib.parse
|
||||
from dataclasses import dataclass
|
||||
|
||||
import httpx
|
||||
import requests
|
||||
|
||||
from configs import dify_config # Extend OAuto third-party login
|
||||
from extensions.ext_database import db # Extend OAuto third-party login
|
||||
|
||||
@@ -1337,6 +1337,15 @@ class RegisterService:
|
||||
account.status = status or AccountStatus.ACTIVE
|
||||
account.initialized_at = naive_utc_now()
|
||||
|
||||
# extend begin:初始化用户额度数据
|
||||
account_money_add = AccountMoneyExtend(
|
||||
account_id=account.id,
|
||||
used_quota=0,
|
||||
total_quota=dify_config.ACCOUNT_TOTAL_QUOTA,
|
||||
)
|
||||
db.session.add(account_money_add)
|
||||
# extend end:初始化用户额度数据
|
||||
|
||||
if open_id is not None and provider is not None:
|
||||
AccountService.link_account_integrate(provider, open_id, account)
|
||||
|
||||
|
||||
@@ -36,6 +36,194 @@ class DingTalkService:
|
||||
config.region_id = "central"
|
||||
return dingtalkoauth2_1_0Client(config)
|
||||
|
||||
@classmethod
|
||||
def extract_data(cls, dictionary: dict, path: str):
|
||||
"""
|
||||
从字典中提取指定路径的数据
|
||||
支持点号分隔的路径和数组索引
|
||||
|
||||
Args:
|
||||
dictionary (dict): 源字典
|
||||
path (str): 以点分隔的路径,如 "data.email" 或 "data[0].userName"
|
||||
|
||||
Returns:
|
||||
提取的数据,如果路径不存在返回None
|
||||
"""
|
||||
if not path:
|
||||
return None
|
||||
|
||||
import re
|
||||
|
||||
# 处理路径中的数组索引,如 data[0].userName -> data.[0].userName
|
||||
path = re.sub(r'\[(\d+)\]', r'.[\1]', path)
|
||||
parts = path.split('.')
|
||||
current = dictionary
|
||||
|
||||
for part in parts:
|
||||
if not part:
|
||||
continue
|
||||
|
||||
# 处理数组索引
|
||||
array_match = re.match(r'\[(\d+)\]', part)
|
||||
if array_match:
|
||||
index = int(array_match.group(1))
|
||||
if isinstance(current, list) and 0 <= index < len(current):
|
||||
current = current[index]
|
||||
else:
|
||||
return None
|
||||
elif isinstance(current, dict) and part in current:
|
||||
current = current[part]
|
||||
else:
|
||||
return None
|
||||
|
||||
return current
|
||||
|
||||
@classmethod
|
||||
def get_email_from_third_party_api(cls, userid: str, integration: SystemIntegrationExtend) -> str:
|
||||
"""
|
||||
通过第三方API获取用户邮箱
|
||||
|
||||
Args:
|
||||
userid: 钉钉用户ID
|
||||
integration: 集成配置对象
|
||||
|
||||
Returns:
|
||||
邮箱地址,获取失败返回空字符串
|
||||
"""
|
||||
try:
|
||||
# 解析config字段
|
||||
if not integration.config:
|
||||
return ""
|
||||
|
||||
config_data = json.loads(integration.config)
|
||||
email_api_config = config_data.get("email_api", {})
|
||||
|
||||
# 检查是否启用
|
||||
if not email_api_config.get("enabled", False):
|
||||
return ""
|
||||
|
||||
# 获取配置参数
|
||||
api_url = email_api_config.get("url", "")
|
||||
method = email_api_config.get("method", "GET").upper()
|
||||
param_field = email_api_config.get("request_param_field", "userId")
|
||||
email_field = email_api_config.get("response_email_field", "data[0].userName")
|
||||
body_type = email_api_config.get("body_type", "raw")
|
||||
headers = email_api_config.get("headers", {})
|
||||
authorization = email_api_config.get("authorization", {})
|
||||
body_data = email_api_config.get("body_data", {})
|
||||
|
||||
if not api_url:
|
||||
logger.warning("Third-party email API URL is not configured")
|
||||
return ""
|
||||
|
||||
# 准备请求头
|
||||
request_headers = dict(headers) if headers else {}
|
||||
|
||||
# 处理Authorization
|
||||
auth = None
|
||||
auth_type = authorization.get("type", "none")
|
||||
if auth_type == "bearer":
|
||||
token = authorization.get("token", "")
|
||||
if token:
|
||||
request_headers["Authorization"] = f"Bearer {token}"
|
||||
elif auth_type == "basic":
|
||||
username = authorization.get("username", "")
|
||||
password = authorization.get("password", "")
|
||||
if username and password:
|
||||
from requests.auth import HTTPBasicAuth
|
||||
auth = HTTPBasicAuth(username, password)
|
||||
|
||||
# 构建请求数据
|
||||
request_data = {}
|
||||
|
||||
# 处理Body数据(仅POST/PUT/DELETE)
|
||||
if method in ["POST", "PUT", "DELETE"]:
|
||||
if body_type == "form-data":
|
||||
# form-data: 合并body_data中的form_data
|
||||
form_data_items = body_data.get("form_data", [])
|
||||
for item in form_data_items:
|
||||
if isinstance(item, dict) and "key" in item and "value" in item:
|
||||
key = item.get("key", "").strip()
|
||||
value = item.get("value", "").strip()
|
||||
if key:
|
||||
request_data[key] = value
|
||||
# 确保主请求字段的值始终是userid(覆盖body_data中的值)
|
||||
request_data[param_field] = userid
|
||||
# form-data使用data参数
|
||||
response = requests.request(
|
||||
method, api_url, data=request_data,
|
||||
headers=request_headers, auth=auth, timeout=10
|
||||
)
|
||||
elif body_type == "x-www-form-urlencoded":
|
||||
# x-www-form-urlencoded: 合并body_data中的urlencoded
|
||||
urlencoded_items = body_data.get("urlencoded", [])
|
||||
for item in urlencoded_items:
|
||||
if isinstance(item, dict) and "key" in item and "value" in item:
|
||||
key = item.get("key", "").strip()
|
||||
value = item.get("value", "").strip()
|
||||
if key:
|
||||
request_data[key] = value
|
||||
# 确保主请求字段的值始终是userid(覆盖body_data中的值)
|
||||
request_data[param_field] = userid
|
||||
# 确保Content-Type正确
|
||||
if "Content-Type" not in request_headers:
|
||||
request_headers["Content-Type"] = "application/x-www-form-urlencoded"
|
||||
response = requests.request(
|
||||
method, api_url, data=request_data,
|
||||
headers=request_headers, auth=auth, timeout=10
|
||||
)
|
||||
else: # raw (JSON)
|
||||
# raw: 合并body_data中的raw JSON
|
||||
raw_json = body_data.get("raw", "")
|
||||
if raw_json:
|
||||
try:
|
||||
raw_data = json.loads(raw_json)
|
||||
if isinstance(raw_data, dict):
|
||||
request_data.update(raw_data)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("Failed to parse raw JSON body: %s", raw_json)
|
||||
# 确保主请求字段的值始终是userid(覆盖raw JSON中的值)
|
||||
request_data[param_field] = userid
|
||||
# 确保Content-Type正确
|
||||
if "Content-Type" not in request_headers:
|
||||
request_headers["Content-Type"] = "application/json"
|
||||
response = requests.request(
|
||||
method, api_url, json=request_data,
|
||||
headers=request_headers, auth=auth, timeout=10
|
||||
)
|
||||
else: # GET请求
|
||||
# GET请求:所有数据作为URL参数
|
||||
response = requests.get(
|
||||
api_url, params=request_data,
|
||||
headers=request_headers, auth=auth, timeout=10
|
||||
)
|
||||
|
||||
# 检查响应
|
||||
if response.status_code != 200:
|
||||
logger.error(f"Third-party email API returned status code: {response.status_code}")
|
||||
return ""
|
||||
|
||||
# 解析响应
|
||||
response_data = response.json()
|
||||
email = cls.extract_data(response_data, email_field)
|
||||
|
||||
if email and isinstance(email, str) and "@" in email:
|
||||
logger.info("Successfully retrieved email from third-party API for userid: %s", userid)
|
||||
return email
|
||||
else:
|
||||
logger.warning("Failed to extract valid email from response using path: %s", email_field)
|
||||
return ""
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error("Failed to parse email API config: %s", e)
|
||||
return ""
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error("Failed to call third-party email API: %s", e)
|
||||
return ""
|
||||
except Exception as e:
|
||||
logger.error("Unexpected error in get_email_from_third_party_api: %s", e)
|
||||
return ""
|
||||
|
||||
@classmethod
|
||||
def get_user_token(cls, code: str) -> (str, str):
|
||||
# get token
|
||||
@@ -94,6 +282,13 @@ class DingTalkService:
|
||||
|
||||
@classmethod
|
||||
def auto_create_user(cls, userid: str) -> (str, str):
|
||||
# 获取集成配置
|
||||
integration: SystemIntegrationExtend = (
|
||||
db.session.query(SystemIntegrationExtend).filter(
|
||||
SystemIntegrationExtend.status == True,
|
||||
SystemIntegrationExtend.classify == SystemIntegrationClassify.SYSTEM_INTEGRATION_DINGTALK).first()
|
||||
)
|
||||
|
||||
dingTalkToken, err = cls.get_access_token()
|
||||
responses = requests.post(
|
||||
f'https://oapi.dingtalk.com/topapi/v2/user/get?access_token={dingTalkToken}',
|
||||
@@ -107,9 +302,21 @@ class DingTalkService:
|
||||
return "", "Request for user information failed: " + userid + " " + json.dumps(reqs)
|
||||
# Check if the user exists
|
||||
username = reqs["result"]['name']
|
||||
email = f"{''.join(lazy_pinyin(username))}@{dify_config.EMAIL_DOMAIN}"
|
||||
if "email" in reqs["result"] and len(reqs["result"]["email"]):
|
||||
|
||||
# 优先尝试从第三方API获取邮箱
|
||||
email = ""
|
||||
if integration:
|
||||
email = cls.get_email_from_third_party_api(userid, integration)
|
||||
|
||||
# 降级处理:使用钉钉返回的邮箱
|
||||
if not email and "email" in reqs["result"] and len(reqs["result"]["email"]):
|
||||
email = reqs["result"]["email"]
|
||||
|
||||
# 最终降级:使用拼音生成邮箱
|
||||
if not email:
|
||||
email = f"{''.join(lazy_pinyin(username))}@{dify_config.EMAIL_DOMAIN}"
|
||||
logger.info("Using pinyin-generated email for user %s: %s", userid, email)
|
||||
|
||||
account: Account = (
|
||||
db.session.query(Account).filter(Account.email == email).first()
|
||||
)
|
||||
@@ -141,22 +348,26 @@ class DingTalkService:
|
||||
return token, ""
|
||||
|
||||
@classmethod
|
||||
def user_third_party(cls, code: str) -> (str, str):
|
||||
def user_third_party(cls, code: str):
|
||||
"""
|
||||
第三方钉钉登录
|
||||
返回: (token_pair, redirect_url, error)
|
||||
"""
|
||||
userToken, err = cls.get_user_token(code)
|
||||
|
||||
if err != "":
|
||||
return "", f"Failed to obtain token: {err}"
|
||||
return None, "", f"Failed to obtain token: {err}"
|
||||
response = requests.get(
|
||||
"https://api.dingtalk.com/v1.0/contact/users/me",
|
||||
headers={"x-acs-dingtalk-access-token": userToken},
|
||||
)
|
||||
# Check the response status code
|
||||
if response.status_code != 200:
|
||||
return "", f"Request failed, status code: {response.status_code}, msg: {response.text}"
|
||||
return None, "", f"Request failed, status code: {response.status_code}, msg: {response.text}"
|
||||
# Print the response content
|
||||
req = response.json()
|
||||
if "statusCode" in req.keys() and req["statusCode"] != 200:
|
||||
return "", f"Request failed, msg: {req.message}"
|
||||
return None, "", f"Request failed, msg: {req.message}"
|
||||
# 提取userid
|
||||
dingTalkToken, err = cls.get_access_token()
|
||||
unionIdResponse = requests.post(
|
||||
@@ -165,37 +376,43 @@ class DingTalkService:
|
||||
)
|
||||
# Check the response status code
|
||||
if unionIdResponse.status_code != 200:
|
||||
return "", f"unionIdResponse failed, status code: {unionIdResponse.status_code}, msg: {unionIdResponse.text}"
|
||||
return None, "", f"unionIdResponse failed, status code: {unionIdResponse.status_code}, msg: {unionIdResponse.text}"
|
||||
# Print the response content
|
||||
unionIdReq = unionIdResponse.json()
|
||||
if unionIdReq["errcode"] != 0:
|
||||
return "", f"Request failed, msg: {unionIdReq['errmsg']}"
|
||||
return None, "", f"Request failed, msg: {unionIdReq['errmsg']}"
|
||||
|
||||
token, err = cls.auto_create_user(unionIdReq["result"]["userid"])
|
||||
token_pair, err = cls.auto_create_user(unionIdReq["result"]["userid"])
|
||||
if len(err) > 0:
|
||||
return "", "Request failed: " + err
|
||||
|
||||
return f"{dify_config.CONSOLE_WEB_URL}/explore/apps-center-extend?console_token={token.access_token}&&refresh_token={token.refresh_token}", ""
|
||||
return None, "", "Request failed: " + err
|
||||
|
||||
redirect_url = f"{dify_config.CONSOLE_WEB_URL}/explore/apps-center-extend"
|
||||
return token_pair, redirect_url, ""
|
||||
|
||||
@classmethod
|
||||
def get_user_info(cls, code: str) -> (str, str):
|
||||
def get_user_info(cls, code: str):
|
||||
"""
|
||||
获取用户信息并登录
|
||||
返回: (token_pair, redirect_url, error)
|
||||
"""
|
||||
host = "https://oapi.dingtalk.com/topapi/v2/user"
|
||||
token, err = cls.get_access_token()
|
||||
if err != "":
|
||||
return "", f"Failed to obtain token: {err}"
|
||||
return None, "", f"Failed to obtain token: {err}"
|
||||
response = requests.post(
|
||||
f"{host}/getuserinfo?access_token={token}",
|
||||
json={"code": code},
|
||||
)
|
||||
# Check the response status code
|
||||
if response.status_code != 200:
|
||||
return "", f"Request failed, status code: {response.status_code}"
|
||||
return None, "", f"Request failed, status code: {response.status_code}"
|
||||
# Print the response content
|
||||
req = response.json()
|
||||
if req["errcode"] != 0:
|
||||
return "", "Request failed: " + req["errmsg"]
|
||||
token, err = cls.auto_create_user(req["result"]["userid"])
|
||||
return None, "", "Request failed: " + req["errmsg"]
|
||||
token_pair, err = cls.auto_create_user(req["result"]["userid"])
|
||||
if len(err) != 0:
|
||||
return "", "Request failed: " + err
|
||||
return None, "", "Request failed: " + err
|
||||
|
||||
return f"{dify_config.CONSOLE_WEB_URL}/explore/apps-center-extend?console_token={token.access_token}&&refresh_token={token.refresh_token}", ""
|
||||
redirect_url = f"{dify_config.CONSOLE_WEB_URL}/explore/apps-center-extend"
|
||||
return token_pair, redirect_url, ""
|
||||
|
||||
@@ -32,8 +32,11 @@ class RecommendedAppService:
|
||||
classList = app.tags
|
||||
description = app.description
|
||||
config = app.app_model_config
|
||||
# Extend: start Handle apps without tags
|
||||
if len(classList) == 0:
|
||||
classList.append(Tag(name="未分类"))
|
||||
# Create a simple object with name attribute for "未分类" category
|
||||
classList.append(type('Tag', (), {'name': '未分类'})())
|
||||
# Extend: stop Handle apps without tags
|
||||
if (
|
||||
len(description) == 0
|
||||
and config is not None
|
||||
@@ -60,6 +63,7 @@ class RecommendedAppService:
|
||||
"icon_background": app.icon_background,
|
||||
},
|
||||
"app_id": installed_app.app_id,
|
||||
"installed_id": installed_app.id,
|
||||
"description": description,
|
||||
"copyright": "",
|
||||
"privacy_policy": "",
|
||||
|
||||
Reference in New Issue
Block a user