Files
dify-plus/api/controllers/console/auth/oauth.py
T
npc0-hue c507fc2675 feat:-皮肤处理和oauth快捷登录
- 同步至应用模板
- 取消同步至应用模板
- 计费唯一索引
2025-10-28 11:53:05 +08:00

234 lines
9.8 KiB
Python

import logging
from typing import Optional
import requests
from flask import current_app, redirect, request
from flask_restx import Resource
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import Unauthorized
from configs import dify_config
from constants.languages import languages
from events.tenant_event import tenant_was_created
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from libs.helper import extract_remote_ip
from libs.oauth import GitHubOAuth, GoogleOAuth, OaOAuth, OAuthUserInfo
from models import Account
from models.account import AccountStatus
from services.account_service import AccountService, RegisterService, TenantService
from services.errors.account import AccountNotFoundError, AccountRegisterError
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkSpaceNotFoundError
from services.feature_service import FeatureService
from .. import api
logger = logging.getLogger(__name__)
def get_oauth_providers():
with current_app.app_context():
if not dify_config.GITHUB_CLIENT_ID or not dify_config.GITHUB_CLIENT_SECRET:
github_oauth = None
else:
github_oauth = GitHubOAuth(
client_id=dify_config.GITHUB_CLIENT_ID,
client_secret=dify_config.GITHUB_CLIENT_SECRET,
redirect_uri=dify_config.CONSOLE_API_URL + "/console/api/oauth/authorize/github",
)
if not dify_config.GOOGLE_CLIENT_ID or not dify_config.GOOGLE_CLIENT_SECRET:
google_oauth = None
else:
google_oauth = GoogleOAuth(
client_id=dify_config.GOOGLE_CLIENT_ID,
client_secret=dify_config.GOOGLE_CLIENT_SECRET,
redirect_uri=dify_config.CONSOLE_API_URL + "/console/api/oauth/authorize/google",
)
oauth2 = OaOAuth(client_id='', client_secret='', redirect_uri='') # Extend: oauth2
OAUTH_PROVIDERS = {"github": github_oauth, "google": google_oauth, "oauth2": oauth2}
return OAUTH_PROVIDERS
class OAuthLogin(Resource):
def get(self, provider: str):
invite_token = request.args.get("invite_token") or None
OAUTH_PROVIDERS = get_oauth_providers()
with current_app.app_context():
oauth_provider = OAUTH_PROVIDERS.get(provider)
if not oauth_provider:
return {"error": "Invalid provider"}, 400
auth_url = oauth_provider.get_authorization_url(invite_token=invite_token)
return redirect(auth_url)
class OAuthCallback(Resource):
def get(self, provider: str):
OAUTH_PROVIDERS = get_oauth_providers()
with current_app.app_context():
oauth_provider = OAUTH_PROVIDERS.get(provider)
if not oauth_provider:
return {"error": "Invalid provider"}, 400
# extend: 兼容casdoor
id_token = None
code = request.args.get("code")
state = request.args.get("state")
# Fallback: some providers may return tokens directly in query (implicit/hybrid flow)
token_from_query: Optional[str] = None
if not code:
token_from_query = request.args.get("access_token")
if token_from_query:
logger.warning(
"oauth.callback_no_code_but_token",
extra={
"provider": provider,
"full_url": request.url,
"note": "Using access_token from query as fallback. Prefer Authorization Code flow.",
},
)
else:
return {"error": "Missing authorization code"}, 400
invite_token = None
if state:
invite_token = state
try:
if token_from_query is not None:
token = token_from_query
else:
# Extend: Start 兼容casdoor
response_json = oauth_provider.get_access_token(code)
token = response_json.get("access_token")
if not token:
return {"error": f"Error in OAuth: {response_json}"}, 502
id_token = response_json.get("id_token")
# Extend: Stop 兼容casdoor
# 检查token是否有效
if not token or token.strip() == "":
logger.error("OAuth2 access token is empty for provider %s", provider)
return {"error": "OAuth2 access token is empty or invalid"}, 400
user_info = oauth_provider.get_user_info(token)
except requests.RequestException as e:
error_text = e.response.text if e.response else str(e)
error_status = e.response.status_code if e.response else "Unknown"
logger.exception("An error occurred during the OAuth process with %s (Status: %s): %s",
provider, error_status, error_text)
# 提供更具体的错误信息
if error_status == 401:
return {"error": f"OAuth2 authentication failed (401 Unauthorized): {error_text}"}, 400
elif error_status == 400:
return {"error": f"OAuth2 bad request (400): {error_text}"}, 400
else:
return {"error": f"OAuth process failed (Status: {error_status}): {error_text}"}, 400
except ValueError as e:
logger.exception("OAuth2 configuration error for provider %s: %s", provider, str(e))
return {"error": f"OAuth2 configuration error: {str(e)}"}, 400
if invite_token and RegisterService.is_valid_invite_token(invite_token):
invitation = RegisterService._get_invitation_by_token(token=invite_token)
if invitation:
invitation_email = invitation.get("email", None)
if invitation_email != user_info.email:
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Invalid invitation token.")
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin/invite-settings?invite_token={invite_token}")
try:
account = _generate_account(provider, user_info)
except AccountNotFoundError:
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Account not found.")
except (WorkSpaceNotFoundError, WorkSpaceNotAllowedCreateError):
return redirect(
f"{dify_config.CONSOLE_WEB_URL}/signin"
"?message=Workspace not found, please contact system admin to invite you to join in a workspace."
)
except AccountRegisterError as e:
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message={e.description}")
# Check account status
if account.status == AccountStatus.BANNED.value:
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Account is banned.")
if account.status == AccountStatus.PENDING.value:
account.status = AccountStatus.ACTIVE.value
account.initialized_at = naive_utc_now()
db.session.commit()
try:
TenantService.create_owner_tenant_if_not_exist(account)
except Unauthorized:
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Workspace not found.")
except WorkSpaceNotAllowedCreateError:
return redirect(
f"{dify_config.CONSOLE_WEB_URL}/signin"
"?message=Workspace not found, please contact system admin to invite you to join in a workspace."
)
token_pair = AccountService.login(
account=account,
ip_address=extract_remote_ip(request),
)
# extend: 兼容casdoor
return redirect(
f"{dify_config.CONSOLE_WEB_URL}?access_token={token_pair.access_token}"
f"&refresh_token={token_pair.refresh_token}&id_token={id_token}"
)
def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]:
account: Optional[Account] = Account.get_by_openid(provider, user_info.id)
if not account:
with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=user_info.email)).scalar_one_or_none()
return account
def _generate_account(provider: str, user_info: OAuthUserInfo):
# Get account by openid or email.
account = _get_account_by_openid_or_email(provider, user_info)
if account:
tenants = TenantService.get_join_tenants(account)
if not tenants:
if not FeatureService.get_system_features().is_allow_create_workspace:
raise WorkSpaceNotAllowedCreateError()
else:
new_tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
TenantService.create_tenant_member(new_tenant, account, role="owner")
account.current_tenant = new_tenant
tenant_was_created.send(new_tenant)
if not account:
if not FeatureService.get_system_features().is_allow_register:
raise AccountNotFoundError()
account_name = user_info.name or "Dify"
account = RegisterService.register(
email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider
)
# Set interface language
preferred_lang = request.accept_languages.best_match(languages)
if preferred_lang and preferred_lang in languages:
interface_language = preferred_lang
else:
interface_language = languages[0]
account.interface_language = interface_language
db.session.commit()
# Link account
AccountService.link_account_integrate(provider, user_info.id, account)
return account
api.add_resource(OAuthLogin, "/oauth/login/<provider>")
api.add_resource(OAuthCallback, "/oauth/authorize/<provider>")