fix: 上下文,聊天计费,额度

This commit is contained in:
npc0-hue
2026-01-21 18:10:18 +08:00
parent 0cf63b0f08
commit 9ed0d7c891
111 changed files with 1954 additions and 729 deletions
+38
View File
@@ -712,3 +712,41 @@ ANNOTATION_IMPORT_MAX_CONCURRENT=5
SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD=21
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE=1000
SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS=30
# OA Oauth2(二开新增配置)
OAUTH2_CLIENT_URL=
OAUTH2_CLIENT_ID=
OAUTH2_CLIENT_SECRET=
OAUTH2_TOKEN_URL=
OAUTH2_USER_INFO_URL=
# 改成从数据库获取应用模版(二开新增配置)
HOSTED_FETCH_APP_TEMPLATES_MODE=db
# 新增数据库相关配置(二开新增配置)
SQLALCHEMY_POOL_SIZE=100
SQLALCHEMY_MAX_OVERFLOW=10
SQLALCHEMY_POOL_RECYCLE=3600
SQLALCHEMY_POOL_PRE_PING=false
# 注册账号统一邮箱后缀(二开新增配置)
EMAIL_DOMAIN=
# 默认邮箱账号密码登录(二开新增配置)
ENABLE_EMAIL_PASSWORD_LOGIN=True
# 管理后台相关配置,后台超级管理员权限组id(二开新增配置)
ADMIN_GROUP_ID=888
# 中->美汇率
RMB_TO_USD_RATE=7.26
# 默认语言
DEFAULT_LANGUAGE=zh-Hans
# Bedrock Proxy
BEDROCK_PROXY=
# 初始额度
ACCOUNT_TOTAL_QUOTA=15
+1
View File
@@ -2112,6 +2112,7 @@ def migrate_oss(
db.session.rollback()
click.echo(click.style(f"Failed to update DB storage_type: {str(e)}", fg="red"))
# extend: start 管理二开db扩展
@click.group("extend_db", help="管理二开扩展表的数据库迁移")
def extend_db():
+7
View File
@@ -67,6 +67,13 @@ class ExtendInfo(BaseSettings):
default=None,
)
# Extend: 记忆上下文功能
DEFAULT_NUMBER_CONTEXT: Optional[int] = Field(
description="Default number of context retention (记忆窗口默认值)",
default=5,
)
# Extend: 记忆上下文功能
class ExtendConfig(ExtendInfo):
pass
+3 -3
View File
@@ -83,7 +83,7 @@ from .auth import (
login,
oauth,
oauth_server,
register_extend,# 二开部分: 新增用户(调用dify注册接口)
register_extend, # 二开部分: 新增用户(调用dify注册接口)
)
# Import billing controllers
@@ -124,7 +124,7 @@ from .tag import tags
# Import workspace controllers
from .workspace import (
account,
account_extend,# 二开部分:新增account_extend
account_extend, # 二开部分:新增account_extend
agent_providers,
endpoint,
load_balancing_config,
@@ -141,9 +141,9 @@ api.add_namespace(console_ns)
__all__ = [
"account",
"account_extend", # 二开部分:新增account_extend
"activate",
"admin",
"account_extend",# 二开部分:新增account_extend
"advanced_prompt_template",
"agent",
"agent_providers",
+1 -1
View File
@@ -12,9 +12,9 @@ from werkzeug.exceptions import Forbidden
from extensions.ext_database import db
from libs.helper import TimestampField
from libs.login import current_account_with_tenant, login_required
from models.api_token_money_extend import ApiTokenMoneyExtend # 二开部分 - 密钥额度限制
from models.dataset import Dataset
from models.model import ApiToken, App
from models.api_token_money_extend import ApiTokenMoneyExtend # 二开部分 - 密钥额度限制
from . import console_ns
from .wraps import account_initialization_required, edit_permission_required, setup_required
+20
View File
@@ -46,6 +46,25 @@ class AppListQuery(BaseModel):
tag_ids: list[str] | None = Field(default=None, description="Comma-separated tag IDs")
is_created_by_me: bool | None = Field(default=None, description="Filter by creator")
# extend: start 二开部分:新增的查询参数
@field_validator("mode", mode="before")
@classmethod
def validate_mode(cls, value: Any) -> str:
"""
Be tolerant for query params.
If client passes an unexpected value, fall back to 'all' instead of raising 422.
"""
if value is None:
return "all"
if isinstance(value, str):
v = value.strip()
allowed = {"completion", "chat", "advanced-chat", "workflow", "agent-chat", "channel", "all"}
return v if v in allowed else "all"
return "all"
# extend: start 二开部分:新增的查询参数
@field_validator("tag_ids", mode="before")
@classmethod
def validate_tag_ids(cls, value: str | list[str] | None) -> list[str] | None:
@@ -443,6 +462,7 @@ class AppPagination(ResponseModel):
total: int
has_more: bool = Field(validation_alias=AliasChoices("has_next", "has_more"))
data: list[AppPartial] = Field(validation_alias=AliasChoices("items", "data"))
recommended_apps: list[str] # extend: recommended apps
class AppExportResponse(ResponseModel):
+34
View File
@@ -69,7 +69,41 @@ class AppSyncApi(Resource):
return "", 200
# Extend: start messages context handling
class MessageContextApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
"""Message Context"""
from flask import request
conversation_id = request.args.get("conversation_id")
if not conversation_id:
from werkzeug.exceptions import BadRequest
raise BadRequest("conversation_id is required")
app_service = RecommendedAppService()
return app_service.message_context(conversation_id)
@setup_required
@login_required
@account_initialization_required
def delete(self):
"""Message Context"""
from flask import request
message_id = request.args.get("message_id")
conversation_id = request.args.get("conversation_id")
if not message_id or not conversation_id:
from werkzeug.exceptions import BadRequest
raise BadRequest("message_id and conversation_id are required")
app_service = RecommendedAppService()
return app_service.delete_message_context(conversation_id, message_id)
# Extend: stop messages context handling
# ----------------start sync app------------------------
api.add_resource(AppSyncApi, "/apps/<uuid:app_id>/sync")
api.add_resource(InstalledSyncAppApi, "/installed/apps")
api.add_resource(MessageContextApi, "/message/context")
# ---------------- stop sync app ------------------------
+2 -2
View File
@@ -17,10 +17,10 @@ from controllers.console.app.error import (
ProviderQuotaExceededError,
)
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
# Extend money_extend
from controllers.console.money_extend import money_limit
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import (
@@ -4,6 +4,8 @@ from typing import cast
from flask import request
from flask_restx import Resource, fields
# Extend: 记忆上下文功能
from configs import dify_config
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
@@ -12,11 +14,15 @@ from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ToolParameterConfigurationManager
from events.app_event import app_model_config_was_updated
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from libs.login import current_account_with_tenant, login_required
from models.model import AppMode, AppModelConfig
from models.model_extend import AppExtend
from services.app_model_config_service import AppModelConfigService
# Extend: 记忆上下文功能
@console_ns.route("/apps/<uuid:app_id>/model-config")
class ModelConfigResource(Resource):
@@ -67,6 +73,23 @@ class ModelConfigResource(Resource):
)
new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration)
# Extend: 记忆上下文功能 - Start
config = request.json
if app_model.mode == AppMode.AGENT_CHAT.value or app_model.mode == AppMode.CHAT.value or app_model.is_agent:
retention_number = int(config.get("retention_number", dify_config.DEFAULT_NUMBER_CONTEXT))
# 循环移除相关键
redis_client.delete(f"retention_number_{app_model.id}")
app_extend = db.session.query(AppExtend).filter(AppExtend.app_id == app_model.id).first()
# appExtend is not None
if app_extend is None:
db.session.add(AppExtend(app_id=app_model.id, retention_number=retention_number))
else:
db.session.query(AppExtend).filter(AppExtend.app_id == app_model.id).update(
{AppExtend.retention_number: retention_number}
)
db.session.commit()
# Extend: 记忆上下文功能 - Stop
if app_model.mode == AppMode.AGENT_CHAT or app_model.is_agent:
# get original app model config
original_app_model_config = (
-11
View File
@@ -6,17 +6,6 @@ from controllers.console.app.error import AppNotFoundError
from extensions.ext_database import db
from libs.login import current_account_with_tenant
from models import App, AppMode
from models.account import Account
def _load_app_model(app_id: str) -> Optional[App]:
assert isinstance(current_user, Account)
app_model = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.first()
)
return app_model
P = ParamSpec("P")
R = TypeVar("R")
+1 -1
View File
@@ -50,7 +50,7 @@ def get_oauth_providers():
redirect_uri=dify_config.CONSOLE_API_URL + "/console/api/oauth/authorize/google",
)
oauth2 = OaOAuth(client_id='', client_secret='', redirect_uri='') # Extend: oauth2
oauth2 = OaOAuth(client_id='', client_secret='', redirect_uri='') # Extend: oauth2
OAUTH_PROVIDERS = {"github": github_oauth, "google": google_oauth, "oauth2": oauth2}
return OAUTH_PROVIDERS
@@ -33,10 +33,10 @@ from libs.login import current_user
from models import Account
from models.model import AppMode
from services.app_generate_service import AppGenerateService
from services.app_task_service import AppTaskService
from services.app_generate_service_extend import (
AppGenerateServiceExtend, # Extend: App Center - Recommended list sorted by usage frequency
)
from services.app_task_service import AppTaskService
from services.errors.llm import InvokeRateLimitError
from .. import console_ns
@@ -1,4 +1,5 @@
from flask_restful import Resource, reqparse # type: ignore
from flask_restx import Resource
from pydantic import BaseModel
from werkzeug.exceptions import Forbidden
from controllers.common.schema import register_schema_models
@@ -6,8 +7,8 @@ from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from libs.login import current_user, login_required
from models.account import TenantAccountRole
from libs.login import current_account_with_tenant, login_required
from models import TenantAccountRole
from services.model_load_balancing_service import ModelLoadBalancingService
+5 -1
View File
@@ -6,7 +6,11 @@ from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required, is_admin_or_owner_required # extend: 非admin或者owner返回 Forbidden
from controllers.console.wraps import ( # extend: 非admin或者owner返回 Forbidden
account_initialization_required,
is_admin_or_owner_required,
setup_required,
)
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.utils.encoders import jsonable_encoder
@@ -11,7 +11,7 @@ from controllers.service_api import service_api_ns
from controllers.service_api.wraps import validate_app_token
from extensions.ext_redis import redis_client
from fields.annotation_fields import annotation_fields, build_annotation_model
from models.model import App
from models.model import ApiToken, App # extend - 密钥额度限制,新增ApiToken
from services.annotation_service import AppAnnotationService
@@ -111,7 +111,7 @@ class AnnotationListApi(Resource):
)
@validate_app_token
@service_api_ns.marshal_with(build_annotation_list_model(service_api_ns))
def get(self, app_model: App):
def get(self, app_model: App, api_token: ApiToken): # extend - 密钥额度限制,新增api_token
"""List annotations for the application."""
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
+4 -4
View File
@@ -5,7 +5,7 @@ from controllers.service_api import service_api_ns
from controllers.service_api.app.error import AppUnavailableError
from controllers.service_api.wraps import validate_app_token
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
from models.model import App, AppMode
from models.model import ApiToken, App, AppMode
from services.app_service import AppService
@@ -23,7 +23,7 @@ class AppParameterApi(Resource):
}
)
@validate_app_token
def get(self, app_model: App):
def get(self, app_model: App, api_token: ApiToken): # extend - 密钥额度限制,新增api_token
"""Retrieve app parameters.
Returns the input form parameters and configuration for the application.
@@ -60,7 +60,7 @@ class AppMetaApi(Resource):
}
)
@validate_app_token
def get(self, app_model: App):
def get(self, app_model: App, api_token: ApiToken): # extend - 密钥额度限制,新增api_token
"""Get app metadata.
Returns metadata about the application including configuration and settings.
@@ -80,7 +80,7 @@ class AppInfoApi(Resource):
}
)
@validate_app_token
def get(self, app_model: App):
def get(self, app_model: App, api_token: ApiToken): # extend - 密钥额度限制,新增api_token
"""Get app information.
Returns basic information about the application including name, description, tags, and mode.
+1 -1
View File
@@ -22,7 +22,7 @@ from controllers.service_api.app.error import (
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError
from models.model import ApiToken, App, EndUser # extend: 二开部分 密钥额度限制,新增api_token
from models.model import ApiToken, App, EndUser # extend: 二开部分 密钥额度限制,新增api_token
from services.audio_service import AudioService
from services.errors.audio import (
AudioTooLargeServiceError,
@@ -30,9 +30,9 @@ from core.errors.error import (
from core.helper.trace_id_helper import get_external_trace_id
from core.model_runtime.errors.invoke import InvokeError
from libs import helper
from models.model import ApiToken, App, AppMode, EndUser # extend: 密钥额度限制,新增ApiToken
from models.model import ApiToken, App, AppMode, EndUser # extend: 密钥额度限制,新增ApiToken
from services.app_generate_service import AppGenerateService
from services.app_generate_service_extend import AppGenerateServiceExtend # extend: 密钥额度限制,新增ApiToken
from services.app_generate_service_extend import AppGenerateServiceExtend # extend: 密钥额度限制,新增ApiToken
from services.app_task_service import AppTaskService
from services.errors.app import IsDraftWorkflowError, WorkflowIdFormatError, WorkflowNotFoundError
from services.errors.llm import InvokeRateLimitError
+2 -2
View File
@@ -15,7 +15,7 @@ from controllers.service_api import service_api_ns
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from extensions.ext_database import db
from fields.file_fields import FileResponse
from models import ApiToken, App, EndUser # extend - 密钥额度限制,新增api_token,否则上传文件会报错
from models import ApiToken, App, EndUser # extend - 密钥额度限制,新增api_token,否则上传文件会报错
from services.file_service import FileService
register_schema_models(service_api_ns, FileResponse)
@@ -36,7 +36,7 @@ class FileApi(Resource):
)
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM)) # type: ignore
@service_api_ns.response(HTTPStatus.CREATED, "File uploaded", service_api_ns.models[FileResponse.__name__])
def post(self, app_model: App, end_user: EndUser, api_token: ApiToken): # extend - 密钥额度限制,新增api_token,否则上传文件会报错
def post(self, app_model: App, end_user: EndUser, api_token: ApiToken): # extend - 密钥额度限制,新增api_token,否则上传文件会报错
"""Upload a file for use in conversations.
Accepts a single file upload via multipart/form-data.
+5 -5
View File
@@ -15,7 +15,7 @@ from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate
from core.app.entities.app_invoke_entities import InvokeFrom
from fields.conversation_fields import ResultResponse
from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem
from models.model import ApiToken, App, AppMode, EndUser # extend - 密钥额度限制,新增ApiToken
from models.model import ApiToken, App, AppMode, EndUser # extend - 密钥额度限制,新增ApiToken
from services.errors.message import (
FirstMessageNotExistsError,
MessageNotExistsError,
@@ -58,7 +58,7 @@ class MessageListApi(Resource):
}
)
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
def get(self, app_model: App, end_user: EndUser, api_token: ApiToken): # extend - 密钥额度限制,新增api_token,否则上传文件会报错
def get(self, app_model: App, end_user: EndUser, api_token: ApiToken): # extend - 密钥额度限制,新增api_token,否则上传文件会报错
"""List messages in a conversation.
Retrieves messages with pagination support using first_id.
@@ -102,7 +102,7 @@ class MessageFeedbackApi(Resource):
}
)
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser, message_id, api_token: ApiToken): # extend - 密钥额度限制,新增api_token,否则上传文件会报错
def post(self, app_model: App, end_user: EndUser, message_id, api_token: ApiToken): # extend - 密钥额度限制,新增api_token,否则上传文件会报错
"""Submit feedback for a message.
Allows users to rate messages as like/dislike and provide optional feedback content.
@@ -137,7 +137,7 @@ class AppGetFeedbacksApi(Resource):
}
)
@validate_app_token
def get(self, app_model: App):
def get(self, app_model: App, api_token: ApiToken): # extend - 密钥额度限制,新增api_token
"""Get all feedbacks for the application.
Returns paginated list of all feedback submitted for messages in this app.
@@ -162,7 +162,7 @@ class MessageSuggestedApi(Resource):
}
)
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY, required=True))
def get(self, app_model: App, end_user: EndUser, message_id, api_token: ApiToken): # extend - 密钥额度限制,新增api_token,否则上传文件会报错
def get(self, app_model: App, end_user: EndUser, message_id, api_token: ApiToken): # extend - 密钥额度限制,新增api_token,否则上传文件会报错
"""Get suggested follow-up questions for a message.
Returns AI-generated follow-up questions based on the message content.
+2 -2
View File
@@ -6,7 +6,7 @@ from controllers.service_api import service_api_ns
from controllers.service_api.wraps import validate_app_token
from extensions.ext_database import db
from models.account import TenantStatus
from models.model import App, Site
from models.model import ApiToken, App, Site # extend - 密钥额度限制,新增ApiToken
@service_api_ns.route("/site")
@@ -23,7 +23,7 @@ class AppSiteApi(Resource):
}
)
@validate_app_token
def get(self, app_model: App):
def get(self, app_model: App, api_token: ApiToken): # extend - 密钥额度限制,新增api_token
"""Retrieve app site info.
Returns the site configuration for the application including theme, icons, and text.
+8 -4
View File
@@ -34,7 +34,7 @@ from extensions.ext_database import db
from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model
from libs import helper
from libs.helper import TimestampField
from models.model import ApiToken, App, AppMode, EndUser # extend - 密钥额度限制,ApiToken
from models.model import ApiToken, App, AppMode, EndUser # extend - 密钥额度限制,ApiToken
from repositories.factory import DifyAPIRepositoryFactory
from services.app_generate_service import AppGenerateService
from services.errors.app import IsDraftWorkflowError, WorkflowIdFormatError, WorkflowNotFoundError
@@ -97,7 +97,7 @@ class WorkflowRunDetailApi(Resource):
)
@validate_app_token
@service_api_ns.marshal_with(build_workflow_run_model(service_api_ns))
def get(self, app_model: App, workflow_run_id: str):
def get(self, app_model: App, workflow_run_id: str, api_token: ApiToken): # extend - 密钥额度限制,新增api_token参数
"""Get a workflow task running detail.
Returns detailed information about a specific workflow run.
@@ -134,7 +134,7 @@ class WorkflowRunApi(Resource):
}
)
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser, api_token: ApiToken, workflow_id: str): # extend: 密钥额度限制
def post(self, app_model: App, end_user: EndUser, api_token: ApiToken): # extend: 密钥额度限制
"""Execute a workflow.
Runs a workflow with the provided inputs and returns the results.
@@ -195,7 +195,7 @@ class WorkflowRunByIdApi(Resource):
}
)
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser, workflow_id: str):
def post(self, app_model: App, end_user: EndUser, api_token: ApiToken, workflow_id: str): # extend: 密钥额度限制
"""Run specific workflow by ID.
Executes a specific workflow version identified by its ID.
@@ -215,6 +215,10 @@ class WorkflowRunByIdApi(Resource):
args["external_trace_id"] = external_trace_id
streaming = payload.response_mode == "streaming"
# ------------------- 二开部分Begin - 密钥额度限制 -------------------
args["api_token"] = api_token
# ------------------- 二开部分End - 密钥额度限制 -------------------
try:
response = AppGenerateService.generate(
app_model=app_model, user=end_user, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=streaming
+33 -22
View File
@@ -14,25 +14,26 @@ from sqlalchemy import select, update
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from libs.login import current_user
from models import Account, Tenant, TenantAccountJoin, TenantStatus
from models.dataset import Dataset, RateLimitLog
from models.model import ApiToken, App
from services.end_user_service import EndUserService
from services.feature_service import FeatureService
# extend: start 额度限制,API调用计费,新增TenantAccountRole
from controllers.service_api.app.error_extend import (
AccountNoMoneyErrorExtend,
ApiTokenDayNoMoneyErrorExtend,
ApiTokenMonthNoMoneyErrorExtend,
)
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from libs.login import current_user
from models import Account, Tenant, TenantAccountJoin, TenantStatus
from models.account_money_extend import AccountMoneyExtend
from models.api_token_money_extend import ApiTokenMoneyExtend
from models.dataset import Dataset, RateLimitLog
from models.model import ApiToken, App
from models.model_extend import EndUserAccountJoinsExtend
from services.end_user_service import EndUserService
from services.feature_service import FeatureService
# extend: stop 额度限制,API调用计费,新增TenantAccountRole
@@ -80,17 +81,27 @@ def validate_app_token(view: Callable[P, R] | None = None, *, fetch_user_arg: Fe
if tenant.status == TenantStatus.ARCHIVE:
raise Forbidden("The workspace's status is archived.")
# ---------------------extend: 二开部分Begin 额度限制,API调用计费 ---------------------
# TODO 需要写入缓存,读缓存
account_money = (
db.session.query(AccountMoneyExtend)
.filter(AccountMoneyExtend.account_id == ta.account_id)
.first()
)
if account_money and account_money.used_quota >= account_money.total_quota:
raise AccountNoMoneyErrorExtend()
tenant_account_join = (
db.session.query(Tenant, TenantAccountJoin)
.filter(Tenant.id == api_token.tenant_id)
.filter(TenantAccountJoin.tenant_id == Tenant.id)
.filter(TenantAccountJoin.role.in_(["owner"]))
.filter(Tenant.status == TenantStatus.NORMAL)
.one_or_none()
) # TODO: only owner information is required, so only one is returned.
if tenant_account_join:
tenant, ta = tenant_account_join
# TODO 需要写入缓存,读缓存
account_money = (
db.session.query(AccountMoneyExtend)
.filter(AccountMoneyExtend.account_id == ta.account_id)
.first()
)
if account_money and account_money.used_quota >= account_money.total_quota:
raise AccountNoMoneyErrorExtend()
else:
raise Unauthorized("Tenant does not exist.")
# 密钥额度判断
kwargs["api_token"] = api_token # API token消息数据传递下去
api_token_money = (
@@ -382,6 +393,7 @@ def validate_and_get_api_token(scope: str | None = None):
return api_token
# ---------------------二开部分Begin 额度限制,API调用计费 ---------------------
def create_or_update_end_user_account_join_extend(end_user_id, account_id, app_id: str) -> EndUserAccountJoinsExtend:
# 插入节点账号id和用户账号id关联关系,以方便扣钱查询
@@ -402,7 +414,6 @@ def create_or_update_end_user_account_join_extend(end_user_id, account_id, app_i
# ---------------------二开部分End 额度限制,API调用计费 ---------------------
class DatasetApiResource(Resource):
method_decorators = [validate_dataset_token]
+8 -4
View File
@@ -1,5 +1,6 @@
import logging
from typing import Any, Literal
from pydantic import BaseModel, Field, field_validator
from werkzeug.exceptions import InternalServerError, NotFound
@@ -37,14 +38,16 @@ logger = logging.getLogger(__name__)
# extend: 您必须登录才能访问您的帐户扩展功能
from flask import request
from extensions.ext_database import db
from libs.passport import PassportService
from models.account_money_extend import AccountMoneyExtend
from services.app_generate_service_extend import AppGenerateServiceExtend
from controllers.web.error_extend import (
AccountNoMoneyErrorExtend,
WebAuthRequiredErrorExtend,
)
from extensions.ext_database import db
from libs.passport import PassportService
from models.account_money_extend import AccountMoneyExtend
from services.app_generate_service_extend import AppGenerateServiceExtend
def is_end_login(end_user):
user_info = None
@@ -61,6 +64,7 @@ def is_end_login(end_user):
# no login
return user_info
# 额度限制
def is_money_limit(end_user) -> bool:
try:
-3
View File
@@ -4,9 +4,6 @@ from datetime import UTC, datetime, timedelta
from flask import make_response, request
from flask_restx import Resource
from sqlalchemy import func, select
from flask import request
from flask_restx import Resource
from sqlalchemy import func, select
from werkzeug.exceptions import NotFound, Unauthorized
from configs import dify_config
+2
View File
@@ -36,8 +36,10 @@ from controllers.web.error_extend import (
WebAuthRequiredErrorExtend,
)
from services.app_generate_service_extend import AppGenerateServiceExtend
# extend: stop 您必须登录才能访问您的帐户扩展功能
@web_ns.route("/workflows/run")
class WorkflowRunApi(WebApiResource):
@web_ns.doc("Run Workflow")
-2
View File
@@ -5,8 +5,6 @@ from typing import Union, cast
from sqlalchemy import select
from sqlalchemy import select
from core.agent.entities import AgentEntity, AgentToolEntity
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
@@ -3,7 +3,7 @@ import logging
import threading
import uuid
from collections.abc import Generator, Mapping
from typing import Any, Literal, Union, overload, cast # extend: 二开部分 - 密钥额度限制,新增cast
from typing import Any, Literal, Union, cast, overload # extend: 二开部分 - 密钥额度限制,新增cast
from flask import Flask, current_app
from pydantic import ValidationError
@@ -38,7 +38,16 @@ from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
from extensions.ext_database import db
from factories import file_factory
from libs.flask_utils import preserve_flask_contexts
from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom, ApiToken # extend: 二开部分 - 密钥额度限制,新增ApiToken
from models import ( # extend: 二开部分 - 密钥额度限制,新增ApiToken
Account,
ApiToken,
App,
Conversation,
EndUser,
Message,
Workflow,
WorkflowNodeExecutionTriggeredFrom,
)
from models.enums import WorkflowRunTriggeredFrom
from services.conversation_service import ConversationService
from services.workflow_draft_variable_service import (
@@ -126,7 +135,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
extras["app_token_id"] = api_token.id
# ------------------- 二开部分End - 密钥额度限制 -------------------
# get conversation
conversation = None
conversation_id = args.get("conversation_id")
@@ -202,6 +202,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
workflow_execution_repository=self._workflow_execution_repository,
workflow_node_execution_repository=self._workflow_node_execution_repository,
trace_manager=self.application_generate_entity.trace_manager,
user_from=user_from, # 二开部分 - 用于计费
)
workflow_entry.graph_engine.layer(persistence_layer)
@@ -70,7 +70,14 @@ from core.workflow.runtime import GraphRuntimeState
from core.workflow.system_variable import SystemVariable
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models import AppMode, Account, Conversation, EndUser, Message, MessageFile # extend: 二开部分End - 密钥额度限制,新增AppMode
from models import ( # extend: 二开部分End - 密钥额度限制,新增AppMode
Account,
AppMode,
Conversation,
EndUser,
Message,
MessageFile,
)
from models.api_token_money_extend import ApiTokenMessageJoinsExtend # extend: 二开部分End - 密钥额度限制
from models.enums import CreatorUserRole
from models.workflow import Workflow
@@ -579,37 +586,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
with self._database_session() as session:
err_event = QueueErrorEvent(error=ValueError(f"Run failed: {event.error}"))
err = self._base_task_pipeline.handle_error(event=err_event, session=session, message_id=self._message_id)
def _handle_workflow_failed_event(
self,
event: QueueWorkflowFailedEvent,
*,
graph_runtime_state: Optional[GraphRuntimeState] = None,
trace_manager: Optional[TraceQueueManager] = None,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle workflow failed events."""
self._ensure_workflow_initialized()
validated_state = self._ensure_graph_runtime_initialized(graph_runtime_state)
with self._database_session() as session:
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_failed(
workflow_run_id=self._workflow_run_id,
total_tokens=validated_state.total_tokens,
total_steps=validated_state.node_run_steps,
status=WorkflowExecutionStatus.FAILED,
error_message=event.error,
conversation_id=self._conversation_id,
trace_manager=trace_manager,
exceptions_count=event.exceptions_count,
external_trace_id=self._application_generate_entity.extras.get("external_trace_id"),
)
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_execution=workflow_execution,
)
err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_execution.error_message}"))
err = self._base_task_pipeline._handle_error(event=err_event, session=session, message_id=self._message_id)
yield workflow_finish_resp
yield self._base_task_pipeline.error_to_stream_response(err)
@@ -76,8 +76,12 @@ class AgentChatAppRunner(AppRunner):
files=list(files),
query=query,
memory=memory,
control_registers=False, # Extend: messages_context_handling
)
# Extend: messages_context_handling
self.add_messages_context(prompt_messages, app_config.app_id, conversation.id, message.id)
# moderation
try:
# process sensitive_word_avoidance
+33
View File
@@ -29,7 +29,14 @@ from core.moderation.input_moderation import InputModeration
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform
# extend: start messages_context_handling
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.model import App, AppMode, Message, MessageAnnotation
from models.model_extend import AppExtend, MessageContextExtend
# extend: stop messages_context_handling
if TYPE_CHECKING:
from core.file.models import File
@@ -72,6 +79,29 @@ class AppRunner:
):
model_config.parameters[parameter_rule.name] = max_tokens
# Extend: start messages_context_handling
def add_messages_context(self, prompt_messages, app_id, conversation_id, message_id):
key = "retention_number_{}".format(app_id)
retention_number = redis_client.get(key)
if retention_number is None:
app_extend: AppExtend = (
db.session.query(AppExtend).filter(AppExtend.app_id == app_id).first()
)
if app_extend is None:
return
retention_number = int(app_extend.retention_number)
redis_client.set(key, app_extend.retention_number)
else:
retention_number = int(retention_number)
if (len(prompt_messages) + 2) / 2 > retention_number:
# 插入替换
db.session.add(MessageContextExtend(
conversation_id=conversation_id,
message_id=message_id,
))
db.session.commit()
# Extend: stop messages_context_handling
def organize_prompt_messages(
self,
app_record: App,
@@ -84,6 +114,7 @@ class AppRunner:
memory: TokenBufferMemory | None = None,
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
context_files: list["File"] | None = None,
control_registers: bool = True, # Extend: messages context handling
) -> tuple[list[PromptMessage], list[str] | None]:
"""
Organize prompt messages
@@ -96,6 +127,7 @@ class AppRunner:
:param query: query
:param memory: memory
:param image_detail_config: the image quality config
:param control_registers: is messages context # Extend: messages context handling
:return:
"""
# get prompt without memory and context
@@ -113,6 +145,7 @@ class AppRunner:
model_config=model_config,
image_detail_config=image_detail_config,
context_files=context_files,
control_registers=control_registers, # Extend: messages context handling
)
else:
memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False))
+1 -1
View File
@@ -24,9 +24,9 @@ from core.ops.ops_trace_manager import TraceQueueManager
from extensions.ext_database import db
from factories import file_factory
from models import Account
from models.api_token_money_extend import ApiTokenMessageJoinsExtend # 二开部分End - 密钥额度限制
from models.model import App, EndUser
from services.conversation_service import ConversationService
from models.api_token_money_extend import ApiTokenMessageJoinsExtend # 二开部分End - 密钥额度限制
logger = logging.getLogger(__name__)
+15
View File
@@ -224,6 +224,21 @@ class ChatAppRunner(AppRunner):
user=application_generate_entity.user_id,
)
# Extend: Start messages context handling
messages_list, _ = self.organize_prompt_messages(
app_record=app_record,
model_config=application_generate_entity.model_conf,
prompt_template_entity=app_config.prompt_template,
inputs=inputs,
files=files,
query=query,
context=context,
memory=memory
)
# Extend: Stop messages_context_handling
self.add_messages_context(messages_list, app_config.app_id, conversation.id, message.id)
# handle invoke result
self._handle_invoke_result(
invoke_result=invoke_result, queue_manager=queue_manager, stream=application_generate_entity.stream
@@ -187,6 +187,7 @@ class PipelineRunner(WorkflowBasedAppRunner):
workflow_execution_repository=self._workflow_execution_repository,
workflow_node_execution_repository=self._workflow_node_execution_repository,
trace_manager=self.application_generate_entity.trace_manager,
user_from=user_from, # 二开部分 - 用于计费
)
workflow_entry.graph_engine.layer(persistence_layer)
+4 -2
View File
@@ -3,8 +3,9 @@ import logging
import threading
import uuid
from collections.abc import Generator, Mapping, Sequence
# extend: 二开部分 - 密钥额度限制,新增cast
from typing import Any, Literal, Union, overload, cast
from typing import Any, Literal, Union, cast, overload
from flask import Flask, current_app
from pydantic import ValidationError
@@ -36,8 +37,9 @@ from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
from extensions.ext_database import db
from factories import file_factory
from libs.flask_utils import preserve_flask_contexts
# extend: 二开部分 - 密钥额度限制,新增ApiToken
from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom, ApiToken
from models import Account, ApiToken, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
from models.enums import WorkflowRunTriggeredFrom
from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService
+1
View File
@@ -145,6 +145,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
workflow_execution_repository=self._workflow_execution_repository,
workflow_node_execution_repository=self._workflow_node_execution_repository,
trace_manager=self.application_generate_entity.trace_manager,
user_from=user_from, # 二开部分 - 用于计费
)
workflow_entry.graph_engine.layer(persistence_layer)
@@ -57,9 +57,11 @@ from core.workflow.runtime import GraphRuntimeState
from core.workflow.system_variable import SystemVariable
from extensions.ext_database import db
from models import Account
from models.enums import CreatorUserRole
# extend: 二开部分End - 密钥额度限制
from models.api_token_money_extend import ApiTokenMessageJoinsExtend
from models.enums import CreatorUserRole
# extend: 二开部分End - 密钥额度限制,新增AppMode
from models.model import AppMode, EndUser
from models.workflow import Workflow, WorkflowAppLog, WorkflowAppLogCreatedFrom
@@ -2,8 +2,6 @@ import logging
from sqlalchemy import select
from sqlalchemy import select
from core.app.entities.app_invoke_entities import InvokeFrom
from core.rag.datasource.vdb.vector_factory import Vector
from extensions.ext_database import db
-2
View File
@@ -1,7 +1,5 @@
from sqlalchemy import select
from sqlalchemy import select
from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor
from core.external_data_tool.base import ExternalDataTool
from core.helper import encrypter
+42 -1
View File
@@ -19,6 +19,9 @@ from core.prompt.utils.extract_thread_messages import extract_thread_messages
from extensions.ext_database import db
from factories import file_factory
from models.model import AppMode, Conversation, Message, MessageFile
# Extend: start messages context handling
from models.model_extend import MessageContextExtend
from models.workflow import Workflow
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
from repositories.factory import DifyAPIRepositoryFactory
@@ -34,6 +37,31 @@ class TokenBufferMemory:
self.model_instance = model_instance
self._workflow_run_repo: APIWorkflowRunRepository | None = None
# Extend: start messages context handling
def messages_context_handling(
self,
conversation_id: str,
prompt_messages: tuple[list[AssistantPromptMessage]],
) -> tuple[list[AssistantPromptMessage]]:
# check if there is a segmentation context
message_context = db.session.query(MessageContextExtend).filter(
MessageContextExtend.conversation_id == conversation_id).order_by(
MessageContextExtend.created_at.desc()).all()
# Is there a split
if not message_context:
return prompt_messages
# for
messages = []
for v in prompt_messages:
messages.append(v)
if v.name is not None and len(v.name) > 0:
for i in message_context:
if v.name == i.message_id:
messages = []
v.name = None
return messages
# Extend: stop messages context handling
@property
def workflow_run_repo(self) -> APIWorkflowRunRepository:
if self._workflow_run_repo is None:
@@ -115,12 +143,14 @@ class TokenBufferMemory:
return AssistantPromptMessage(content=prompt_message_contents)
def get_history_prompt_messages(
self, max_token_limit: int = 2000, message_limit: int | None = None
self, max_token_limit: int = 2000, message_limit: int | None = None,
control_registers: bool = True, # Extend: messages context handling
) -> Sequence[PromptMessage]:
"""
Get history prompt messages.
:param max_token_limit: max token limit
:param message_limit: message limit
:param control_registers:
"""
app_record = self.conversation.app
@@ -188,6 +218,17 @@ class TokenBufferMemory:
else:
prompt_messages.append(AssistantPromptMessage(content=message.answer))
# Extend Contextual dividing line
prompt_messages.append(AssistantPromptMessage(name=message.id, content=message.answer))
# Extend: start messages context handling
if control_registers:
prompt_messages = self.messages_context_handling(
prompt_messages=prompt_messages,
conversation_id=self.conversation.id,
)
# Extend: stop messages context handling
if not prompt_messages:
return []
+5 -1
View File
@@ -15,9 +15,11 @@ class PromptTransform:
memory_config: MemoryConfig,
prompt_messages: list[PromptMessage],
model_config: ModelConfigWithCredentialsEntity,
control_registers: bool = True, # Extend: messages context handling
) -> list[PromptMessage]:
rest_tokens = self._calculate_rest_token(prompt_messages, model_config)
histories = self._get_history_messages_list_from_memory(memory, memory_config, rest_tokens)
# Extend: messages context handling
histories = self._get_history_messages_list_from_memory(memory, memory_config, rest_tokens, control_registers)
prompt_messages.extend(histories)
return prompt_messages
@@ -74,6 +76,7 @@ class PromptTransform:
def _get_history_messages_list_from_memory(
self, memory: TokenBufferMemory, memory_config: MemoryConfig, max_token_limit: int
, control_registers: bool = True, # Extend: messages context handling
) -> list[PromptMessage]:
"""Get memory messages."""
return list(
@@ -86,5 +89,6 @@ class PromptTransform:
and memory_config.window.size > 0
)
else None,
control_registers=control_registers, # Extend: messages context handling
)
)
@@ -50,6 +50,7 @@ class SimplePromptTransform(PromptTransform):
model_config: ModelConfigWithCredentialsEntity,
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
context_files: list["File"] | None = None,
control_registers: bool = True, # Extend: messages context handling
) -> tuple[list[PromptMessage], list[str] | None]:
inputs = {key: str(value) for key, value in inputs.items()}
@@ -66,6 +67,7 @@ class SimplePromptTransform(PromptTransform):
model_config=model_config,
image_detail_config=image_detail_config,
context_files=context_files,
control_registers=control_registers, # Extend: messages context handling
)
else:
prompt_messages, stops = self._get_completion_model_prompt_messages(
@@ -191,6 +193,7 @@ class SimplePromptTransform(PromptTransform):
model_config: ModelConfigWithCredentialsEntity,
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
context_files: list["File"] | None = None,
control_registers: bool = True, # Extend: messages context handling
) -> tuple[list[PromptMessage], list[str] | None]:
prompt_messages: list[PromptMessage] = []
@@ -217,6 +220,7 @@ class SimplePromptTransform(PromptTransform):
),
prompt_messages=prompt_messages,
model_config=model_config,
control_registers=control_registers, # Extend: messages context handling
)
if query:
@@ -12,12 +12,6 @@ from core.rag.extractor.watercrawl.exceptions import (
WaterCrawlPermissionError,
)
from core.rag.extractor.watercrawl.exceptions import (
WaterCrawlAuthenticationError,
WaterCrawlBadRequestError,
WaterCrawlPermissionError,
)
class BaseAPIClient:
def __init__(self, api_key, base_url):
+50 -13
View File
@@ -122,6 +122,9 @@ class ToolManager:
"""
get the plugin provider
"""
# extend: 获取插件提供程序
from core.plugin.impl.exc import PluginNotFoundError
# check if context is set
try:
@@ -141,19 +144,53 @@ class ToolManager:
return plugin_tool_providers[provider]
manager = PluginToolManager()
provider_entity = manager.fetch_tool_provider(tenant_id, provider)
if not provider_entity:
raise ToolProviderNotFoundError(f"plugin provider {provider} not found")
# extend: start 获取插件提供程序
max_retries = 2
last_error = None
controller = PluginToolProviderController(
entity=provider_entity.declaration,
plugin_id=provider_entity.plugin_id,
plugin_unique_identifier=provider_entity.plugin_unique_identifier,
tenant_id=tenant_id,
)
for attempt in range(max_retries):
try:
provider_entity = manager.fetch_tool_provider(tenant_id, provider)
if not provider_entity:
raise ToolProviderNotFoundError(f"plugin provider {provider} not found")
plugin_tool_providers[provider] = controller
return controller
controller = PluginToolProviderController(
entity=provider_entity.declaration,
plugin_id=provider_entity.plugin_id,
plugin_unique_identifier=provider_entity.plugin_unique_identifier,
tenant_id=tenant_id,
)
plugin_tool_providers[provider] = controller
return controller
except PluginNotFoundError as e:
last_error = e
# Clear cache and retry once more
if attempt < max_retries - 1:
logger.warning(
f"Plugin {provider} not found on attempt {attempt + 1}, clearing cache and retrying. "
f"Error: {str(e)}"
)
# Remove from cache if exists
plugin_tool_providers.pop(provider, None)
# Small delay before retry
time.sleep(0.5)
else:
logger.error(
f"Plugin {provider} not found after {max_retries} attempts. "
f"Last error: {str(e)}"
)
raise ToolProviderNotFoundError(f"plugin provider {provider} not found after retries: {str(e)}")
except Exception as e:
# For other errors, don't retry
logger.exception(f"Error fetching plugin provider {provider}: {str(e)}")
raise
# Should not reach here, but just in case
if last_error:
raise ToolProviderNotFoundError(f"plugin provider {provider} not found: {str(last_error)}")
raise ToolProviderNotFoundError(f"plugin provider {provider} not found")
# extend: stop 获取插件提供程序
@classmethod
def get_tool_runtime(
@@ -634,9 +671,9 @@ class ToolManager:
# MySQL: Use window function to achieve same result
sql = """
SELECT id FROM (
SELECT id,
SELECT id,
ROW_NUMBER() OVER (
PARTITION BY tenant_id, provider
PARTITION BY tenant_id, provider
ORDER BY is_default DESC, created_at DESC
) as rn
FROM tool_builtin_providers
@@ -15,6 +15,7 @@ from datetime import datetime
from typing import Any, Union
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
@@ -48,6 +49,14 @@ from core.workflow.repositories.workflow_node_execution_repository import Workfl
from core.workflow.workflow_entry import WorkflowEntry
from libs.datetime_utils import naive_utc_now
# extend: start 二开部分 - 计费相关的用户信息
from models.enums import CreatorUserRole, UserFrom
from tasks.extend.update_account_money_when_workflow_node_execution_created_extend import (
update_account_money_when_workflow_node_execution_created_extend,
)
# extend: stop 二开部分 - 计费相关的用户信息
@dataclass(slots=True)
class PersistenceWorkflowInfo:
@@ -82,6 +91,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
trace_manager: TraceQueueManager | None = None,
user_from: UserFrom | None = None, # 二开部分 - 用于计费
) -> None:
super().__init__()
self._application_generate_entity = application_generate_entity
@@ -89,6 +99,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
self._workflow_execution_repository = workflow_execution_repository
self._workflow_node_execution_repository = workflow_node_execution_repository
self._trace_manager = trace_manager
self._user_from = user_from # 二开部分 - 用于计费
self._workflow_execution: WorkflowExecution | None = None
self._node_execution_cache: dict[str, WorkflowNodeExecution] = {}
@@ -270,6 +281,26 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
def _handle_node_succeeded(self, event: NodeRunSucceededEvent) -> None:
domain_execution = self._get_node_execution(event.id)
self._update_node_execution(domain_execution, event.node_run_result, WorkflowNodeExecutionStatus.SUCCEEDED)
# 二开部分Begin - 计费
# 异步任务计算费用并更新账户额度,将对象转换为字典传递
domain_execution_dict = jsonable_encoder(domain_execution)
# 添加用户信息到字典中
domain_execution_dict['created_by'] = self._application_generate_entity.user_id
if self._user_from == UserFrom.ACCOUNT:
domain_execution_dict['created_by_role'] = CreatorUserRole.ACCOUNT.value
elif self._user_from == UserFrom.END_USER:
domain_execution_dict['created_by_role'] = CreatorUserRole.END_USER.value
else:
domain_execution_dict['created_by_role'] = None
# 添加 workflow_run_id
if self._workflow_execution:
domain_execution_dict['workflow_run_id'] = self._workflow_execution.id_
update_account_money_when_workflow_node_execution_created_extend.delay(domain_execution_dict)
# 二开部分End - 计费
def _handle_node_failed(self, event: NodeRunFailedEvent) -> None:
domain_execution = self._get_node_execution(event.id)
@@ -403,3 +434,4 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
def _system_variables(self) -> Mapping[str, Any]:
runtime_state = self.graph_runtime_state
return runtime_state.variable_pool.get_by_prefix(SYSTEM_VARIABLE_NODE_ID)
+3 -2
View File
@@ -11,10 +11,11 @@ from core.variables.types import SegmentType
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.code.entities import CodeNodeData
from core.workflow.nodes.code.limits import CodeNodeLimits
# Extend: Adding execution control logic
from core.workflow.nodes.code.control_extend import ExecutionControl
from core.workflow.nodes.code.entities import CodeNodeData
from core.workflow.nodes.code.limits import CodeNodeLimits
from .exc import (
CodeNodeError,
@@ -38,6 +38,7 @@ from tasks.extend.update_account_money_when_workflow_node_execution_created_exte
# 二开部分End - 密钥额度限制
@dataclass
class CycleManagerWorkflowInfo:
workflow_id: str
@@ -55,12 +56,18 @@ class WorkflowCycleManager:
workflow_info: CycleManagerWorkflowInfo,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
user_id: Optional[str] = None, # extend:二开部分 - 计费相关的用户信息
created_by_role=None, # extend:二开部分 - 计费相关的用户信息
) -> None:
self._application_generate_entity = application_generate_entity
self._workflow_system_variables = workflow_system_variables
self._workflow_info = workflow_info
self._workflow_execution_repository = workflow_execution_repository
self._workflow_node_execution_repository = workflow_node_execution_repository
# extend:二开部分 - 计费相关的用户信息
self._user_id = user_id
self._created_by_role = created_by_role
# Initialize caches for workflow execution cycle
# These caches avoid redundant repository calls during a single workflow execution
+4 -4
View File
@@ -10,12 +10,12 @@ from .queue_credential_sync_when_tenant_created import handle as handle_queue_cr
from .sync_plugin_trigger_when_app_created import handle as handle_sync_plugin_trigger_when_app_created
from .sync_webhook_when_app_created import handle as handle_sync_webhook_when_app_created
from .sync_workflow_schedule_when_app_published import handle as handle_sync_workflow_schedule_when_app_published
from .update_app_dataset_join_when_app_model_config_updated import (
handle as handle_update_app_dataset_join_when_app_model_config_updated,
)
from .update_account_money_when_messaeg_created_extend import (
handle as handle_update_account_money_when_messaeg_created_extend,
) # 二开部分:新增限额判断
from .update_app_dataset_join_when_app_model_config_updated import (
handle as handle_update_app_dataset_join_when_app_model_config_updated,
)
from .update_app_dataset_join_when_app_published_workflow_updated import (
handle as handle_update_app_dataset_join_when_app_published_workflow_updated,
)
@@ -38,9 +38,9 @@ __all__ = [
"handle_sync_plugin_trigger_when_app_created",
"handle_sync_webhook_when_app_created",
"handle_sync_workflow_schedule_when_app_published",
"handle_update_account_money_when_messaeg_created_extend", # Extend messaeg_created_extend
"handle_update_app_dataset_join_when_app_model_config_updated",
"handle_update_app_dataset_join_when_app_published_workflow_updated",
"handle_update_app_triggers_when_app_published_workflow_updated",
"handle_update_provider_when_message_created",
"handle_update_account_money_when_messaeg_created_extend",# Extend messaeg_created_extend
]
+7 -1
View File
@@ -2,7 +2,7 @@ import json
import os
import threading
from flask import Response
from flask import Response, request # Extend: 新增request
from configs import dify_config
from dify_app import DifyApp
@@ -14,6 +14,12 @@ def init_app(app: DifyApp):
"""Add Version headers to the response."""
response.headers.add("X-Version", dify_config.project.version)
response.headers.add("X-Env", dify_config.DEPLOY_ENV)
# Extend: Start New proxy authentication: Login and write JWT token to cookies
cookie = request.cookies.get("x-token")
token = request.headers.get("Authorization")
if token is not None and len(token) > 0 and token != cookie:
response.set_cookie("x-token", token[7:], httponly=True)
# Extend: Stop New proxy authentication: Login and write JWT token to cookies
return response
@app.route("/health")
+1 -1
View File
@@ -4,7 +4,7 @@ from dify_app import DifyApp
BASE_CORS_HEADERS: tuple[str, ...] = ("Content-Type", HEADER_NAME_APP_CODE, HEADER_NAME_PASSPORT)
SERVICE_API_HEADERS: tuple[str, ...] = (*BASE_CORS_HEADERS, "Authorization", "Authorization-extend", "X-App-Code")
AUTHENTICATED_HEADERS: tuple[str, ...] = (*SERVICE_API_HEADERS, HEADER_NAME_CSRF_TOKEN), "Authorization-extend", "X-App-Code")
AUTHENTICATED_HEADERS: tuple[str, ...] = (*SERVICE_API_HEADERS, HEADER_NAME_CSRF_TOKEN, "Authorization-extend", "X-App-Code")
FILES_HEADERS: tuple[str, ...] = (*BASE_CORS_HEADERS, HEADER_NAME_CSRF_TOKEN, "Authorization-extend")
EMBED_HEADERS: tuple[str, ...] = ("Content-Type", HEADER_NAME_APP_CODE, "Authorization-extend", "X-App-Code")
EXPOSED_HEADERS: tuple[str, ...] = ("X-Version", "X-Env", "X-Trace-Id")
+24
View File
@@ -102,6 +102,7 @@ def init_app(app: DifyApp) -> Celery:
imports = [
"tasks.async_workflow_tasks", # trigger workers
"tasks.trigger_processing_tasks", # async trigger processing
"tasks.extend.update_account_money_when_workflow_node_execution_created_extend", # 二开部分 - workflow计费任务
]
day = dify_config.CELERY_BEAT_SCHEDULER_TIME
@@ -175,6 +176,29 @@ def init_app(app: DifyApp) -> Celery:
"task": "schedule.trigger_provider_refresh_task.trigger_provider_refresh",
"schedule": timedelta(minutes=dify_config.TRIGGER_PROVIDER_REFRESH_INTERVAL),
}
# ---------------------------- 二开部分 Begin ----------------------------
# 添加二开的定时任务imports
imports.append("schedule.update_account_used_quota_extend") # 每月重置账号额度
imports.append("schedule.update_api_token_daily_used_quota_task_extend") # 重置密钥日额度
imports.append("schedule.update_api_token_monthly_used_quota_task_extend") # 重置密钥月额度
# 每月1号00:00,重置账号额度
beat_schedule["update_account_used_quota"] = {
"task": "schedule.update_account_used_quota_extend.update_account_used_quota_extend",
"schedule": crontab(minute="0", hour="0", day_of_month="1"),
}
# 每天00:00,重置密钥日额度
beat_schedule["update_api_token_daily_used_quota_task_extend"] = {
"task": "schedule.update_api_token_daily_used_quota_task_extend.update_api_token_daily_used_quota_task_extend",
"schedule": crontab(minute="0", hour="0"),
}
# 每月1号00:00,重置密钥月额度
beat_schedule["update_api_token_monthly_used_quota_task_extend"] = {
"task": "schedule.update_api_token_monthly_used_quota_task_extend.update_api_token_monthly_used_quota_task_extend",
"schedule": crontab(minute="0", hour="0", day_of_month="1"),
}
# ---------------------------- 二开部分 End ----------------------------
celery_app.conf.update(beat_schedule=beat_schedule, imports=imports)
return celery_app
+1 -1
View File
@@ -56,7 +56,7 @@ def init_app(app: DifyApp):
migrate_oss,
setup_datasource_oauth_client,
transform_datasource_credentials,
install_rag_pipeline_plugins,,
install_rag_pipeline_plugins,
extend_db,
]
for cmd in cmds_to_register:
+1
View File
@@ -198,6 +198,7 @@ app_detail_fields_with_site = {
"access_mode": fields.String,
"tags": fields.List(fields.Nested(tag_fields)),
"site": fields.Nested(site_fields),
"retention_number": fields.Integer, # Extend: 记忆上下文功能
}
+1 -1
View File
@@ -215,7 +215,7 @@ class OaOAuth(OAuth):
return current
def get_authorization_url(self, invite_token: Optional[str] = None):
def get_authorization_url(self, invite_token: str | None = None):
auto2_conf = self.get_auto2_conf()
integration = auto2_conf.get('integration')
if integration is None:
+3
View File
@@ -22,6 +22,7 @@ logger = logging.getLogger('alembic.env')
# 将当前目录的父目录(api目录)添加到Python路径中,以便能够导入models模块
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
# 获取当前运行的应用数据库引擎和URL
def get_engine():
try:
@@ -29,6 +30,7 @@ def get_engine():
except (KeyError, AttributeError):
return current_app.extensions['migrate'].db.engine
def get_engine_url():
try:
return get_engine().url.render_as_string(hide_password=False).replace(
@@ -36,6 +38,7 @@ def get_engine_url():
except AttributeError:
return str(get_engine().url).replace('%', '%%')
# 使用当前应用的数据库URL替换配置文件中的URL
config.set_main_option('sqlalchemy.url', get_engine_url())
@@ -0,0 +1,45 @@
"""add_app_extend
Revision ID: 012_app_extend
Revises: 011_system_integration_fields
Create Date: 2025-01-15 12:00:00.000000
"""
import sqlalchemy as sa
from alembic import op
from sqlalchemy.engine.reflection import Inspector
from models import types
# revision identifiers, used by Alembic.
revision = '012_app_extend'
down_revision = '011_system_integration_fields'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
conn = op.get_bind()
inspector = Inspector.from_engine(conn)
tables = inspector.get_table_names()
if 'app_extend' not in tables:
op.create_table('app_extend',
sa.Column('id', types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('app_id', types.StringUUID(), nullable=False),
sa.Column('retention_number', sa.Integer(), nullable=True),
sa.PrimaryKeyConstraint('id', name='app_extend_joins_pkey')
)
with op.batch_alter_table('app_extend', schema=None) as batch_op:
batch_op.create_index('app_extend_id_app_id_idx', ['app_id'], unique=False)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('app_extend', schema=None) as batch_op:
batch_op.drop_index('app_extend_id_app_id_idx')
op.drop_table('app_extend')
# ### end Alembic commands ###
@@ -0,0 +1,48 @@
"""add_message_context_extend
Revision ID: 013_message_context_extend
Revises: 012_app_extend
Create Date: 2025-01-15 13:00:00.000000
"""
import sqlalchemy as sa
from alembic import op
from sqlalchemy.engine.reflection import Inspector
from models import types
# revision identifiers, used by Alembic.
revision = '013_message_context_extend'
down_revision = '012_app_extend'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
conn = op.get_bind()
inspector = Inspector.from_engine(conn)
tables = inspector.get_table_names()
if 'message_context_extend' not in tables:
op.create_table('message_context_extend',
sa.Column('id', types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
sa.Column('conversation_id', sa.String(length=36), nullable=True),
sa.Column('message_id', sa.String(length=36), nullable=False),
sa.PrimaryKeyConstraint('id', name='message_context_extend_joins_pkey')
)
with op.batch_alter_table('message_context_extend', schema=None) as batch_op:
batch_op.create_index('message_context_conversation_id_idx', ['conversation_id'], unique=False)
batch_op.create_index('message_context_created_at_idx', ['created_at'], unique=False)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('message_context_extend', schema=None) as batch_op:
batch_op.drop_index('message_context_created_at_idx')
batch_op.drop_index('message_context_conversation_id_idx')
op.drop_table('message_context_extend')
# ### end Alembic commands ###
@@ -1,62 +0,0 @@
"""add_account_money_extend_unique_constraint
Revision ID: 012_account_money_extend_unique
Revises: 011_system_integration_fields
Create Date: 2025-10-21 18:00:00.000000
"""
import sqlalchemy as sa
from alembic import op
from sqlalchemy.engine.reflection import Inspector
# revision identifiers, used by Alembic.
revision = '012_account_money_extend_unique'
down_revision = '011_system_integration_fields'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
conn = op.get_bind()
inspector = Inspector.from_engine(conn)
tables = inspector.get_table_names()
if 'account_money_extend' in tables:
# 首先删除重复数据,只保留每个account_id中updated_at最大的记录
conn.execute(sa.text("""
DELETE FROM account_money_extend
WHERE id NOT IN (
SELECT DISTINCT ON (account_id) id
FROM account_money_extend
ORDER BY account_id, updated_at DESC
)
"""))
# 删除现有的普通索引
with op.batch_alter_table('account_money_extend', schema=None) as batch_op:
try:
batch_op.drop_index('idx_account_money_account_id')
except Exception:
# 如果索引不存在,忽略错误
pass
# 创建唯一约束
with op.batch_alter_table('account_money_extend', schema=None) as batch_op:
batch_op.create_unique_constraint('idx_account_money_account_id_unique', ['account_id'])
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
if 'account_money_extend' in tables:
with op.batch_alter_table('account_money_extend', schema=None) as batch_op:
try:
batch_op.drop_constraint('idx_account_money_account_id_unique', type_='unique')
except Exception:
# 如果约束不存在,忽略错误
pass
# 重新创建普通索引
batch_op.create_index('idx_account_money_account_id', ['account_id'], unique=False)
# ### end Alembic commands ###
+5 -1
View File
@@ -26,6 +26,9 @@ from .dataset import (
TidbAuthBinding,
Whitelist,
)
# extend: db
from .engine import db
from .enums import (
AppTriggerStatus,
AppTriggerType,
@@ -174,7 +177,7 @@ __all__ = [
"RecommendedApp",
"SavedMessage",
"Site",
"SystemIntegrationExtend", # Extend System Integration
"SystemIntegrationExtend", # Extend System Integration
"Tag",
"TagBinding",
"Tenant",
@@ -209,4 +212,5 @@ __all__ = [
"WorkflowToolProvider",
"WorkflowTriggerStatus",
"WorkflowType",
"db", # extend: db
]
+1
View File
@@ -577,6 +577,7 @@ class RecommendedAppsCategoryJoinExtend(db.Model):
recommended_id = db.Column(StringUUID, nullable=False)
category_id = db.Column(StringUUID, nullable=False)
class RecommendedApp(Base): # bug
__tablename__ = "recommended_apps"
__table_args__ = (
+31
View File
@@ -8,6 +8,7 @@ class EndUserAccountJoinsExtend(db.Model):
__table_args__ = (
db.PrimaryKeyConstraint("id", name="end_user_account_joins_pkey"),
db.Index("end_user_account_joins_account_id_idx", "account_id"),
db.Index("end_user_account_joins_end_user_id_idx", "end_user_id"), # 单独索引,用于计费查询优化
db.Index("end_user_account_joins_end_user_id_app_id_idx", "end_user_id", "app_id"),
)
@@ -17,3 +18,33 @@ class EndUserAccountJoinsExtend(db.Model):
app_id = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
# Extend: 记忆上下文功能
class AppExtend(db.Model):
__tablename__ = "app_extend"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="app_extend_joins_pkey"),
db.Index("app_extend_id_app_id_idx", "app_id"),
)
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
app_id = db.Column(StringUUID, nullable=False)
retention_number = db.Column(db.Integer, nullable=True)
# Extend: 记忆上下文功能
# Extend: 消息上下文分割功能
class MessageContextExtend(db.Model):
__tablename__ = "message_context_extend"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="message_context_extend_joins_pkey"),
db.Index("message_context_conversation_id_idx", "conversation_id"),
db.Index("message_context_created_at_idx", "created_at"),
)
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
conversation_id = db.Column(db.String(36), nullable=True)
message_id = db.Column(db.String(36), nullable=False)
# Extend: 消息上下文分割功能
+4 -4
View File
@@ -9,10 +9,10 @@ from .engine import db
class SystemIntegrationClassify:
SYSTEM_INTEGRATION_DINGTALK = 1 # 钉钉
SYSTEM_INTEGRATION_WEIXIN = 2 # 微信
SYSTEM_INTEGRATION_FEI_SU = 3 # 飞书
SYSTEM_INTEGRATION_OAUTH_TWO = 4 # OAuth2
SYSTEM_INTEGRATION_DINGTALK = 1 # 钉钉
SYSTEM_INTEGRATION_WEIXIN = 2 # 微信
SYSTEM_INTEGRATION_FEI_SU = 3 # 飞书
SYSTEM_INTEGRATION_OAUTH_TWO = 4 # OAuth2
class SystemIntegrationExtend(db.Model):
+8
View File
@@ -21,6 +21,7 @@ from libs.datetime_utils import naive_utc_now
from libs.login import current_user
from models import Account
from models.model import App, AppMode, AppModelConfig, AppStatisticsExtend, RecommendedApp, Site
from models.model_extend import AppExtend # Extend: 记忆上下文功能
from models.tools import ApiToolProvider
from services.billing_service import BillingService
from services.enterprise.enterprise_service import EnterpriseService
@@ -259,6 +260,13 @@ class AppService:
return model_config
app = ModifiedApp(app)
# Extend: 记忆上下文功能 - Start
app_extend: AppExtend = db.session.query(AppExtend).filter(AppExtend.app_id == app.id).first()
if app_extend is not None:
app.retention_number = app_extend.retention_number
else:
app.retention_number = dify_config.DEFAULT_NUMBER_CONTEXT
# Extend: 记忆上下文功能 - Stop
return app
+5 -5
View File
@@ -20,7 +20,7 @@ from services.account_service import AccountService, RegisterService, TenantServ
from services.account_service_extend import TenantExtendService
logger = logging.getLogger(__name__)
DINGTALK_ACCOUNT_TOKEN = { "time": 0, "token": "" }
DINGTALK_ACCOUNT_TOKEN = {"time": 0, "token": ""}
class DingTalkService:
@@ -97,7 +97,7 @@ class DingTalkService:
dingTalkToken, err = cls.get_access_token()
responses = requests.post(
f'https://oapi.dingtalk.com/topapi/v2/user/get?access_token={dingTalkToken}',
json={ "userid": userid },
json={"userid": userid},
)
# Check the response status code
if responses.status_code != 200:
@@ -148,7 +148,7 @@ class DingTalkService:
return "", f"Failed to obtain token: {err}"
response = requests.get(
"https://api.dingtalk.com/v1.0/contact/users/me",
headers={ "x-acs-dingtalk-access-token": userToken },
headers={"x-acs-dingtalk-access-token": userToken},
)
# Check the response status code
if response.status_code != 200:
@@ -161,7 +161,7 @@ class DingTalkService:
dingTalkToken, err = cls.get_access_token()
unionIdResponse = requests.post(
f"https://oapi.dingtalk.com/topapi/user/getbyunionid?access_token={dingTalkToken}",
json={ "unionid": req["unionId"] }
json={"unionid": req["unionId"]}
)
# Check the response status code
if unionIdResponse.status_code != 200:
@@ -185,7 +185,7 @@ class DingTalkService:
return "", f"Failed to obtain token: {err}"
response = requests.post(
f"{host}/getuserinfo?access_token={token}",
json={ "code": code },
json={"code": code},
)
# Check the response status code
if response.status_code != 200:
+35 -26
View File
@@ -1,21 +1,23 @@
import json
# extend start: oauth2 and DingTalk third-party login
import re
from enum import StrEnum
from flask import has_app_context, has_request_context, request
from pydantic import BaseModel, ConfigDict, Field
from configs import dify_config
from enums.cloud_plan import CloudPlan
from services.billing_service import BillingService
from services.enterprise.enterprise_service import EnterpriseService
# extend start: oauth2 and DingTalk third-party login
import re
import json
from flask import request
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.system_extend import SystemIntegrationClassify, SystemIntegrationExtend
from services.billing_service import BillingService
from services.enterprise.enterprise_service import EnterpriseService
# extend stop: oauth2 and DingTalk third-party login
class SubscriptionModel(BaseModel):
plan: str = CloudPlan.SANDBOX
interval: str = ""
@@ -180,9 +182,9 @@ class SystemFeatureModel(BaseModel):
plugin_manager: PluginManagerModel = PluginManagerModel()
is_custom_auth2: str = "" # extend: Customizing AUTH2
is_custom_auth2_logout: str = "" # extend: Customizing AUTH2
ding_talk_client_id: str = "" # extend: DingTalk third-party login
ding_talk_corp_id: str = "" # extend: DingTalk sidebar login
ding_talk: bool = "" # extend: DingTalk sidebar login
ding_talk_client_id: str = "" # extend: DingTalk third-party login
ding_talk_corp_id: str = "" # extend: DingTalk sidebar login
ding_talk: bool = "" # extend: DingTalk sidebar login
class FeatureService:
@@ -216,9 +218,14 @@ class FeatureService:
def get_system_features(cls) -> SystemFeatureModel:
system_features = SystemFeatureModel()
# extend start: oauth2
api_host = request.host_url
# 通过nginx代理转发会导致 request.host_url 获取的是内网ip,这个时候使用.env的配置
if bool(re.search(r'^(?:[0-9]{1,3}\.){3}[0-9]{1,3}', request.host_url)):
# 检查是否有请求上下文(在 Celery worker 中可能没有)
if has_request_context():
api_host = request.host_url
# 通过nginx代理转发会导致 request.host_url 获取的是内网ip,这个时候使用.env的配置
if bool(re.search(r'^(?:[0-9]{1,3}\.){3}[0-9]{1,3}', request.host_url)):
api_host = dify_config.CONSOLE_WEB_URL
else:
# 没有请求上下文时(如 Celery worker),直接使用配置值
api_host = dify_config.CONSOLE_WEB_URL
redis_client.set("api_host", api_host)
# extend stop: oauth2
@@ -246,19 +253,21 @@ class FeatureService:
system_features.is_allow_create_workspace = dify_config.ALLOW_CREATE_WORKSPACE
system_features.is_email_setup = dify_config.MAIL_TYPE is not None and dify_config.MAIL_TYPE != ""
# extend start: DingTalk third-party login
for i in db.session.query(SystemIntegrationExtend).filter(SystemIntegrationExtend.status == True).all():
if i.classify == SystemIntegrationClassify.SYSTEM_INTEGRATION_DINGTALK:
system_features.ding_talk_client_id = i.app_key
system_features.ding_talk_corp_id = i.corp_id
system_features.ding_talk = i.status
# Extend: OAuth2 Start
elif i.classify == SystemIntegrationClassify.SYSTEM_INTEGRATION_OAUTH_TWO:
config = json.loads(i.config)
system_features.is_custom_auth2 = i.status
if "logout_url" in config.keys():
system_features.is_custom_auth2_logout = "{}{}".format(
config['server_url'], config['logout_url'])
# Extend: OAuth2 Stop
# 检查是否有应用上下文(访问 db.session 需要应用上下文)
if has_app_context():
for i in db.session.query(SystemIntegrationExtend).filter(SystemIntegrationExtend.status == True).all():
if i.classify == SystemIntegrationClassify.SYSTEM_INTEGRATION_DINGTALK:
system_features.ding_talk_client_id = i.app_key
system_features.ding_talk_corp_id = i.corp_id
system_features.ding_talk = i.status
# Extend: OAuth2 Start
elif i.classify == SystemIntegrationClassify.SYSTEM_INTEGRATION_OAUTH_TWO:
config = json.loads(i.config)
system_features.is_custom_auth2 = i.status
if "logout_url" in config.keys():
system_features.is_custom_auth2_logout = "{}{}".format(
config['server_url'], config['logout_url'])
# Extend: OAuth2 Stop
# extend stop: DingTalk third-party login
@classmethod
+1 -1
View File
@@ -8,7 +8,7 @@ from services.model_provider_service_extend import ModelProviderExtendService
class ModelExtendService:
@staticmethod
def sync_set_all_model_to_tenant(tenant_id: str) -> bool:
logging.info(f"开始同步所有模型到工作区: {tenant_id}")
logging.info("开始同步所有模型到工作区: %s", tenant_id)
model_provider_service_extend = ModelProviderExtendService()
# 同步供应商+模型名称的模型数据
provider_model_records = TenantExtendService.get_sync_all_model()
@@ -2,6 +2,7 @@ from sqlalchemy import select
from constants.languages import languages
from extensions.ext_database import db
# extend add category to categories
from models.model import App, RecommendedApp, RecommendedAppsCategoryJoinExtend, RecommendedCategoryExtend
from services.app_dsl_service import AppDslService
@@ -67,22 +68,34 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase):
if not site:
continue
recommended_app_result = {
"id": recommended_app.id,
"app": recommended_app.app,
"app_id": recommended_app.app_id,
"description": site.description,
"copyright": site.copyright,
"privacy_policy": site.privacy_policy,
"custom_disclaimer": site.custom_disclaimer,
"category": recommended_app.category,
"position": recommended_app.position,
"is_listed": recommended_app.is_listed,
}
recommended_apps_result.append(recommended_app_result)
categories.add(recommended_app.category)
config = app.app_model_config
if config is not None and config.pre_prompt is not None and len(config.pre_prompt) > 0:
description = config.pre_prompt
if recommended_app.id in recommended:
classList = recommended[recommended_app.id]
if len(classList) == 0:
classList.append("")
for classId in classList:
category = "未分类"
if classId in class_dick:
category = class_dick[classId]
recommended_app_result = {
"id": recommended_app.id,
"app": recommended_app.app,
"app_id": recommended_app.app_id,
"description": description,
"copyright": site.copyright,
"privacy_policy": site.privacy_policy,
"custom_disclaimer": site.custom_disclaimer,
"category": category,
"position": recommended_app.position,
"is_listed": recommended_app.is_listed,
}
recommended_apps_result.append(recommended_app_result)
categories.add(category) # add category to categories
categories = sorted(categories)
categories.append("未分类")
return {"recommended_apps": recommended_apps_result, "categories": sorted(categories)}
# -------------- extend stop: add category to categories ---------------
@@ -186,3 +186,26 @@ class RecommendedAppService:
return recommendedApp.id
except:
return ""
# Extend: start messages context handling
@classmethod
def message_context(cls, conversation_id: str):
from models.model_extend import MessageContextExtend
message_list = []
message_context = db.session.query(MessageContextExtend).filter(
MessageContextExtend.conversation_id == conversation_id).order_by(
MessageContextExtend.created_at.desc()).all()
for v in message_context:
message_list.append(v.message_id)
return message_list
@classmethod
def delete_message_context(cls, conversation_id, message_id: str):
from models.model_extend import MessageContextExtend
db.session.query(MessageContextExtend).filter(
MessageContextExtend.conversation_id == conversation_id,
MessageContextExtend.message_id == message_id,
).delete()
db.session.commit()
return 'ok'
# Extend: stop messages context handling
+1
View File
@@ -4,6 +4,7 @@ from configs import dify_config
from extensions.ext_database import db
from models.account import Tenant, TenantAccountJoin, TenantAccountRole
from services.account_service import TenantService
# extend: 添加用户权限
from services.account_service_extend import TenantExtendService
from services.feature_service import FeatureService
@@ -1,26 +1,97 @@
import json
import logging
from decimal import Decimal
from typing import Optional
import click
from celery import shared_task
from sqlalchemy import exists
from sqlalchemy.exc import SQLAlchemyError
from configs import dify_config
from core.workflow.nodes.enums import NodeType
from core.workflow.enums import NodeType
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.account import Account
from models.account_money_extend import AccountMoneyExtend
from models.api_token_money_extend import ApiTokenMessageJoinsExtend, ApiTokenMoneyExtend
from models.enums import CreatorUserRole
from models.model_extend import EndUserAccountJoinsExtend
# 缓存键前缀和过期时间
PAYER_ID_CACHE_PREFIX = "billing:payer_id:"
PAYER_ID_CACHE_TTL = 3600 # 1小时缓存
def _get_payer_id_from_cache(end_user_id: str) -> Optional[str]:
"""从Redis缓存获取付费人ID"""
try:
cache_key = f"{PAYER_ID_CACHE_PREFIX}{end_user_id}"
cached_value = redis_client.get(cache_key)
if cached_value:
return cached_value.decode('utf-8') if isinstance(cached_value, bytes) else cached_value
except Exception as e:
logging.debug("缓存读取失败: %s", e)
return None
def _set_payer_id_to_cache(end_user_id: str, payer_id: str) -> None:
"""将付费人ID写入Redis缓存"""
try:
cache_key = f"{PAYER_ID_CACHE_PREFIX}{end_user_id}"
redis_client.setex(cache_key, PAYER_ID_CACHE_TTL, payer_id)
except Exception as e:
logging.debug("缓存写入失败: %s", e)
def _resolve_payer_id(created_by: str, created_by_role: Optional[str]) -> str:
"""
解析实际付费人ID
使用缓存+高效查询优化性能
"""
payer_id = created_by
if created_by_role != CreatorUserRole.END_USER.value:
return payer_id
# 先检查缓存
cached_payer_id = _get_payer_id_from_cache(created_by)
if cached_payer_id:
return cached_payer_id
# 使用 EXISTS 子查询检查是否是真实账户,比 SELECT 更高效
is_account = db.session.query(
exists().where(Account.id == created_by)
).scalar()
if is_account:
# 是真实账户,缓存并返回
_set_payer_id_to_cache(created_by, created_by)
return created_by
# 查询关联表获取真正的付费账户
# 只选择需要的字段,使用索引优化查询
end_user_account = (
db.session.query(EndUserAccountJoinsExtend.account_id)
.filter(EndUserAccountJoinsExtend.end_user_id == created_by)
.order_by(EndUserAccountJoinsExtend.created_at.desc())
.first()
)
if end_user_account:
payer_id = str(end_user_account.account_id)
# 缓存结果
_set_payer_id_to_cache(created_by, payer_id)
return payer_id
@shared_task(queue="extend_high", bind=True, max_retries=3)
def update_account_money_when_workflow_node_execution_created_extend(
self, workflow_node_execution_dict: dict):
"""
计算工作流节点执行的费用并更新账户额度
优化版本使用缓存减少数据库查询使用原子更新避免并发问题
:param workflow_node_execution_dict: 工作流节点执行字典
"""
@@ -35,7 +106,7 @@ def update_account_money_when_workflow_node_execution_created_extend(
node_id = workflow_node_execution_dict.get("id")
logging.info(click.style("工作流节点ID {}".format(node_id), fg="cyan"))
# 拿到费用 - 从 outputs 字段获取费用信息(参考原始代码)
# 拿到费用 - 从 outputs 字段获取费用信息
outputs = workflow_node_execution_dict.get("outputs", {})
# 如果 outputs 是字符串,则解析 JSON;如果已经是字典,则直接使用
@@ -55,36 +126,26 @@ def update_account_money_when_workflow_node_execution_created_extend(
logging.info(click.style("扣除费用: {}".format(price), fg="green"))
try:
# 当前是end_user,节点账号id
# 分两种情况
# web应用的请求,created_by记录的是登录账号的ID,可以拿这个ID来扣钱
# API调用,created_by记录的是节点登录账号ID,真正需要扣钱的在关联表EndUserAccountJoinsExtend,需要多做一层查询
created_by = workflow_node_execution_dict.get("created_by")
created_by_role = workflow_node_execution_dict.get("created_by_role")
payerId = created_by # 付钱的ID
if created_by_role == CreatorUserRole.END_USER.value:
account = db.session.query(Account).filter(Account.id == created_by).first()
if not account:
end_user_account_joins = (
db.session.query(EndUserAccountJoinsExtend)
.filter(EndUserAccountJoinsExtend.end_user_id == created_by)
.order_by(EndUserAccountJoinsExtend.created_at.desc())
.first()
)
if end_user_account_joins:
payerId = end_user_account_joins.account_id
# 使用优化后的方法获取付费人ID
payer_id = _resolve_payer_id(created_by, created_by_role)
logging.info(click.style("更新账号额度,账号ID {}".format(payer_id), fg="green"))
account_money = db.session.query(AccountMoneyExtend).filter(
AccountMoneyExtend.account_id == payerId).first()
logging.info(click.style("更新账号额度,账号ID {}".format(payerId), fg="green"))
if account_money:
db.session.query(AccountMoneyExtend).filter(AccountMoneyExtend.account_id == payerId).update(
{
"used_quota": float(account_money.used_quota) + price}
)
else:
# 使用原子更新,避免并发问题,并减少一次查询
# UPDATE ... SET used_quota = used_quota + price WHERE account_id = ?
rows_updated = db.session.query(AccountMoneyExtend).filter(
AccountMoneyExtend.account_id == payer_id
).update(
{"used_quota": AccountMoneyExtend.used_quota + price},
synchronize_session=False
)
if rows_updated == 0:
# 记录不存在,创建新记录
account_money_add = AccountMoneyExtend(
account_id=payerId,
account_id=payer_id,
used_quota=price,
total_quota=dify_config.ACCOUNT_TOTAL_QUOTA,
)
@@ -92,33 +153,38 @@ def update_account_money_when_workflow_node_execution_created_extend(
# 扣掉密钥的钱
workflow_run_id = workflow_node_execution_dict.get("workflow_run_id")
api_token_message = (
db.session.query(ApiTokenMessageJoinsExtend)
.filter(ApiTokenMessageJoinsExtend.record_id == workflow_run_id)
.first()
)
if api_token_message:
logging.info(click.style("更新密钥额度,密钥ID {}".format(
api_token_message.app_token_id), fg="green"))
db.session.query(ApiTokenMoneyExtend).filter(
ApiTokenMoneyExtend.app_token_id == api_token_message.app_token_id
).update(
{
"accumulated_quota": ApiTokenMoneyExtend.accumulated_quota + price,
"day_used_quota": ApiTokenMoneyExtend.day_used_quota + price,
"month_used_quota": ApiTokenMoneyExtend.month_used_quota + price,
},
if workflow_run_id:
# 只查询需要的字段
api_token_id_result = (
db.session.query(ApiTokenMessageJoinsExtend.app_token_id)
.filter(ApiTokenMessageJoinsExtend.record_id == workflow_run_id)
.first()
)
if api_token_id_result and api_token_id_result.app_token_id:
app_token_id = api_token_id_result.app_token_id
logging.info(click.style("更新密钥额度,密钥ID {}".format(app_token_id), fg="green"))
db.session.query(ApiTokenMoneyExtend).filter(
ApiTokenMoneyExtend.app_token_id == app_token_id
).update(
{
"accumulated_quota": ApiTokenMoneyExtend.accumulated_quota + price,
"day_used_quota": ApiTokenMoneyExtend.day_used_quota + price,
"month_used_quota": ApiTokenMoneyExtend.month_used_quota + price,
},
synchronize_session=False
)
db.session.commit()
except SQLAlchemyError as e:
db.session.rollback()
logging.exception(
click.style(f"工作流节点ID {format(node_id)},扣除费用:"
f"{format(price)} 数据库异常,60秒后进行重试,", fg="red")
)
raise self.retry(exc=e, countdown=60) # Retry after 60 seconds
except Exception as e:
db.session.rollback()
logging.exception(
click.style(f"工作流节点ID {format(node_id)},扣除费用:"
f"{format(price)} 异常报错,60秒后进行重试,", fg="red")
@@ -1,10 +1,11 @@
import time
from decimal import Decimal
from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.entities.runtime_route_state import RuntimeRouteState
from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.system_variable import SystemVariable
@@ -3,11 +3,11 @@ import uuid
from datetime import UTC, datetime
import pytest
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState, RuntimeRouteState
from pydantic import ValidationError
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState, RuntimeRouteState
_TEST_DATETIME = datetime(2024, 1, 15, 10, 30, 45)