mirror of
https://github.com/YFGaia/dify-plus.git
synced 2026-06-04 10:14:00 +08:00
fix: web app
This commit is contained in:
@@ -750,3 +750,4 @@ BEDROCK_PROXY=
|
||||
|
||||
# 初始额度
|
||||
ACCOUNT_TOTAL_QUOTA=15
|
||||
ENTERPRISE_ENABLED=True
|
||||
|
||||
@@ -45,20 +45,33 @@ from controllers.web.error_extend import (
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from libs.passport import PassportService
|
||||
from libs.token import extract_access_token
|
||||
from models.account_money_extend import AccountMoneyExtend
|
||||
from services.account_service import AccountService
|
||||
from services.app_generate_service_extend import AppGenerateServiceExtend
|
||||
|
||||
|
||||
def is_end_login(end_user):
|
||||
user_info = None
|
||||
try:
|
||||
auth_token = request.headers.get("Authorization-extend")
|
||||
# 从 cookie 中读取 access_token
|
||||
auth_token = extract_access_token(request)
|
||||
if not auth_token:
|
||||
return None
|
||||
|
||||
# 验证 access_token
|
||||
decoded = PassportService().verify(auth_token)
|
||||
user_info = AccountService.load_logged_in_account(account_id=decoded.get("user_id"))
|
||||
user_id = decoded.get("user_id")
|
||||
|
||||
# 加载 Console 用户信息
|
||||
user_info = AccountService.load_logged_in_account(account_id=user_id)
|
||||
|
||||
# 绑定 end_user 与 Console 用户
|
||||
if user_info is not None:
|
||||
if end_user.external_user_id is None:
|
||||
end_user.external_user_id = decoded.get("user_id")
|
||||
except:
|
||||
end_user.external_user_id = user_id
|
||||
db.session.commit() # 提交绑定关系
|
||||
except Exception:
|
||||
logging.exception("load_logged_in_account error")
|
||||
pass
|
||||
# no login
|
||||
@@ -152,6 +165,11 @@ class CompletionApi(WebApiResource):
|
||||
streaming = payload.response_mode == "streaming"
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
# extend 获取 Console 用户 ID,直接作为 from_account_id 传递
|
||||
user_info = is_end_login(end_user)
|
||||
if user_info:
|
||||
args["account_id"] = user_info.id
|
||||
|
||||
try:
|
||||
AppGenerateServiceExtend.calculate_cumulative_usage(
|
||||
app_model=app_model,
|
||||
@@ -251,6 +269,11 @@ class ChatApi(WebApiResource):
|
||||
streaming = payload.response_mode == "streaming"
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
# 获取 Console 用户 ID,直接作为 from_account_id 传递
|
||||
user_info = is_end_login(end_user)
|
||||
if user_info:
|
||||
args["account_id"] = user_info.id
|
||||
|
||||
try:
|
||||
AppGenerateServiceExtend.calculate_cumulative_usage(
|
||||
app_model=app_model,
|
||||
|
||||
@@ -23,6 +23,7 @@ from libs.passport import PassportService
|
||||
from libs.password import valid_password
|
||||
from libs.token import (
|
||||
clear_webapp_access_token_from_cookie,
|
||||
extract_access_token,
|
||||
extract_webapp_access_token,
|
||||
)
|
||||
from services.account_service import AccountService
|
||||
@@ -88,11 +89,23 @@ class LoginStatusApi(Resource):
|
||||
def get(self):
|
||||
app_code = request.args.get("app_code")
|
||||
user_id = request.args.get("user_id")
|
||||
|
||||
# 检查 Console 用户的 access_token cookie
|
||||
console_token = extract_access_token(request)
|
||||
console_user_logged_in = False
|
||||
if console_token:
|
||||
try:
|
||||
PassportService().verify(console_token)
|
||||
console_user_logged_in = True
|
||||
except Exception:
|
||||
console_user_logged_in = False
|
||||
|
||||
token = extract_webapp_access_token(request)
|
||||
if not app_code:
|
||||
return {
|
||||
"logged_in": bool(token),
|
||||
"app_logged_in": False,
|
||||
"console_logged_in": console_user_logged_in,
|
||||
}
|
||||
app_id = AppService.get_app_id_by_code(app_code)
|
||||
is_public = not dify_config.ENTERPRISE_ENABLED or not WebAppAuthService.is_app_require_permission_check(
|
||||
@@ -118,6 +131,7 @@ class LoginStatusApi(Resource):
|
||||
return {
|
||||
"logged_in": user_logged_in,
|
||||
"app_logged_in": app_logged_in,
|
||||
"console_logged_in": console_user_logged_in,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -86,6 +86,11 @@ class WorkflowRunApi(WebApiResource):
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# extend: 获取 Console 用户 ID,直接作为 from_account_id 传递
|
||||
user_info = is_end_login(end_user)
|
||||
if user_info:
|
||||
args["account_id"] = user_info.id
|
||||
|
||||
try:
|
||||
AppGenerateServiceExtend.calculate_cumulative_usage(
|
||||
app_model=app_model,
|
||||
|
||||
@@ -135,6 +135,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
extras["app_token_id"] = api_token.id
|
||||
# ------------------- 二开部分End - 密钥额度限制 -------------------
|
||||
|
||||
# extend: 如果 args 中有 account_id(Web App 登录用户),将其放入 extras
|
||||
if args.get("account_id"):
|
||||
extras["account_id"] = args.get("account_id")
|
||||
|
||||
# get conversation
|
||||
conversation = None
|
||||
conversation_id = args.get("conversation_id")
|
||||
|
||||
@@ -98,6 +98,9 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
inputs = args["inputs"]
|
||||
|
||||
extras = {"auto_generate_conversation_name": args.get("auto_generate_name", True)}
|
||||
# extend: 如果 args 中有 account_id(Web App 登录用户),将其放入 extras
|
||||
if args.get("account_id"):
|
||||
extras["account_id"] = args.get("account_id")
|
||||
|
||||
# get conversation
|
||||
conversation = None
|
||||
|
||||
@@ -90,6 +90,9 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
inputs = args["inputs"]
|
||||
|
||||
extras = {"auto_generate_conversation_name": args.get("auto_generate_name", True)}
|
||||
# extend: 如果 args 中有 account_id(Web App 登录用户),将其放入 extras
|
||||
if args.get("account_id"):
|
||||
extras["account_id"] = args.get("account_id")
|
||||
|
||||
# get conversation
|
||||
conversation = None
|
||||
|
||||
@@ -130,6 +130,11 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id
|
||||
)
|
||||
|
||||
# extend: 如果 args 中有 account_id(Web App 登录用户),将其放入 extras
|
||||
extras = {}
|
||||
if args.get("account_id"):
|
||||
extras["account_id"] = args.get("account_id")
|
||||
|
||||
# init application generate entity
|
||||
application_generate_entity = CompletionAppGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
@@ -144,7 +149,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
user_id=user.id,
|
||||
stream=streaming,
|
||||
invoke_from=invoke_from,
|
||||
extras={},
|
||||
extras=extras,
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
|
||||
|
||||
@@ -128,6 +128,9 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
if application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
|
||||
from_source = "api"
|
||||
end_user_id = application_generate_entity.user_id
|
||||
# 如果 extras 中提供了 account_id(Web App 登录用户),优先使用
|
||||
if application_generate_entity.extras and application_generate_entity.extras.get("account_id"):
|
||||
account_id = application_generate_entity.extras.get("account_id")
|
||||
else:
|
||||
from_source = "console"
|
||||
account_id = application_generate_entity.user_id
|
||||
|
||||
@@ -166,6 +166,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
extras["app_token_id"] = api_token.id
|
||||
# ------------------- 二开部分End - 密钥额度限制
|
||||
|
||||
# extend: 如果 args 中有 account_id(Web App 登录用户),将其放入 extras
|
||||
if args.get("account_id"):
|
||||
extras["account_id"] = args.get("account_id")
|
||||
|
||||
# init application generate entity
|
||||
application_generate_entity = WorkflowAppGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
|
||||
+9
-3
@@ -957,9 +957,15 @@ class Conversation(Base):
|
||||
):
|
||||
return end_user.session_id
|
||||
elif end_user is not None:
|
||||
user: Account = db.session.query(Account).filter(Account.id == end_user.external_user_id).first()
|
||||
if user:
|
||||
return user.name
|
||||
# 验证 external_user_id 是否为有效的 UUID
|
||||
try:
|
||||
uuid.UUID(end_user.external_user_id)
|
||||
user: Account = db.session.query(Account).filter(Account.id == end_user.external_user_id).first()
|
||||
if user:
|
||||
return user.name
|
||||
except (ValueError, TypeError):
|
||||
# 如果不是有效的 UUID,返回 session_id
|
||||
return end_user.session_id
|
||||
elif self.from_account_id:
|
||||
user: Account = db.session.query(Account).filter(Account.id == self.from_account_id).first()
|
||||
if user:
|
||||
|
||||
+15
-8
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
@@ -1174,14 +1175,20 @@ class WorkflowAppLog(TypeBase):
|
||||
from models.model import EndUser
|
||||
end_user = db.session.query(EndUser).filter(EndUser.id == self.created_by).first()
|
||||
if end_user is not None and end_user.external_user_id is not None and len(end_user.external_user_id) > 0:
|
||||
user: Account = db.session.query(Account).filter(Account.id == end_user.external_user_id).first()
|
||||
if user:
|
||||
return {
|
||||
"id": user.id,
|
||||
"type": user.status,
|
||||
"is_anonymous": "true",
|
||||
"session_id": user.name,
|
||||
}
|
||||
# 验证 external_user_id 是否为有效的 UUID
|
||||
try:
|
||||
uuid.UUID(end_user.external_user_id)
|
||||
user: Account = db.session.query(Account).filter(Account.id == end_user.external_user_id).first()
|
||||
if user:
|
||||
return {
|
||||
"id": user.id,
|
||||
"type": user.status,
|
||||
"is_anonymous": "true",
|
||||
"session_id": user.name,
|
||||
}
|
||||
except (ValueError, TypeError):
|
||||
# 如果不是有效的 UUID,跳过查询
|
||||
pass
|
||||
return end_user
|
||||
elif len(self.created_by) > 0:
|
||||
user: Account = db.session.query(Account).filter(Account.id == self.created_by).first()
|
||||
|
||||
@@ -2,14 +2,14 @@
|
||||
|
||||
import { usePathname, useRouter, useSearchParams } from 'next/navigation'
|
||||
import * as React from 'react'
|
||||
import { useCallback, useEffect } from 'react'
|
||||
import { useCallback, useEffect, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import AppUnavailable from '@/app/components/base/app-unavailable'
|
||||
import Loading from '@/app/components/base/loading'
|
||||
import { useWebAppStore } from '@/context/web-app-context'
|
||||
import { useGetUserCanAccessApp } from '@/service/access-control'
|
||||
import { useGetWebAppInfo, useGetWebAppMeta, useGetWebAppParams } from '@/service/use-share'
|
||||
import { webAppLogout } from '@/service/webapp-auth'
|
||||
import { checkConsoleLoginStatus, webAppLogout } from '@/service/webapp-auth'
|
||||
|
||||
const AuthenticatedLayout = ({ children }: { children: React.ReactNode }) => {
|
||||
const { t } = useTranslation()
|
||||
@@ -22,6 +22,7 @@ const AuthenticatedLayout = ({ children }: { children: React.ReactNode }) => {
|
||||
const { isFetching: isFetchingAppInfo, data: appInfo, error: appInfoError } = useGetWebAppInfo()
|
||||
const { isFetching: isFetchingAppMeta, data: appMeta, error: appMetaError } = useGetWebAppMeta()
|
||||
const { data: userCanAccessApp, error: useCanAccessAppError } = useGetUserCanAccessApp({ appId: appInfo?.app_id, isInstalledApp: false })
|
||||
const [isCheckingAuth, setIsCheckingAuth] = useState(true)
|
||||
|
||||
useEffect(() => {
|
||||
if (appInfo)
|
||||
@@ -36,6 +37,22 @@ const AuthenticatedLayout = ({ children }: { children: React.ReactNode }) => {
|
||||
const router = useRouter()
|
||||
const pathname = usePathname()
|
||||
const searchParams = useSearchParams()
|
||||
|
||||
// 检查 Console 用户登录状态
|
||||
useEffect(() => {
|
||||
const checkConsoleAuth = async () => {
|
||||
setIsCheckingAuth(true)
|
||||
const isConsoleLoggedIn = await checkConsoleLoginStatus()
|
||||
if (!isConsoleLoggedIn) {
|
||||
// 未登录,保存当前 URL 到 localStorage,然后跳转到 Console 登录页面
|
||||
localStorage.setItem('redirect_url', pathname)
|
||||
router.replace('/signin')
|
||||
}
|
||||
setIsCheckingAuth(false)
|
||||
}
|
||||
|
||||
checkConsoleAuth()
|
||||
}, [pathname, router])
|
||||
const getSigninUrl = useCallback(() => {
|
||||
const params = new URLSearchParams(searchParams)
|
||||
params.delete('message')
|
||||
@@ -85,7 +102,7 @@ const AuthenticatedLayout = ({ children }: { children: React.ReactNode }) => {
|
||||
</div>
|
||||
)
|
||||
}
|
||||
if (isFetchingAppInfo || isFetchingAppParams || isFetchingAppMeta) {
|
||||
if (isCheckingAuth || isFetchingAppInfo || isFetchingAppParams || isFetchingAppMeta) {
|
||||
return (
|
||||
<div className="flex h-full items-center justify-center">
|
||||
<Loading />
|
||||
|
||||
@@ -2,13 +2,14 @@
|
||||
import type { FC } from 'react'
|
||||
import { useRouter, useSearchParams } from 'next/navigation'
|
||||
import * as React from 'react'
|
||||
import { useCallback } from 'react'
|
||||
import { useCallback, useEffect, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import AppUnavailable from '@/app/components/base/app-unavailable'
|
||||
import Loading from '@/app/components/base/loading'
|
||||
import { useGlobalPublicStore } from '@/context/global-public-context'
|
||||
import { useWebAppStore } from '@/context/web-app-context'
|
||||
import { AccessMode } from '@/models/access-control'
|
||||
import { webAppLogout } from '@/service/webapp-auth'
|
||||
import { checkConsoleLoginStatus, webAppLogout } from '@/service/webapp-auth'
|
||||
import ExternalMemberSsoAuth from './components/external-member-sso-auth'
|
||||
import NormalForm from './normalForm'
|
||||
|
||||
@@ -18,8 +19,26 @@ const WebSSOForm: FC = () => {
|
||||
const webAppAccessMode = useWebAppStore(s => s.webAppAccessMode)
|
||||
const searchParams = useSearchParams()
|
||||
const router = useRouter()
|
||||
const [isCheckingAuth, setIsCheckingAuth] = useState(true)
|
||||
|
||||
const redirectUrl = searchParams.get('redirect_url')
|
||||
|
||||
// 检查 Console 用户登录状态
|
||||
useEffect(() => {
|
||||
const checkAuth = async () => {
|
||||
setIsCheckingAuth(true)
|
||||
const isConsoleLoggedIn = await checkConsoleLoginStatus()
|
||||
if (!isConsoleLoggedIn) {
|
||||
// 未登录,保存 redirect_url 到 localStorage,然后跳转到 Console 登录页面
|
||||
if (redirectUrl)
|
||||
localStorage.setItem('redirect_url', redirectUrl)
|
||||
router.replace('/signin')
|
||||
}
|
||||
setIsCheckingAuth(false)
|
||||
}
|
||||
|
||||
checkAuth()
|
||||
}, [router, redirectUrl])
|
||||
|
||||
const getSigninUrl = useCallback(() => {
|
||||
const params = new URLSearchParams()
|
||||
@@ -34,6 +53,14 @@ const WebSSOForm: FC = () => {
|
||||
router.replace(url)
|
||||
}, [getSigninUrl, router, webAppLogout, shareCode])
|
||||
|
||||
if (isCheckingAuth) {
|
||||
return (
|
||||
<div className="flex h-full items-center justify-center">
|
||||
<Loading />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
if (!redirectUrl) {
|
||||
return (
|
||||
<div className="flex h-full items-center justify-center">
|
||||
|
||||
@@ -316,10 +316,9 @@ const Chat: FC<ChatProps> = ({
|
||||
}
|
||||
// Extend: stop messages context handling
|
||||
return (
|
||||
<>
|
||||
<Fragment key={item.id}>
|
||||
<Answer
|
||||
appData={appData}
|
||||
key={item.id}
|
||||
item={item}
|
||||
question={chatList[index - 1]?.content}
|
||||
index={index}
|
||||
@@ -347,7 +346,7 @@ const Chat: FC<ChatProps> = ({
|
||||
)
|
||||
}
|
||||
{/* Extend: stop messages context handling */}
|
||||
</>
|
||||
</Fragment>
|
||||
)
|
||||
}
|
||||
return (
|
||||
|
||||
@@ -21,6 +21,19 @@ function getItemWithExpiry(key: string): string | null {
|
||||
}
|
||||
|
||||
export const resolvePostLoginRedirect = (searchParams: ReadonlyURLSearchParams) => {
|
||||
// WebApp/Console: 优先走 localStorage 里的 redirect_url(登录成功后必须清理)
|
||||
const localRedirectUrl = localStorage.getItem('redirect_url')
|
||||
if (localRedirectUrl) {
|
||||
localStorage.removeItem('redirect_url')
|
||||
try {
|
||||
return decodeURIComponent(localRedirectUrl)
|
||||
}
|
||||
catch (e) {
|
||||
console.error('Failed to decode redirect URL from localStorage:', e)
|
||||
return localRedirectUrl
|
||||
}
|
||||
}
|
||||
|
||||
const redirectUrl = searchParams.get(REDIRECT_URL_KEY)
|
||||
if (redirectUrl) {
|
||||
try {
|
||||
|
||||
@@ -446,12 +446,6 @@ export const ssePost = async (
|
||||
}),
|
||||
} as RequestInit, fetchOptions)
|
||||
|
||||
// ----------------- start You must log in to access your account extend ---------------
|
||||
const token = localStorage.getItem(CSRF_COOKIE_NAME()) || ''
|
||||
if ((url === 'chat-messages' || url === 'completion-messages' || url === 'workflows/run') && token.length > 0)
|
||||
(options.headers as Headers).set('Authorization-extend', `${token}`)
|
||||
// ----------------- stop You must log in to access your account extend ---------------
|
||||
|
||||
const contentType = (options.headers as Headers).get('Content-Type')
|
||||
if (!contentType)
|
||||
(options.headers as Headers).set('Content-Type', ContentType.json)
|
||||
|
||||
@@ -28,6 +28,7 @@ export function clearWebAppPassport(shareCode: string) {
|
||||
type isWebAppLogin = {
|
||||
logged_in: boolean
|
||||
app_logged_in: boolean
|
||||
console_logged_in?: boolean
|
||||
}
|
||||
|
||||
export async function webAppLoginStatus(shareCode: string, userId?: string) {
|
||||
@@ -43,6 +44,17 @@ export async function webAppLoginStatus(shareCode: string, userId?: string) {
|
||||
}
|
||||
}
|
||||
|
||||
export async function checkConsoleLoginStatus() {
|
||||
try {
|
||||
const { console_logged_in } = await getPublic<isWebAppLogin>('/login/status')
|
||||
return console_logged_in || false
|
||||
}
|
||||
catch (error) {
|
||||
console.error('Failed to check console login status:', error)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
export async function webAppLogout(shareCode: string) {
|
||||
clearWebAppAccessToken()
|
||||
clearWebAppPassport(shareCode)
|
||||
|
||||
Reference in New Issue
Block a user