mirror of
https://github.com/YFGaia/dify-plus.git
synced 2026-06-14 20:41:21 +08:00
fix: 上下文,聊天计费,额度
This commit is contained in:
@@ -712,3 +712,41 @@ ANNOTATION_IMPORT_MAX_CONCURRENT=5
|
||||
SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD=21
|
||||
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE=1000
|
||||
SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS=30
|
||||
|
||||
|
||||
# OA Oauth2(二开新增配置)
|
||||
OAUTH2_CLIENT_URL=
|
||||
OAUTH2_CLIENT_ID=
|
||||
OAUTH2_CLIENT_SECRET=
|
||||
OAUTH2_TOKEN_URL=
|
||||
OAUTH2_USER_INFO_URL=
|
||||
|
||||
# 改成从数据库获取应用模版(二开新增配置)
|
||||
HOSTED_FETCH_APP_TEMPLATES_MODE=db
|
||||
|
||||
# 新增数据库相关配置(二开新增配置)
|
||||
SQLALCHEMY_POOL_SIZE=100
|
||||
SQLALCHEMY_MAX_OVERFLOW=10
|
||||
SQLALCHEMY_POOL_RECYCLE=3600
|
||||
SQLALCHEMY_POOL_PRE_PING=false
|
||||
|
||||
# 注册账号统一邮箱后缀(二开新增配置)
|
||||
EMAIL_DOMAIN=
|
||||
|
||||
# 默认邮箱账号密码登录(二开新增配置)
|
||||
ENABLE_EMAIL_PASSWORD_LOGIN=True
|
||||
|
||||
# 管理后台相关配置,后台超级管理员权限组id(二开新增配置)
|
||||
ADMIN_GROUP_ID=888
|
||||
|
||||
# 中->美汇率
|
||||
RMB_TO_USD_RATE=7.26
|
||||
|
||||
# 默认语言
|
||||
DEFAULT_LANGUAGE=zh-Hans
|
||||
|
||||
# Bedrock Proxy
|
||||
BEDROCK_PROXY=
|
||||
|
||||
# 初始额度
|
||||
ACCOUNT_TOTAL_QUOTA=15
|
||||
|
||||
@@ -2112,6 +2112,7 @@ def migrate_oss(
|
||||
db.session.rollback()
|
||||
click.echo(click.style(f"Failed to update DB storage_type: {str(e)}", fg="red"))
|
||||
|
||||
|
||||
# extend: start 管理二开db扩展
|
||||
@click.group("extend_db", help="管理二开扩展表的数据库迁移")
|
||||
def extend_db():
|
||||
|
||||
@@ -67,6 +67,13 @@ class ExtendInfo(BaseSettings):
|
||||
default=None,
|
||||
)
|
||||
|
||||
# Extend: 记忆上下文功能
|
||||
DEFAULT_NUMBER_CONTEXT: Optional[int] = Field(
|
||||
description="Default number of context retention (记忆窗口默认值)",
|
||||
default=5,
|
||||
)
|
||||
# Extend: 记忆上下文功能
|
||||
|
||||
|
||||
class ExtendConfig(ExtendInfo):
|
||||
pass
|
||||
|
||||
@@ -83,7 +83,7 @@ from .auth import (
|
||||
login,
|
||||
oauth,
|
||||
oauth_server,
|
||||
register_extend,# 二开部分: 新增用户(调用dify注册接口)
|
||||
register_extend, # 二开部分: 新增用户(调用dify注册接口)
|
||||
)
|
||||
|
||||
# Import billing controllers
|
||||
@@ -124,7 +124,7 @@ from .tag import tags
|
||||
# Import workspace controllers
|
||||
from .workspace import (
|
||||
account,
|
||||
account_extend,# 二开部分:新增account_extend
|
||||
account_extend, # 二开部分:新增account_extend
|
||||
agent_providers,
|
||||
endpoint,
|
||||
load_balancing_config,
|
||||
@@ -141,9 +141,9 @@ api.add_namespace(console_ns)
|
||||
|
||||
__all__ = [
|
||||
"account",
|
||||
"account_extend", # 二开部分:新增account_extend
|
||||
"activate",
|
||||
"admin",
|
||||
"account_extend",# 二开部分:新增account_extend
|
||||
"advanced_prompt_template",
|
||||
"agent",
|
||||
"agent_providers",
|
||||
|
||||
@@ -12,9 +12,9 @@ from werkzeug.exceptions import Forbidden
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import TimestampField
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.api_token_money_extend import ApiTokenMoneyExtend # 二开部分 - 密钥额度限制
|
||||
from models.dataset import Dataset
|
||||
from models.model import ApiToken, App
|
||||
from models.api_token_money_extend import ApiTokenMoneyExtend # 二开部分 - 密钥额度限制
|
||||
|
||||
from . import console_ns
|
||||
from .wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
|
||||
@@ -46,6 +46,25 @@ class AppListQuery(BaseModel):
|
||||
tag_ids: list[str] | None = Field(default=None, description="Comma-separated tag IDs")
|
||||
is_created_by_me: bool | None = Field(default=None, description="Filter by creator")
|
||||
|
||||
# extend: start 二开部分:新增的查询参数
|
||||
@field_validator("mode", mode="before")
|
||||
@classmethod
|
||||
def validate_mode(cls, value: Any) -> str:
|
||||
"""
|
||||
Be tolerant for query params.
|
||||
If client passes an unexpected value, fall back to 'all' instead of raising 422.
|
||||
"""
|
||||
if value is None:
|
||||
return "all"
|
||||
|
||||
if isinstance(value, str):
|
||||
v = value.strip()
|
||||
allowed = {"completion", "chat", "advanced-chat", "workflow", "agent-chat", "channel", "all"}
|
||||
return v if v in allowed else "all"
|
||||
|
||||
return "all"
|
||||
# extend: start 二开部分:新增的查询参数
|
||||
|
||||
@field_validator("tag_ids", mode="before")
|
||||
@classmethod
|
||||
def validate_tag_ids(cls, value: str | list[str] | None) -> list[str] | None:
|
||||
@@ -443,6 +462,7 @@ class AppPagination(ResponseModel):
|
||||
total: int
|
||||
has_more: bool = Field(validation_alias=AliasChoices("has_next", "has_more"))
|
||||
data: list[AppPartial] = Field(validation_alias=AliasChoices("items", "data"))
|
||||
recommended_apps: list[str] # extend: recommended apps
|
||||
|
||||
|
||||
class AppExportResponse(ResponseModel):
|
||||
|
||||
@@ -69,7 +69,41 @@ class AppSyncApi(Resource):
|
||||
return "", 200
|
||||
|
||||
|
||||
# Extend: start messages context handling
|
||||
class MessageContextApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
"""Message Context"""
|
||||
from flask import request
|
||||
conversation_id = request.args.get("conversation_id")
|
||||
if not conversation_id:
|
||||
from werkzeug.exceptions import BadRequest
|
||||
raise BadRequest("conversation_id is required")
|
||||
app_service = RecommendedAppService()
|
||||
|
||||
return app_service.message_context(conversation_id)
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self):
|
||||
"""Message Context"""
|
||||
from flask import request
|
||||
message_id = request.args.get("message_id")
|
||||
conversation_id = request.args.get("conversation_id")
|
||||
if not message_id or not conversation_id:
|
||||
from werkzeug.exceptions import BadRequest
|
||||
raise BadRequest("message_id and conversation_id are required")
|
||||
app_service = RecommendedAppService()
|
||||
|
||||
return app_service.delete_message_context(conversation_id, message_id)
|
||||
# Extend: stop messages context handling
|
||||
|
||||
|
||||
# ----------------start sync app------------------------
|
||||
api.add_resource(AppSyncApi, "/apps/<uuid:app_id>/sync")
|
||||
api.add_resource(InstalledSyncAppApi, "/installed/apps")
|
||||
api.add_resource(MessageContextApi, "/message/context")
|
||||
# ---------------- stop sync app ------------------------
|
||||
|
||||
@@ -17,10 +17,10 @@ from controllers.console.app.error import (
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
|
||||
# Extend money_extend
|
||||
from controllers.console.money_extend import money_limit
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import (
|
||||
|
||||
@@ -4,6 +4,8 @@ from typing import cast
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields
|
||||
|
||||
# Extend: 记忆上下文功能
|
||||
from configs import dify_config
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
@@ -12,11 +14,15 @@ from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.configuration import ToolParameterConfigurationManager
|
||||
from events.app_event import app_model_config_was_updated
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.model import AppMode, AppModelConfig
|
||||
from models.model_extend import AppExtend
|
||||
from services.app_model_config_service import AppModelConfigService
|
||||
|
||||
# Extend: 记忆上下文功能
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/model-config")
|
||||
class ModelConfigResource(Resource):
|
||||
@@ -67,6 +73,23 @@ class ModelConfigResource(Resource):
|
||||
)
|
||||
new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration)
|
||||
|
||||
# Extend: 记忆上下文功能 - Start
|
||||
config = request.json
|
||||
if app_model.mode == AppMode.AGENT_CHAT.value or app_model.mode == AppMode.CHAT.value or app_model.is_agent:
|
||||
retention_number = int(config.get("retention_number", dify_config.DEFAULT_NUMBER_CONTEXT))
|
||||
# 循环移除相关键
|
||||
redis_client.delete(f"retention_number_{app_model.id}")
|
||||
app_extend = db.session.query(AppExtend).filter(AppExtend.app_id == app_model.id).first()
|
||||
# appExtend is not None
|
||||
if app_extend is None:
|
||||
db.session.add(AppExtend(app_id=app_model.id, retention_number=retention_number))
|
||||
else:
|
||||
db.session.query(AppExtend).filter(AppExtend.app_id == app_model.id).update(
|
||||
{AppExtend.retention_number: retention_number}
|
||||
)
|
||||
db.session.commit()
|
||||
# Extend: 记忆上下文功能 - Stop
|
||||
|
||||
if app_model.mode == AppMode.AGENT_CHAT or app_model.is_agent:
|
||||
# get original app model config
|
||||
original_app_model_config = (
|
||||
|
||||
@@ -6,17 +6,6 @@ from controllers.console.app.error import AppNotFoundError
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_account_with_tenant
|
||||
from models import App, AppMode
|
||||
from models.account import Account
|
||||
|
||||
|
||||
def _load_app_model(app_id: str) -> Optional[App]:
|
||||
assert isinstance(current_user, Account)
|
||||
app_model = (
|
||||
db.session.query(App)
|
||||
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
.first()
|
||||
)
|
||||
return app_model
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
@@ -50,7 +50,7 @@ def get_oauth_providers():
|
||||
redirect_uri=dify_config.CONSOLE_API_URL + "/console/api/oauth/authorize/google",
|
||||
)
|
||||
|
||||
oauth2 = OaOAuth(client_id='', client_secret='', redirect_uri='') # Extend: oauth2
|
||||
oauth2 = OaOAuth(client_id='', client_secret='', redirect_uri='') # Extend: oauth2
|
||||
|
||||
OAUTH_PROVIDERS = {"github": github_oauth, "google": google_oauth, "oauth2": oauth2}
|
||||
return OAUTH_PROVIDERS
|
||||
|
||||
@@ -33,10 +33,10 @@ from libs.login import current_user
|
||||
from models import Account
|
||||
from models.model import AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.app_task_service import AppTaskService
|
||||
from services.app_generate_service_extend import (
|
||||
AppGenerateServiceExtend, # Extend: App Center - Recommended list sorted by usage frequency
|
||||
)
|
||||
from services.app_task_service import AppTaskService
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
||||
from .. import console_ns
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from flask_restful import Resource, reqparse # type: ignore
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
@@ -6,8 +7,8 @@ from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from libs.login import current_user, login_required
|
||||
from models.account import TenantAccountRole
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import TenantAccountRole
|
||||
from services.model_load_balancing_service import ModelLoadBalancingService
|
||||
|
||||
|
||||
|
||||
@@ -6,7 +6,11 @@ from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required, is_admin_or_owner_required # extend: 非admin或者owner返回 Forbidden
|
||||
from controllers.console.wraps import ( # extend: 非admin或者owner返回 Forbidden
|
||||
account_initialization_required,
|
||||
is_admin_or_owner_required,
|
||||
setup_required,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
|
||||
@@ -11,7 +11,7 @@ from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.wraps import validate_app_token
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.annotation_fields import annotation_fields, build_annotation_model
|
||||
from models.model import App
|
||||
from models.model import ApiToken, App # extend - 密钥额度限制,新增ApiToken
|
||||
from services.annotation_service import AppAnnotationService
|
||||
|
||||
|
||||
@@ -111,7 +111,7 @@ class AnnotationListApi(Resource):
|
||||
)
|
||||
@validate_app_token
|
||||
@service_api_ns.marshal_with(build_annotation_list_model(service_api_ns))
|
||||
def get(self, app_model: App):
|
||||
def get(self, app_model: App, api_token: ApiToken): # extend - 密钥额度限制,新增api_token
|
||||
"""List annotations for the application."""
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
limit = request.args.get("limit", default=20, type=int)
|
||||
|
||||
@@ -5,7 +5,7 @@ from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.app.error import AppUnavailableError
|
||||
from controllers.service_api.wraps import validate_app_token
|
||||
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
|
||||
from models.model import App, AppMode
|
||||
from models.model import ApiToken, App, AppMode
|
||||
from services.app_service import AppService
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ class AppParameterApi(Resource):
|
||||
}
|
||||
)
|
||||
@validate_app_token
|
||||
def get(self, app_model: App):
|
||||
def get(self, app_model: App, api_token: ApiToken): # extend - 密钥额度限制,新增api_token
|
||||
"""Retrieve app parameters.
|
||||
|
||||
Returns the input form parameters and configuration for the application.
|
||||
@@ -60,7 +60,7 @@ class AppMetaApi(Resource):
|
||||
}
|
||||
)
|
||||
@validate_app_token
|
||||
def get(self, app_model: App):
|
||||
def get(self, app_model: App, api_token: ApiToken): # extend - 密钥额度限制,新增api_token
|
||||
"""Get app metadata.
|
||||
|
||||
Returns metadata about the application including configuration and settings.
|
||||
@@ -80,7 +80,7 @@ class AppInfoApi(Resource):
|
||||
}
|
||||
)
|
||||
@validate_app_token
|
||||
def get(self, app_model: App):
|
||||
def get(self, app_model: App, api_token: ApiToken): # extend - 密钥额度限制,新增api_token
|
||||
"""Get app information.
|
||||
|
||||
Returns basic information about the application including name, description, tags, and mode.
|
||||
|
||||
@@ -22,7 +22,7 @@ from controllers.service_api.app.error import (
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from models.model import ApiToken, App, EndUser # extend: 二开部分 密钥额度限制,新增api_token
|
||||
from models.model import ApiToken, App, EndUser # extend: 二开部分 密钥额度限制,新增api_token
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import (
|
||||
AudioTooLargeServiceError,
|
||||
|
||||
@@ -30,9 +30,9 @@ from core.errors.error import (
|
||||
from core.helper.trace_id_helper import get_external_trace_id
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from libs import helper
|
||||
from models.model import ApiToken, App, AppMode, EndUser # extend: 密钥额度限制,新增ApiToken
|
||||
from models.model import ApiToken, App, AppMode, EndUser # extend: 密钥额度限制,新增ApiToken
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.app_generate_service_extend import AppGenerateServiceExtend # extend: 密钥额度限制,新增ApiToken
|
||||
from services.app_generate_service_extend import AppGenerateServiceExtend # extend: 密钥额度限制,新增ApiToken
|
||||
from services.app_task_service import AppTaskService
|
||||
from services.errors.app import IsDraftWorkflowError, WorkflowIdFormatError, WorkflowNotFoundError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
||||
@@ -15,7 +15,7 @@ from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from extensions.ext_database import db
|
||||
from fields.file_fields import FileResponse
|
||||
from models import ApiToken, App, EndUser # extend - 密钥额度限制,新增api_token,否则上传文件会报错
|
||||
from models import ApiToken, App, EndUser # extend - 密钥额度限制,新增api_token,否则上传文件会报错
|
||||
from services.file_service import FileService
|
||||
|
||||
register_schema_models(service_api_ns, FileResponse)
|
||||
@@ -36,7 +36,7 @@ class FileApi(Resource):
|
||||
)
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM)) # type: ignore
|
||||
@service_api_ns.response(HTTPStatus.CREATED, "File uploaded", service_api_ns.models[FileResponse.__name__])
|
||||
def post(self, app_model: App, end_user: EndUser, api_token: ApiToken): # extend - 密钥额度限制,新增api_token,否则上传文件会报错
|
||||
def post(self, app_model: App, end_user: EndUser, api_token: ApiToken): # extend - 密钥额度限制,新增api_token,否则上传文件会报错
|
||||
"""Upload a file for use in conversations.
|
||||
|
||||
Accepts a single file upload via multipart/form-data.
|
||||
|
||||
@@ -15,7 +15,7 @@ from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from fields.conversation_fields import ResultResponse
|
||||
from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem
|
||||
from models.model import ApiToken, App, AppMode, EndUser # extend - 密钥额度限制,新增ApiToken
|
||||
from models.model import ApiToken, App, AppMode, EndUser # extend - 密钥额度限制,新增ApiToken
|
||||
from services.errors.message import (
|
||||
FirstMessageNotExistsError,
|
||||
MessageNotExistsError,
|
||||
@@ -58,7 +58,7 @@ class MessageListApi(Resource):
|
||||
}
|
||||
)
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
|
||||
def get(self, app_model: App, end_user: EndUser, api_token: ApiToken): # extend - 密钥额度限制,新增api_token,否则上传文件会报错
|
||||
def get(self, app_model: App, end_user: EndUser, api_token: ApiToken): # extend - 密钥额度限制,新增api_token,否则上传文件会报错
|
||||
"""List messages in a conversation.
|
||||
|
||||
Retrieves messages with pagination support using first_id.
|
||||
@@ -102,7 +102,7 @@ class MessageFeedbackApi(Resource):
|
||||
}
|
||||
)
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def post(self, app_model: App, end_user: EndUser, message_id, api_token: ApiToken): # extend - 密钥额度限制,新增api_token,否则上传文件会报错
|
||||
def post(self, app_model: App, end_user: EndUser, message_id, api_token: ApiToken): # extend - 密钥额度限制,新增api_token,否则上传文件会报错
|
||||
"""Submit feedback for a message.
|
||||
|
||||
Allows users to rate messages as like/dislike and provide optional feedback content.
|
||||
@@ -137,7 +137,7 @@ class AppGetFeedbacksApi(Resource):
|
||||
}
|
||||
)
|
||||
@validate_app_token
|
||||
def get(self, app_model: App):
|
||||
def get(self, app_model: App, api_token: ApiToken): # extend - 密钥额度限制,新增api_token
|
||||
"""Get all feedbacks for the application.
|
||||
|
||||
Returns paginated list of all feedback submitted for messages in this app.
|
||||
@@ -162,7 +162,7 @@ class MessageSuggestedApi(Resource):
|
||||
}
|
||||
)
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY, required=True))
|
||||
def get(self, app_model: App, end_user: EndUser, message_id, api_token: ApiToken): # extend - 密钥额度限制,新增api_token,否则上传文件会报错
|
||||
def get(self, app_model: App, end_user: EndUser, message_id, api_token: ApiToken): # extend - 密钥额度限制,新增api_token,否则上传文件会报错
|
||||
"""Get suggested follow-up questions for a message.
|
||||
|
||||
Returns AI-generated follow-up questions based on the message content.
|
||||
|
||||
@@ -6,7 +6,7 @@ from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.wraps import validate_app_token
|
||||
from extensions.ext_database import db
|
||||
from models.account import TenantStatus
|
||||
from models.model import App, Site
|
||||
from models.model import ApiToken, App, Site # extend - 密钥额度限制,新增ApiToken
|
||||
|
||||
|
||||
@service_api_ns.route("/site")
|
||||
@@ -23,7 +23,7 @@ class AppSiteApi(Resource):
|
||||
}
|
||||
)
|
||||
@validate_app_token
|
||||
def get(self, app_model: App):
|
||||
def get(self, app_model: App, api_token: ApiToken): # extend - 密钥额度限制,新增api_token
|
||||
"""Retrieve app site info.
|
||||
|
||||
Returns the site configuration for the application including theme, icons, and text.
|
||||
|
||||
@@ -34,7 +34,7 @@ from extensions.ext_database import db
|
||||
from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model
|
||||
from libs import helper
|
||||
from libs.helper import TimestampField
|
||||
from models.model import ApiToken, App, AppMode, EndUser # extend - 密钥额度限制,ApiToken
|
||||
from models.model import ApiToken, App, AppMode, EndUser # extend - 密钥额度限制,ApiToken
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import IsDraftWorkflowError, WorkflowIdFormatError, WorkflowNotFoundError
|
||||
@@ -97,7 +97,7 @@ class WorkflowRunDetailApi(Resource):
|
||||
)
|
||||
@validate_app_token
|
||||
@service_api_ns.marshal_with(build_workflow_run_model(service_api_ns))
|
||||
def get(self, app_model: App, workflow_run_id: str):
|
||||
def get(self, app_model: App, workflow_run_id: str, api_token: ApiToken): # extend - 密钥额度限制,新增api_token参数
|
||||
"""Get a workflow task running detail.
|
||||
|
||||
Returns detailed information about a specific workflow run.
|
||||
@@ -134,7 +134,7 @@ class WorkflowRunApi(Resource):
|
||||
}
|
||||
)
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def post(self, app_model: App, end_user: EndUser, api_token: ApiToken, workflow_id: str): # extend: 密钥额度限制
|
||||
def post(self, app_model: App, end_user: EndUser, api_token: ApiToken): # extend: 密钥额度限制
|
||||
"""Execute a workflow.
|
||||
|
||||
Runs a workflow with the provided inputs and returns the results.
|
||||
@@ -195,7 +195,7 @@ class WorkflowRunByIdApi(Resource):
|
||||
}
|
||||
)
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def post(self, app_model: App, end_user: EndUser, workflow_id: str):
|
||||
def post(self, app_model: App, end_user: EndUser, api_token: ApiToken, workflow_id: str): # extend: 密钥额度限制
|
||||
"""Run specific workflow by ID.
|
||||
|
||||
Executes a specific workflow version identified by its ID.
|
||||
@@ -215,6 +215,10 @@ class WorkflowRunByIdApi(Resource):
|
||||
args["external_trace_id"] = external_trace_id
|
||||
streaming = payload.response_mode == "streaming"
|
||||
|
||||
# ------------------- 二开部分Begin - 密钥额度限制 -------------------
|
||||
args["api_token"] = api_token
|
||||
# ------------------- 二开部分End - 密钥额度限制 -------------------
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model, user=end_user, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=streaming
|
||||
|
||||
@@ -14,25 +14,26 @@ from sqlalchemy import select, update
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
|
||||
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_user
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantStatus
|
||||
from models.dataset import Dataset, RateLimitLog
|
||||
from models.model import ApiToken, App
|
||||
from services.end_user_service import EndUserService
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
# extend: start 额度限制,API调用计费,新增TenantAccountRole
|
||||
from controllers.service_api.app.error_extend import (
|
||||
AccountNoMoneyErrorExtend,
|
||||
ApiTokenDayNoMoneyErrorExtend,
|
||||
ApiTokenMonthNoMoneyErrorExtend,
|
||||
)
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_user
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantStatus
|
||||
from models.account_money_extend import AccountMoneyExtend
|
||||
from models.api_token_money_extend import ApiTokenMoneyExtend
|
||||
from models.dataset import Dataset, RateLimitLog
|
||||
from models.model import ApiToken, App
|
||||
from models.model_extend import EndUserAccountJoinsExtend
|
||||
from services.end_user_service import EndUserService
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
# extend: stop 额度限制,API调用计费,新增TenantAccountRole
|
||||
|
||||
|
||||
@@ -80,17 +81,27 @@ def validate_app_token(view: Callable[P, R] | None = None, *, fetch_user_arg: Fe
|
||||
if tenant.status == TenantStatus.ARCHIVE:
|
||||
raise Forbidden("The workspace's status is archived.")
|
||||
|
||||
|
||||
# ---------------------extend: 二开部分Begin 额度限制,API调用计费 ---------------------
|
||||
# TODO 需要写入缓存,读缓存
|
||||
account_money = (
|
||||
db.session.query(AccountMoneyExtend)
|
||||
.filter(AccountMoneyExtend.account_id == ta.account_id)
|
||||
.first()
|
||||
)
|
||||
if account_money and account_money.used_quota >= account_money.total_quota:
|
||||
raise AccountNoMoneyErrorExtend()
|
||||
|
||||
tenant_account_join = (
|
||||
db.session.query(Tenant, TenantAccountJoin)
|
||||
.filter(Tenant.id == api_token.tenant_id)
|
||||
.filter(TenantAccountJoin.tenant_id == Tenant.id)
|
||||
.filter(TenantAccountJoin.role.in_(["owner"]))
|
||||
.filter(Tenant.status == TenantStatus.NORMAL)
|
||||
.one_or_none()
|
||||
) # TODO: only owner information is required, so only one is returned.
|
||||
if tenant_account_join:
|
||||
tenant, ta = tenant_account_join
|
||||
# TODO 需要写入缓存,读缓存
|
||||
account_money = (
|
||||
db.session.query(AccountMoneyExtend)
|
||||
.filter(AccountMoneyExtend.account_id == ta.account_id)
|
||||
.first()
|
||||
)
|
||||
if account_money and account_money.used_quota >= account_money.total_quota:
|
||||
raise AccountNoMoneyErrorExtend()
|
||||
else:
|
||||
raise Unauthorized("Tenant does not exist.")
|
||||
# 密钥额度判断
|
||||
kwargs["api_token"] = api_token # API token消息数据传递下去
|
||||
api_token_money = (
|
||||
@@ -382,6 +393,7 @@ def validate_and_get_api_token(scope: str | None = None):
|
||||
|
||||
return api_token
|
||||
|
||||
|
||||
# ---------------------二开部分Begin 额度限制,API调用计费 ---------------------
|
||||
def create_or_update_end_user_account_join_extend(end_user_id, account_id, app_id: str) -> EndUserAccountJoinsExtend:
|
||||
# 插入节点账号id和用户账号id关联关系,以方便扣钱查询
|
||||
@@ -402,7 +414,6 @@ def create_or_update_end_user_account_join_extend(end_user_id, account_id, app_i
|
||||
# ---------------------二开部分End 额度限制,API调用计费 ---------------------
|
||||
|
||||
|
||||
|
||||
class DatasetApiResource(Resource):
|
||||
method_decorators = [validate_dataset_token]
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import logging
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
@@ -37,14 +38,16 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
# extend: 您必须登录才能访问您的帐户扩展功能
|
||||
from flask import request
|
||||
from extensions.ext_database import db
|
||||
from libs.passport import PassportService
|
||||
from models.account_money_extend import AccountMoneyExtend
|
||||
from services.app_generate_service_extend import AppGenerateServiceExtend
|
||||
|
||||
from controllers.web.error_extend import (
|
||||
AccountNoMoneyErrorExtend,
|
||||
WebAuthRequiredErrorExtend,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from libs.passport import PassportService
|
||||
from models.account_money_extend import AccountMoneyExtend
|
||||
from services.app_generate_service_extend import AppGenerateServiceExtend
|
||||
|
||||
|
||||
def is_end_login(end_user):
|
||||
user_info = None
|
||||
@@ -61,6 +64,7 @@ def is_end_login(end_user):
|
||||
# no login
|
||||
return user_info
|
||||
|
||||
|
||||
# 额度限制
|
||||
def is_money_limit(end_user) -> bool:
|
||||
try:
|
||||
|
||||
@@ -4,9 +4,6 @@ from datetime import UTC, datetime, timedelta
|
||||
from flask import make_response, request
|
||||
from flask_restx import Resource
|
||||
from sqlalchemy import func, select
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from sqlalchemy import func, select
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
@@ -36,8 +36,10 @@ from controllers.web.error_extend import (
|
||||
WebAuthRequiredErrorExtend,
|
||||
)
|
||||
from services.app_generate_service_extend import AppGenerateServiceExtend
|
||||
|
||||
# extend: stop 您必须登录才能访问您的帐户扩展功能
|
||||
|
||||
|
||||
@web_ns.route("/workflows/run")
|
||||
class WorkflowRunApi(WebApiResource):
|
||||
@web_ns.doc("Run Workflow")
|
||||
|
||||
@@ -5,8 +5,6 @@ from typing import Union, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.agent.entities import AgentEntity, AgentToolEntity
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
|
||||
|
||||
@@ -3,7 +3,7 @@ import logging
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Literal, Union, overload, cast # extend: 二开部分 - 密钥额度限制,新增cast
|
||||
from typing import Any, Literal, Union, cast, overload # extend: 二开部分 - 密钥额度限制,新增cast
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
@@ -38,7 +38,16 @@ from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom, ApiToken # extend: 二开部分 - 密钥额度限制,新增ApiToken
|
||||
from models import ( # extend: 二开部分 - 密钥额度限制,新增ApiToken
|
||||
Account,
|
||||
ApiToken,
|
||||
App,
|
||||
Conversation,
|
||||
EndUser,
|
||||
Message,
|
||||
Workflow,
|
||||
WorkflowNodeExecutionTriggeredFrom,
|
||||
)
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from services.conversation_service import ConversationService
|
||||
from services.workflow_draft_variable_service import (
|
||||
@@ -126,7 +135,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
extras["app_token_id"] = api_token.id
|
||||
# ------------------- 二开部分End - 密钥额度限制 -------------------
|
||||
|
||||
|
||||
# get conversation
|
||||
conversation = None
|
||||
conversation_id = args.get("conversation_id")
|
||||
|
||||
@@ -202,6 +202,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
workflow_execution_repository=self._workflow_execution_repository,
|
||||
workflow_node_execution_repository=self._workflow_node_execution_repository,
|
||||
trace_manager=self.application_generate_entity.trace_manager,
|
||||
user_from=user_from, # 二开部分 - 用于计费
|
||||
)
|
||||
|
||||
workflow_entry.graph_engine.layer(persistence_layer)
|
||||
|
||||
@@ -70,7 +70,14 @@ from core.workflow.runtime import GraphRuntimeState
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import AppMode, Account, Conversation, EndUser, Message, MessageFile # extend: 二开部分End - 密钥额度限制,新增AppMode
|
||||
from models import ( # extend: 二开部分End - 密钥额度限制,新增AppMode
|
||||
Account,
|
||||
AppMode,
|
||||
Conversation,
|
||||
EndUser,
|
||||
Message,
|
||||
MessageFile,
|
||||
)
|
||||
from models.api_token_money_extend import ApiTokenMessageJoinsExtend # extend: 二开部分End - 密钥额度限制
|
||||
from models.enums import CreatorUserRole
|
||||
from models.workflow import Workflow
|
||||
@@ -579,37 +586,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
with self._database_session() as session:
|
||||
err_event = QueueErrorEvent(error=ValueError(f"Run failed: {event.error}"))
|
||||
err = self._base_task_pipeline.handle_error(event=err_event, session=session, message_id=self._message_id)
|
||||
def _handle_workflow_failed_event(
|
||||
self,
|
||||
event: QueueWorkflowFailedEvent,
|
||||
*,
|
||||
graph_runtime_state: Optional[GraphRuntimeState] = None,
|
||||
trace_manager: Optional[TraceQueueManager] = None,
|
||||
**kwargs,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle workflow failed events."""
|
||||
self._ensure_workflow_initialized()
|
||||
validated_state = self._ensure_graph_runtime_initialized(graph_runtime_state)
|
||||
|
||||
with self._database_session() as session:
|
||||
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_failed(
|
||||
workflow_run_id=self._workflow_run_id,
|
||||
total_tokens=validated_state.total_tokens,
|
||||
total_steps=validated_state.node_run_steps,
|
||||
status=WorkflowExecutionStatus.FAILED,
|
||||
error_message=event.error,
|
||||
conversation_id=self._conversation_id,
|
||||
trace_manager=trace_manager,
|
||||
exceptions_count=event.exceptions_count,
|
||||
external_trace_id=self._application_generate_entity.extras.get("external_trace_id"),
|
||||
)
|
||||
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution=workflow_execution,
|
||||
)
|
||||
err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_execution.error_message}"))
|
||||
err = self._base_task_pipeline._handle_error(event=err_event, session=session, message_id=self._message_id)
|
||||
|
||||
yield workflow_finish_resp
|
||||
yield self._base_task_pipeline.error_to_stream_response(err)
|
||||
|
||||
@@ -76,8 +76,12 @@ class AgentChatAppRunner(AppRunner):
|
||||
files=list(files),
|
||||
query=query,
|
||||
memory=memory,
|
||||
control_registers=False, # Extend: messages_context_handling
|
||||
)
|
||||
|
||||
# Extend: messages_context_handling
|
||||
self.add_messages_context(prompt_messages, app_config.app_id, conversation.id, message.id)
|
||||
|
||||
# moderation
|
||||
try:
|
||||
# process sensitive_word_avoidance
|
||||
|
||||
@@ -29,7 +29,14 @@ from core.moderation.input_moderation import InputModeration
|
||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform
|
||||
|
||||
# extend: start messages_context_handling
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.model import App, AppMode, Message, MessageAnnotation
|
||||
from models.model_extend import AppExtend, MessageContextExtend
|
||||
|
||||
# extend: stop messages_context_handling
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.file.models import File
|
||||
@@ -72,6 +79,29 @@ class AppRunner:
|
||||
):
|
||||
model_config.parameters[parameter_rule.name] = max_tokens
|
||||
|
||||
# Extend: start messages_context_handling
|
||||
def add_messages_context(self, prompt_messages, app_id, conversation_id, message_id):
|
||||
key = "retention_number_{}".format(app_id)
|
||||
retention_number = redis_client.get(key)
|
||||
if retention_number is None:
|
||||
app_extend: AppExtend = (
|
||||
db.session.query(AppExtend).filter(AppExtend.app_id == app_id).first()
|
||||
)
|
||||
if app_extend is None:
|
||||
return
|
||||
retention_number = int(app_extend.retention_number)
|
||||
redis_client.set(key, app_extend.retention_number)
|
||||
else:
|
||||
retention_number = int(retention_number)
|
||||
if (len(prompt_messages) + 2) / 2 > retention_number:
|
||||
# 插入替换
|
||||
db.session.add(MessageContextExtend(
|
||||
conversation_id=conversation_id,
|
||||
message_id=message_id,
|
||||
))
|
||||
db.session.commit()
|
||||
# Extend: stop messages_context_handling
|
||||
|
||||
def organize_prompt_messages(
|
||||
self,
|
||||
app_record: App,
|
||||
@@ -84,6 +114,7 @@ class AppRunner:
|
||||
memory: TokenBufferMemory | None = None,
|
||||
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
|
||||
context_files: list["File"] | None = None,
|
||||
control_registers: bool = True, # Extend: messages context handling
|
||||
) -> tuple[list[PromptMessage], list[str] | None]:
|
||||
"""
|
||||
Organize prompt messages
|
||||
@@ -96,6 +127,7 @@ class AppRunner:
|
||||
:param query: query
|
||||
:param memory: memory
|
||||
:param image_detail_config: the image quality config
|
||||
:param control_registers: is messages context # Extend: messages context handling
|
||||
:return:
|
||||
"""
|
||||
# get prompt without memory and context
|
||||
@@ -113,6 +145,7 @@ class AppRunner:
|
||||
model_config=model_config,
|
||||
image_detail_config=image_detail_config,
|
||||
context_files=context_files,
|
||||
control_registers=control_registers, # Extend: messages context handling
|
||||
)
|
||||
else:
|
||||
memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False))
|
||||
|
||||
@@ -24,9 +24,9 @@ from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models import Account
|
||||
from models.api_token_money_extend import ApiTokenMessageJoinsExtend # 二开部分End - 密钥额度限制
|
||||
from models.model import App, EndUser
|
||||
from services.conversation_service import ConversationService
|
||||
from models.api_token_money_extend import ApiTokenMessageJoinsExtend # 二开部分End - 密钥额度限制
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -224,6 +224,21 @@ class ChatAppRunner(AppRunner):
|
||||
user=application_generate_entity.user_id,
|
||||
)
|
||||
|
||||
# Extend: Start messages context handling
|
||||
messages_list, _ = self.organize_prompt_messages(
|
||||
app_record=app_record,
|
||||
model_config=application_generate_entity.model_conf,
|
||||
prompt_template_entity=app_config.prompt_template,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query,
|
||||
context=context,
|
||||
memory=memory
|
||||
)
|
||||
|
||||
# Extend: Stop messages_context_handling
|
||||
self.add_messages_context(messages_list, app_config.app_id, conversation.id, message.id)
|
||||
|
||||
# handle invoke result
|
||||
self._handle_invoke_result(
|
||||
invoke_result=invoke_result, queue_manager=queue_manager, stream=application_generate_entity.stream
|
||||
|
||||
@@ -187,6 +187,7 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
||||
workflow_execution_repository=self._workflow_execution_repository,
|
||||
workflow_node_execution_repository=self._workflow_node_execution_repository,
|
||||
trace_manager=self.application_generate_entity.trace_manager,
|
||||
user_from=user_from, # 二开部分 - 用于计费
|
||||
)
|
||||
|
||||
workflow_entry.graph_engine.layer(persistence_layer)
|
||||
|
||||
@@ -3,8 +3,9 @@ import logging
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
|
||||
# extend: 二开部分 - 密钥额度限制,新增cast
|
||||
from typing import Any, Literal, Union, overload, cast
|
||||
from typing import Any, Literal, Union, cast, overload
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
@@ -36,8 +37,9 @@ from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
|
||||
# extend: 二开部分 - 密钥额度限制,新增ApiToken
|
||||
from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom, ApiToken
|
||||
from models import Account, ApiToken, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService
|
||||
|
||||
|
||||
@@ -145,6 +145,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
workflow_execution_repository=self._workflow_execution_repository,
|
||||
workflow_node_execution_repository=self._workflow_node_execution_repository,
|
||||
trace_manager=self.application_generate_entity.trace_manager,
|
||||
user_from=user_from, # 二开部分 - 用于计费
|
||||
)
|
||||
|
||||
workflow_entry.graph_engine.layer(persistence_layer)
|
||||
|
||||
@@ -57,9 +57,11 @@ from core.workflow.runtime import GraphRuntimeState
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from extensions.ext_database import db
|
||||
from models import Account
|
||||
from models.enums import CreatorUserRole
|
||||
|
||||
# extend: 二开部分End - 密钥额度限制
|
||||
from models.api_token_money_extend import ApiTokenMessageJoinsExtend
|
||||
from models.enums import CreatorUserRole
|
||||
|
||||
# extend: 二开部分End - 密钥额度限制,新增AppMode
|
||||
from models.model import AppMode, EndUser
|
||||
from models.workflow import Workflow, WorkflowAppLog, WorkflowAppLogCreatedFrom
|
||||
|
||||
@@ -2,8 +2,6 @@ import logging
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from extensions.ext_database import db
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
from sqlalchemy import select
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor
|
||||
from core.external_data_tool.base import ExternalDataTool
|
||||
from core.helper import encrypter
|
||||
|
||||
@@ -19,6 +19,9 @@ from core.prompt.utils.extract_thread_messages import extract_thread_messages
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models.model import AppMode, Conversation, Message, MessageFile
|
||||
|
||||
# Extend: start messages context handling
|
||||
from models.model_extend import MessageContextExtend
|
||||
from models.workflow import Workflow
|
||||
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
@@ -34,6 +37,31 @@ class TokenBufferMemory:
|
||||
self.model_instance = model_instance
|
||||
self._workflow_run_repo: APIWorkflowRunRepository | None = None
|
||||
|
||||
# Extend: start messages context handling
|
||||
def messages_context_handling(
|
||||
self,
|
||||
conversation_id: str,
|
||||
prompt_messages: tuple[list[AssistantPromptMessage]],
|
||||
) -> tuple[list[AssistantPromptMessage]]:
|
||||
# check if there is a segmentation context
|
||||
message_context = db.session.query(MessageContextExtend).filter(
|
||||
MessageContextExtend.conversation_id == conversation_id).order_by(
|
||||
MessageContextExtend.created_at.desc()).all()
|
||||
# Is there a split
|
||||
if not message_context:
|
||||
return prompt_messages
|
||||
# for
|
||||
messages = []
|
||||
for v in prompt_messages:
|
||||
messages.append(v)
|
||||
if v.name is not None and len(v.name) > 0:
|
||||
for i in message_context:
|
||||
if v.name == i.message_id:
|
||||
messages = []
|
||||
v.name = None
|
||||
return messages
|
||||
# Extend: stop messages context handling
|
||||
|
||||
@property
|
||||
def workflow_run_repo(self) -> APIWorkflowRunRepository:
|
||||
if self._workflow_run_repo is None:
|
||||
@@ -115,12 +143,14 @@ class TokenBufferMemory:
|
||||
return AssistantPromptMessage(content=prompt_message_contents)
|
||||
|
||||
def get_history_prompt_messages(
|
||||
self, max_token_limit: int = 2000, message_limit: int | None = None
|
||||
self, max_token_limit: int = 2000, message_limit: int | None = None,
|
||||
control_registers: bool = True, # Extend: messages context handling
|
||||
) -> Sequence[PromptMessage]:
|
||||
"""
|
||||
Get history prompt messages.
|
||||
:param max_token_limit: max token limit
|
||||
:param message_limit: message limit
|
||||
:param control_registers:
|
||||
"""
|
||||
app_record = self.conversation.app
|
||||
|
||||
@@ -188,6 +218,17 @@ class TokenBufferMemory:
|
||||
else:
|
||||
prompt_messages.append(AssistantPromptMessage(content=message.answer))
|
||||
|
||||
# Extend Contextual dividing line
|
||||
prompt_messages.append(AssistantPromptMessage(name=message.id, content=message.answer))
|
||||
|
||||
# Extend: start messages context handling
|
||||
if control_registers:
|
||||
prompt_messages = self.messages_context_handling(
|
||||
prompt_messages=prompt_messages,
|
||||
conversation_id=self.conversation.id,
|
||||
)
|
||||
# Extend: stop messages context handling
|
||||
|
||||
if not prompt_messages:
|
||||
return []
|
||||
|
||||
|
||||
@@ -15,9 +15,11 @@ class PromptTransform:
|
||||
memory_config: MemoryConfig,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
control_registers: bool = True, # Extend: messages context handling
|
||||
) -> list[PromptMessage]:
|
||||
rest_tokens = self._calculate_rest_token(prompt_messages, model_config)
|
||||
histories = self._get_history_messages_list_from_memory(memory, memory_config, rest_tokens)
|
||||
# Extend: messages context handling
|
||||
histories = self._get_history_messages_list_from_memory(memory, memory_config, rest_tokens, control_registers)
|
||||
prompt_messages.extend(histories)
|
||||
|
||||
return prompt_messages
|
||||
@@ -74,6 +76,7 @@ class PromptTransform:
|
||||
|
||||
def _get_history_messages_list_from_memory(
|
||||
self, memory: TokenBufferMemory, memory_config: MemoryConfig, max_token_limit: int
|
||||
, control_registers: bool = True, # Extend: messages context handling
|
||||
) -> list[PromptMessage]:
|
||||
"""Get memory messages."""
|
||||
return list(
|
||||
@@ -86,5 +89,6 @@ class PromptTransform:
|
||||
and memory_config.window.size > 0
|
||||
)
|
||||
else None,
|
||||
control_registers=control_registers, # Extend: messages context handling
|
||||
)
|
||||
)
|
||||
|
||||
@@ -50,6 +50,7 @@ class SimplePromptTransform(PromptTransform):
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
|
||||
context_files: list["File"] | None = None,
|
||||
control_registers: bool = True, # Extend: messages context handling
|
||||
) -> tuple[list[PromptMessage], list[str] | None]:
|
||||
inputs = {key: str(value) for key, value in inputs.items()}
|
||||
|
||||
@@ -66,6 +67,7 @@ class SimplePromptTransform(PromptTransform):
|
||||
model_config=model_config,
|
||||
image_detail_config=image_detail_config,
|
||||
context_files=context_files,
|
||||
control_registers=control_registers, # Extend: messages context handling
|
||||
)
|
||||
else:
|
||||
prompt_messages, stops = self._get_completion_model_prompt_messages(
|
||||
@@ -191,6 +193,7 @@ class SimplePromptTransform(PromptTransform):
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
|
||||
context_files: list["File"] | None = None,
|
||||
control_registers: bool = True, # Extend: messages context handling
|
||||
) -> tuple[list[PromptMessage], list[str] | None]:
|
||||
prompt_messages: list[PromptMessage] = []
|
||||
|
||||
@@ -217,6 +220,7 @@ class SimplePromptTransform(PromptTransform):
|
||||
),
|
||||
prompt_messages=prompt_messages,
|
||||
model_config=model_config,
|
||||
control_registers=control_registers, # Extend: messages context handling
|
||||
)
|
||||
|
||||
if query:
|
||||
|
||||
@@ -12,12 +12,6 @@ from core.rag.extractor.watercrawl.exceptions import (
|
||||
WaterCrawlPermissionError,
|
||||
)
|
||||
|
||||
from core.rag.extractor.watercrawl.exceptions import (
|
||||
WaterCrawlAuthenticationError,
|
||||
WaterCrawlBadRequestError,
|
||||
WaterCrawlPermissionError,
|
||||
)
|
||||
|
||||
|
||||
class BaseAPIClient:
|
||||
def __init__(self, api_key, base_url):
|
||||
|
||||
@@ -122,6 +122,9 @@ class ToolManager:
|
||||
"""
|
||||
get the plugin provider
|
||||
"""
|
||||
# extend: 获取插件提供程序
|
||||
from core.plugin.impl.exc import PluginNotFoundError
|
||||
|
||||
# check if context is set
|
||||
|
||||
try:
|
||||
@@ -141,19 +144,53 @@ class ToolManager:
|
||||
return plugin_tool_providers[provider]
|
||||
|
||||
manager = PluginToolManager()
|
||||
provider_entity = manager.fetch_tool_provider(tenant_id, provider)
|
||||
if not provider_entity:
|
||||
raise ToolProviderNotFoundError(f"plugin provider {provider} not found")
|
||||
# extend: start 获取插件提供程序
|
||||
max_retries = 2
|
||||
last_error = None
|
||||
|
||||
controller = PluginToolProviderController(
|
||||
entity=provider_entity.declaration,
|
||||
plugin_id=provider_entity.plugin_id,
|
||||
plugin_unique_identifier=provider_entity.plugin_unique_identifier,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
provider_entity = manager.fetch_tool_provider(tenant_id, provider)
|
||||
if not provider_entity:
|
||||
raise ToolProviderNotFoundError(f"plugin provider {provider} not found")
|
||||
|
||||
plugin_tool_providers[provider] = controller
|
||||
return controller
|
||||
controller = PluginToolProviderController(
|
||||
entity=provider_entity.declaration,
|
||||
plugin_id=provider_entity.plugin_id,
|
||||
plugin_unique_identifier=provider_entity.plugin_unique_identifier,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
plugin_tool_providers[provider] = controller
|
||||
return controller
|
||||
except PluginNotFoundError as e:
|
||||
last_error = e
|
||||
# Clear cache and retry once more
|
||||
if attempt < max_retries - 1:
|
||||
logger.warning(
|
||||
f"Plugin {provider} not found on attempt {attempt + 1}, clearing cache and retrying. "
|
||||
f"Error: {str(e)}"
|
||||
)
|
||||
# Remove from cache if exists
|
||||
plugin_tool_providers.pop(provider, None)
|
||||
# Small delay before retry
|
||||
time.sleep(0.5)
|
||||
else:
|
||||
logger.error(
|
||||
f"Plugin {provider} not found after {max_retries} attempts. "
|
||||
f"Last error: {str(e)}"
|
||||
)
|
||||
raise ToolProviderNotFoundError(f"plugin provider {provider} not found after retries: {str(e)}")
|
||||
except Exception as e:
|
||||
# For other errors, don't retry
|
||||
logger.exception(f"Error fetching plugin provider {provider}: {str(e)}")
|
||||
raise
|
||||
|
||||
# Should not reach here, but just in case
|
||||
if last_error:
|
||||
raise ToolProviderNotFoundError(f"plugin provider {provider} not found: {str(last_error)}")
|
||||
raise ToolProviderNotFoundError(f"plugin provider {provider} not found")
|
||||
# extend: stop 获取插件提供程序
|
||||
|
||||
@classmethod
|
||||
def get_tool_runtime(
|
||||
@@ -634,9 +671,9 @@ class ToolManager:
|
||||
# MySQL: Use window function to achieve same result
|
||||
sql = """
|
||||
SELECT id FROM (
|
||||
SELECT id,
|
||||
SELECT id,
|
||||
ROW_NUMBER() OVER (
|
||||
PARTITION BY tenant_id, provider
|
||||
PARTITION BY tenant_id, provider
|
||||
ORDER BY is_default DESC, created_at DESC
|
||||
) as rn
|
||||
FROM tool_builtin_providers
|
||||
|
||||
@@ -15,6 +15,7 @@ from datetime import datetime
|
||||
from typing import Any, Union
|
||||
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
@@ -48,6 +49,14 @@ from core.workflow.repositories.workflow_node_execution_repository import Workfl
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
# extend: start 二开部分 - 计费相关的用户信息
|
||||
from models.enums import CreatorUserRole, UserFrom
|
||||
from tasks.extend.update_account_money_when_workflow_node_execution_created_extend import (
|
||||
update_account_money_when_workflow_node_execution_created_extend,
|
||||
)
|
||||
|
||||
# extend: stop 二开部分 - 计费相关的用户信息
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class PersistenceWorkflowInfo:
|
||||
@@ -82,6 +91,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
trace_manager: TraceQueueManager | None = None,
|
||||
user_from: UserFrom | None = None, # 二开部分 - 用于计费
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self._application_generate_entity = application_generate_entity
|
||||
@@ -89,6 +99,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
self._workflow_execution_repository = workflow_execution_repository
|
||||
self._workflow_node_execution_repository = workflow_node_execution_repository
|
||||
self._trace_manager = trace_manager
|
||||
self._user_from = user_from # 二开部分 - 用于计费
|
||||
|
||||
self._workflow_execution: WorkflowExecution | None = None
|
||||
self._node_execution_cache: dict[str, WorkflowNodeExecution] = {}
|
||||
@@ -270,6 +281,26 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
def _handle_node_succeeded(self, event: NodeRunSucceededEvent) -> None:
|
||||
domain_execution = self._get_node_execution(event.id)
|
||||
self._update_node_execution(domain_execution, event.node_run_result, WorkflowNodeExecutionStatus.SUCCEEDED)
|
||||
|
||||
# 二开部分Begin - 计费
|
||||
# 异步任务计算费用并更新账户额度,将对象转换为字典传递
|
||||
domain_execution_dict = jsonable_encoder(domain_execution)
|
||||
|
||||
# 添加用户信息到字典中
|
||||
domain_execution_dict['created_by'] = self._application_generate_entity.user_id
|
||||
if self._user_from == UserFrom.ACCOUNT:
|
||||
domain_execution_dict['created_by_role'] = CreatorUserRole.ACCOUNT.value
|
||||
elif self._user_from == UserFrom.END_USER:
|
||||
domain_execution_dict['created_by_role'] = CreatorUserRole.END_USER.value
|
||||
else:
|
||||
domain_execution_dict['created_by_role'] = None
|
||||
|
||||
# 添加 workflow_run_id
|
||||
if self._workflow_execution:
|
||||
domain_execution_dict['workflow_run_id'] = self._workflow_execution.id_
|
||||
|
||||
update_account_money_when_workflow_node_execution_created_extend.delay(domain_execution_dict)
|
||||
# 二开部分End - 计费
|
||||
|
||||
def _handle_node_failed(self, event: NodeRunFailedEvent) -> None:
|
||||
domain_execution = self._get_node_execution(event.id)
|
||||
@@ -403,3 +434,4 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
def _system_variables(self) -> Mapping[str, Any]:
|
||||
runtime_state = self.graph_runtime_state
|
||||
return runtime_state.variable_pool.get_by_prefix(SYSTEM_VARIABLE_NODE_ID)
|
||||
|
||||
@@ -11,10 +11,11 @@ from core.variables.types import SegmentType
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.code.entities import CodeNodeData
|
||||
from core.workflow.nodes.code.limits import CodeNodeLimits
|
||||
|
||||
# Extend: Adding execution control logic
|
||||
from core.workflow.nodes.code.control_extend import ExecutionControl
|
||||
from core.workflow.nodes.code.entities import CodeNodeData
|
||||
from core.workflow.nodes.code.limits import CodeNodeLimits
|
||||
|
||||
from .exc import (
|
||||
CodeNodeError,
|
||||
|
||||
@@ -38,6 +38,7 @@ from tasks.extend.update_account_money_when_workflow_node_execution_created_exte
|
||||
|
||||
# 二开部分End - 密钥额度限制
|
||||
|
||||
|
||||
@dataclass
|
||||
class CycleManagerWorkflowInfo:
|
||||
workflow_id: str
|
||||
@@ -55,12 +56,18 @@ class WorkflowCycleManager:
|
||||
workflow_info: CycleManagerWorkflowInfo,
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
user_id: Optional[str] = None, # extend:二开部分 - 计费相关的用户信息
|
||||
created_by_role=None, # extend:二开部分 - 计费相关的用户信息
|
||||
) -> None:
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self._workflow_system_variables = workflow_system_variables
|
||||
self._workflow_info = workflow_info
|
||||
self._workflow_execution_repository = workflow_execution_repository
|
||||
self._workflow_node_execution_repository = workflow_node_execution_repository
|
||||
|
||||
# extend:二开部分 - 计费相关的用户信息
|
||||
self._user_id = user_id
|
||||
self._created_by_role = created_by_role
|
||||
|
||||
# Initialize caches for workflow execution cycle
|
||||
# These caches avoid redundant repository calls during a single workflow execution
|
||||
|
||||
@@ -10,12 +10,12 @@ from .queue_credential_sync_when_tenant_created import handle as handle_queue_cr
|
||||
from .sync_plugin_trigger_when_app_created import handle as handle_sync_plugin_trigger_when_app_created
|
||||
from .sync_webhook_when_app_created import handle as handle_sync_webhook_when_app_created
|
||||
from .sync_workflow_schedule_when_app_published import handle as handle_sync_workflow_schedule_when_app_published
|
||||
from .update_app_dataset_join_when_app_model_config_updated import (
|
||||
handle as handle_update_app_dataset_join_when_app_model_config_updated,
|
||||
)
|
||||
from .update_account_money_when_messaeg_created_extend import (
|
||||
handle as handle_update_account_money_when_messaeg_created_extend,
|
||||
) # 二开部分:新增限额判断
|
||||
from .update_app_dataset_join_when_app_model_config_updated import (
|
||||
handle as handle_update_app_dataset_join_when_app_model_config_updated,
|
||||
)
|
||||
from .update_app_dataset_join_when_app_published_workflow_updated import (
|
||||
handle as handle_update_app_dataset_join_when_app_published_workflow_updated,
|
||||
)
|
||||
@@ -38,9 +38,9 @@ __all__ = [
|
||||
"handle_sync_plugin_trigger_when_app_created",
|
||||
"handle_sync_webhook_when_app_created",
|
||||
"handle_sync_workflow_schedule_when_app_published",
|
||||
"handle_update_account_money_when_messaeg_created_extend", # Extend messaeg_created_extend
|
||||
"handle_update_app_dataset_join_when_app_model_config_updated",
|
||||
"handle_update_app_dataset_join_when_app_published_workflow_updated",
|
||||
"handle_update_app_triggers_when_app_published_workflow_updated",
|
||||
"handle_update_provider_when_message_created",
|
||||
"handle_update_account_money_when_messaeg_created_extend",# Extend messaeg_created_extend
|
||||
]
|
||||
|
||||
@@ -2,7 +2,7 @@ import json
|
||||
import os
|
||||
import threading
|
||||
|
||||
from flask import Response
|
||||
from flask import Response, request # Extend: 新增request
|
||||
|
||||
from configs import dify_config
|
||||
from dify_app import DifyApp
|
||||
@@ -14,6 +14,12 @@ def init_app(app: DifyApp):
|
||||
"""Add Version headers to the response."""
|
||||
response.headers.add("X-Version", dify_config.project.version)
|
||||
response.headers.add("X-Env", dify_config.DEPLOY_ENV)
|
||||
# Extend: Start New proxy authentication: Login and write JWT token to cookies
|
||||
cookie = request.cookies.get("x-token")
|
||||
token = request.headers.get("Authorization")
|
||||
if token is not None and len(token) > 0 and token != cookie:
|
||||
response.set_cookie("x-token", token[7:], httponly=True)
|
||||
# Extend: Stop New proxy authentication: Login and write JWT token to cookies
|
||||
return response
|
||||
|
||||
@app.route("/health")
|
||||
|
||||
@@ -4,7 +4,7 @@ from dify_app import DifyApp
|
||||
|
||||
BASE_CORS_HEADERS: tuple[str, ...] = ("Content-Type", HEADER_NAME_APP_CODE, HEADER_NAME_PASSPORT)
|
||||
SERVICE_API_HEADERS: tuple[str, ...] = (*BASE_CORS_HEADERS, "Authorization", "Authorization-extend", "X-App-Code")
|
||||
AUTHENTICATED_HEADERS: tuple[str, ...] = (*SERVICE_API_HEADERS, HEADER_NAME_CSRF_TOKEN), "Authorization-extend", "X-App-Code")
|
||||
AUTHENTICATED_HEADERS: tuple[str, ...] = (*SERVICE_API_HEADERS, HEADER_NAME_CSRF_TOKEN, "Authorization-extend", "X-App-Code")
|
||||
FILES_HEADERS: tuple[str, ...] = (*BASE_CORS_HEADERS, HEADER_NAME_CSRF_TOKEN, "Authorization-extend")
|
||||
EMBED_HEADERS: tuple[str, ...] = ("Content-Type", HEADER_NAME_APP_CODE, "Authorization-extend", "X-App-Code")
|
||||
EXPOSED_HEADERS: tuple[str, ...] = ("X-Version", "X-Env", "X-Trace-Id")
|
||||
|
||||
@@ -102,6 +102,7 @@ def init_app(app: DifyApp) -> Celery:
|
||||
imports = [
|
||||
"tasks.async_workflow_tasks", # trigger workers
|
||||
"tasks.trigger_processing_tasks", # async trigger processing
|
||||
"tasks.extend.update_account_money_when_workflow_node_execution_created_extend", # 二开部分 - workflow计费任务
|
||||
]
|
||||
day = dify_config.CELERY_BEAT_SCHEDULER_TIME
|
||||
|
||||
@@ -175,6 +176,29 @@ def init_app(app: DifyApp) -> Celery:
|
||||
"task": "schedule.trigger_provider_refresh_task.trigger_provider_refresh",
|
||||
"schedule": timedelta(minutes=dify_config.TRIGGER_PROVIDER_REFRESH_INTERVAL),
|
||||
}
|
||||
# ---------------------------- 二开部分 Begin ----------------------------
|
||||
# 添加二开的定时任务imports
|
||||
imports.append("schedule.update_account_used_quota_extend") # 每月重置账号额度
|
||||
imports.append("schedule.update_api_token_daily_used_quota_task_extend") # 重置密钥日额度
|
||||
imports.append("schedule.update_api_token_monthly_used_quota_task_extend") # 重置密钥月额度
|
||||
|
||||
# 每月1号00:00,重置账号额度
|
||||
beat_schedule["update_account_used_quota"] = {
|
||||
"task": "schedule.update_account_used_quota_extend.update_account_used_quota_extend",
|
||||
"schedule": crontab(minute="0", hour="0", day_of_month="1"),
|
||||
}
|
||||
# 每天00:00,重置密钥日额度
|
||||
beat_schedule["update_api_token_daily_used_quota_task_extend"] = {
|
||||
"task": "schedule.update_api_token_daily_used_quota_task_extend.update_api_token_daily_used_quota_task_extend",
|
||||
"schedule": crontab(minute="0", hour="0"),
|
||||
}
|
||||
# 每月1号00:00,重置密钥月额度
|
||||
beat_schedule["update_api_token_monthly_used_quota_task_extend"] = {
|
||||
"task": "schedule.update_api_token_monthly_used_quota_task_extend.update_api_token_monthly_used_quota_task_extend",
|
||||
"schedule": crontab(minute="0", hour="0", day_of_month="1"),
|
||||
}
|
||||
# ---------------------------- 二开部分 End ----------------------------
|
||||
|
||||
celery_app.conf.update(beat_schedule=beat_schedule, imports=imports)
|
||||
|
||||
return celery_app
|
||||
|
||||
@@ -56,7 +56,7 @@ def init_app(app: DifyApp):
|
||||
migrate_oss,
|
||||
setup_datasource_oauth_client,
|
||||
transform_datasource_credentials,
|
||||
install_rag_pipeline_plugins,,
|
||||
install_rag_pipeline_plugins,
|
||||
extend_db,
|
||||
]
|
||||
for cmd in cmds_to_register:
|
||||
|
||||
@@ -198,6 +198,7 @@ app_detail_fields_with_site = {
|
||||
"access_mode": fields.String,
|
||||
"tags": fields.List(fields.Nested(tag_fields)),
|
||||
"site": fields.Nested(site_fields),
|
||||
"retention_number": fields.Integer, # Extend: 记忆上下文功能
|
||||
}
|
||||
|
||||
|
||||
|
||||
+1
-1
@@ -215,7 +215,7 @@ class OaOAuth(OAuth):
|
||||
|
||||
return current
|
||||
|
||||
def get_authorization_url(self, invite_token: Optional[str] = None):
|
||||
def get_authorization_url(self, invite_token: str | None = None):
|
||||
auto2_conf = self.get_auto2_conf()
|
||||
integration = auto2_conf.get('integration')
|
||||
if integration is None:
|
||||
|
||||
@@ -22,6 +22,7 @@ logger = logging.getLogger('alembic.env')
|
||||
# 将当前目录的父目录(api目录)添加到Python路径中,以便能够导入models模块
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||
|
||||
|
||||
# 获取当前运行的应用数据库引擎和URL
|
||||
def get_engine():
|
||||
try:
|
||||
@@ -29,6 +30,7 @@ def get_engine():
|
||||
except (KeyError, AttributeError):
|
||||
return current_app.extensions['migrate'].db.engine
|
||||
|
||||
|
||||
def get_engine_url():
|
||||
try:
|
||||
return get_engine().url.render_as_string(hide_password=False).replace(
|
||||
@@ -36,6 +38,7 @@ def get_engine_url():
|
||||
except AttributeError:
|
||||
return str(get_engine().url).replace('%', '%%')
|
||||
|
||||
|
||||
# 使用当前应用的数据库URL替换配置文件中的URL
|
||||
config.set_main_option('sqlalchemy.url', get_engine_url())
|
||||
|
||||
|
||||
@@ -0,0 +1,45 @@
|
||||
"""add_app_extend
|
||||
|
||||
Revision ID: 012_app_extend
|
||||
Revises: 011_system_integration_fields
|
||||
Create Date: 2025-01-15 12:00:00.000000
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
|
||||
from models import types
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '012_app_extend'
|
||||
down_revision = '011_system_integration_fields'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
conn = op.get_bind()
|
||||
inspector = Inspector.from_engine(conn)
|
||||
tables = inspector.get_table_names()
|
||||
|
||||
if 'app_extend' not in tables:
|
||||
op.create_table('app_extend',
|
||||
sa.Column('id', types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('app_id', types.StringUUID(), nullable=False),
|
||||
sa.Column('retention_number', sa.Integer(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id', name='app_extend_joins_pkey')
|
||||
)
|
||||
with op.batch_alter_table('app_extend', schema=None) as batch_op:
|
||||
batch_op.create_index('app_extend_id_app_id_idx', ['app_id'], unique=False)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('app_extend', schema=None) as batch_op:
|
||||
batch_op.drop_index('app_extend_id_app_id_idx')
|
||||
|
||||
op.drop_table('app_extend')
|
||||
# ### end Alembic commands ###
|
||||
@@ -0,0 +1,48 @@
|
||||
"""add_message_context_extend
|
||||
|
||||
Revision ID: 013_message_context_extend
|
||||
Revises: 012_app_extend
|
||||
Create Date: 2025-01-15 13:00:00.000000
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
|
||||
from models import types
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '013_message_context_extend'
|
||||
down_revision = '012_app_extend'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
conn = op.get_bind()
|
||||
inspector = Inspector.from_engine(conn)
|
||||
tables = inspector.get_table_names()
|
||||
|
||||
if 'message_context_extend' not in tables:
|
||||
op.create_table('message_context_extend',
|
||||
sa.Column('id', types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||
sa.Column('conversation_id', sa.String(length=36), nullable=True),
|
||||
sa.Column('message_id', sa.String(length=36), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='message_context_extend_joins_pkey')
|
||||
)
|
||||
with op.batch_alter_table('message_context_extend', schema=None) as batch_op:
|
||||
batch_op.create_index('message_context_conversation_id_idx', ['conversation_id'], unique=False)
|
||||
batch_op.create_index('message_context_created_at_idx', ['created_at'], unique=False)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('message_context_extend', schema=None) as batch_op:
|
||||
batch_op.drop_index('message_context_created_at_idx')
|
||||
batch_op.drop_index('message_context_conversation_id_idx')
|
||||
|
||||
op.drop_table('message_context_extend')
|
||||
# ### end Alembic commands ###
|
||||
-62
@@ -1,62 +0,0 @@
|
||||
"""add_account_money_extend_unique_constraint
|
||||
|
||||
Revision ID: 012_account_money_extend_unique
|
||||
Revises: 011_system_integration_fields
|
||||
Create Date: 2025-10-21 18:00:00.000000
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '012_account_money_extend_unique'
|
||||
down_revision = '011_system_integration_fields'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
conn = op.get_bind()
|
||||
inspector = Inspector.from_engine(conn)
|
||||
tables = inspector.get_table_names()
|
||||
|
||||
if 'account_money_extend' in tables:
|
||||
# 首先删除重复数据,只保留每个account_id中updated_at最大的记录
|
||||
conn.execute(sa.text("""
|
||||
DELETE FROM account_money_extend
|
||||
WHERE id NOT IN (
|
||||
SELECT DISTINCT ON (account_id) id
|
||||
FROM account_money_extend
|
||||
ORDER BY account_id, updated_at DESC
|
||||
)
|
||||
"""))
|
||||
|
||||
# 删除现有的普通索引
|
||||
with op.batch_alter_table('account_money_extend', schema=None) as batch_op:
|
||||
try:
|
||||
batch_op.drop_index('idx_account_money_account_id')
|
||||
except Exception:
|
||||
# 如果索引不存在,忽略错误
|
||||
pass
|
||||
|
||||
# 创建唯一约束
|
||||
with op.batch_alter_table('account_money_extend', schema=None) as batch_op:
|
||||
batch_op.create_unique_constraint('idx_account_money_account_id_unique', ['account_id'])
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
if 'account_money_extend' in tables:
|
||||
with op.batch_alter_table('account_money_extend', schema=None) as batch_op:
|
||||
try:
|
||||
batch_op.drop_constraint('idx_account_money_account_id_unique', type_='unique')
|
||||
except Exception:
|
||||
# 如果约束不存在,忽略错误
|
||||
pass
|
||||
|
||||
# 重新创建普通索引
|
||||
batch_op.create_index('idx_account_money_account_id', ['account_id'], unique=False)
|
||||
# ### end Alembic commands ###
|
||||
@@ -26,6 +26,9 @@ from .dataset import (
|
||||
TidbAuthBinding,
|
||||
Whitelist,
|
||||
)
|
||||
|
||||
# extend: db
|
||||
from .engine import db
|
||||
from .enums import (
|
||||
AppTriggerStatus,
|
||||
AppTriggerType,
|
||||
@@ -174,7 +177,7 @@ __all__ = [
|
||||
"RecommendedApp",
|
||||
"SavedMessage",
|
||||
"Site",
|
||||
"SystemIntegrationExtend", # Extend System Integration
|
||||
"SystemIntegrationExtend", # Extend System Integration
|
||||
"Tag",
|
||||
"TagBinding",
|
||||
"Tenant",
|
||||
@@ -209,4 +212,5 @@ __all__ = [
|
||||
"WorkflowToolProvider",
|
||||
"WorkflowTriggerStatus",
|
||||
"WorkflowType",
|
||||
"db", # extend: db
|
||||
]
|
||||
|
||||
@@ -577,6 +577,7 @@ class RecommendedAppsCategoryJoinExtend(db.Model):
|
||||
recommended_id = db.Column(StringUUID, nullable=False)
|
||||
category_id = db.Column(StringUUID, nullable=False)
|
||||
|
||||
|
||||
class RecommendedApp(Base): # bug
|
||||
__tablename__ = "recommended_apps"
|
||||
__table_args__ = (
|
||||
|
||||
@@ -8,6 +8,7 @@ class EndUserAccountJoinsExtend(db.Model):
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="end_user_account_joins_pkey"),
|
||||
db.Index("end_user_account_joins_account_id_idx", "account_id"),
|
||||
db.Index("end_user_account_joins_end_user_id_idx", "end_user_id"), # 单独索引,用于计费查询优化
|
||||
db.Index("end_user_account_joins_end_user_id_app_id_idx", "end_user_id", "app_id"),
|
||||
)
|
||||
|
||||
@@ -17,3 +18,33 @@ class EndUserAccountJoinsExtend(db.Model):
|
||||
app_id = db.Column(StringUUID, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
|
||||
|
||||
# Extend: 记忆上下文功能
|
||||
class AppExtend(db.Model):
|
||||
__tablename__ = "app_extend"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="app_extend_joins_pkey"),
|
||||
db.Index("app_extend_id_app_id_idx", "app_id"),
|
||||
)
|
||||
|
||||
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
app_id = db.Column(StringUUID, nullable=False)
|
||||
retention_number = db.Column(db.Integer, nullable=True)
|
||||
# Extend: 记忆上下文功能
|
||||
|
||||
|
||||
# Extend: 消息上下文分割功能
|
||||
class MessageContextExtend(db.Model):
|
||||
__tablename__ = "message_context_extend"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="message_context_extend_joins_pkey"),
|
||||
db.Index("message_context_conversation_id_idx", "conversation_id"),
|
||||
db.Index("message_context_created_at_idx", "created_at"),
|
||||
)
|
||||
|
||||
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
conversation_id = db.Column(db.String(36), nullable=True)
|
||||
message_id = db.Column(db.String(36), nullable=False)
|
||||
# Extend: 消息上下文分割功能
|
||||
|
||||
@@ -9,10 +9,10 @@ from .engine import db
|
||||
|
||||
|
||||
class SystemIntegrationClassify:
|
||||
SYSTEM_INTEGRATION_DINGTALK = 1 # 钉钉
|
||||
SYSTEM_INTEGRATION_WEIXIN = 2 # 微信
|
||||
SYSTEM_INTEGRATION_FEI_SU = 3 # 飞书
|
||||
SYSTEM_INTEGRATION_OAUTH_TWO = 4 # OAuth2
|
||||
SYSTEM_INTEGRATION_DINGTALK = 1 # 钉钉
|
||||
SYSTEM_INTEGRATION_WEIXIN = 2 # 微信
|
||||
SYSTEM_INTEGRATION_FEI_SU = 3 # 飞书
|
||||
SYSTEM_INTEGRATION_OAUTH_TWO = 4 # OAuth2
|
||||
|
||||
|
||||
class SystemIntegrationExtend(db.Model):
|
||||
|
||||
@@ -21,6 +21,7 @@ from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_user
|
||||
from models import Account
|
||||
from models.model import App, AppMode, AppModelConfig, AppStatisticsExtend, RecommendedApp, Site
|
||||
from models.model_extend import AppExtend # Extend: 记忆上下文功能
|
||||
from models.tools import ApiToolProvider
|
||||
from services.billing_service import BillingService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
@@ -259,6 +260,13 @@ class AppService:
|
||||
return model_config
|
||||
|
||||
app = ModifiedApp(app)
|
||||
# Extend: 记忆上下文功能 - Start
|
||||
app_extend: AppExtend = db.session.query(AppExtend).filter(AppExtend.app_id == app.id).first()
|
||||
if app_extend is not None:
|
||||
app.retention_number = app_extend.retention_number
|
||||
else:
|
||||
app.retention_number = dify_config.DEFAULT_NUMBER_CONTEXT
|
||||
# Extend: 记忆上下文功能 - Stop
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ from services.account_service import AccountService, RegisterService, TenantServ
|
||||
from services.account_service_extend import TenantExtendService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
DINGTALK_ACCOUNT_TOKEN = { "time": 0, "token": "" }
|
||||
DINGTALK_ACCOUNT_TOKEN = {"time": 0, "token": ""}
|
||||
|
||||
|
||||
class DingTalkService:
|
||||
@@ -97,7 +97,7 @@ class DingTalkService:
|
||||
dingTalkToken, err = cls.get_access_token()
|
||||
responses = requests.post(
|
||||
f'https://oapi.dingtalk.com/topapi/v2/user/get?access_token={dingTalkToken}',
|
||||
json={ "userid": userid },
|
||||
json={"userid": userid},
|
||||
)
|
||||
# Check the response status code
|
||||
if responses.status_code != 200:
|
||||
@@ -148,7 +148,7 @@ class DingTalkService:
|
||||
return "", f"Failed to obtain token: {err}"
|
||||
response = requests.get(
|
||||
"https://api.dingtalk.com/v1.0/contact/users/me",
|
||||
headers={ "x-acs-dingtalk-access-token": userToken },
|
||||
headers={"x-acs-dingtalk-access-token": userToken},
|
||||
)
|
||||
# Check the response status code
|
||||
if response.status_code != 200:
|
||||
@@ -161,7 +161,7 @@ class DingTalkService:
|
||||
dingTalkToken, err = cls.get_access_token()
|
||||
unionIdResponse = requests.post(
|
||||
f"https://oapi.dingtalk.com/topapi/user/getbyunionid?access_token={dingTalkToken}",
|
||||
json={ "unionid": req["unionId"] }
|
||||
json={"unionid": req["unionId"]}
|
||||
)
|
||||
# Check the response status code
|
||||
if unionIdResponse.status_code != 200:
|
||||
@@ -185,7 +185,7 @@ class DingTalkService:
|
||||
return "", f"Failed to obtain token: {err}"
|
||||
response = requests.post(
|
||||
f"{host}/getuserinfo?access_token={token}",
|
||||
json={ "code": code },
|
||||
json={"code": code},
|
||||
)
|
||||
# Check the response status code
|
||||
if response.status_code != 200:
|
||||
|
||||
@@ -1,21 +1,23 @@
|
||||
import json
|
||||
|
||||
# extend start: oauth2 and DingTalk third-party login
|
||||
import re
|
||||
from enum import StrEnum
|
||||
|
||||
from flask import has_app_context, has_request_context, request
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from configs import dify_config
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from services.billing_service import BillingService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
|
||||
# extend start: oauth2 and DingTalk third-party login
|
||||
import re
|
||||
import json
|
||||
from flask import request
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.system_extend import SystemIntegrationClassify, SystemIntegrationExtend
|
||||
from services.billing_service import BillingService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
|
||||
# extend stop: oauth2 and DingTalk third-party login
|
||||
|
||||
|
||||
class SubscriptionModel(BaseModel):
|
||||
plan: str = CloudPlan.SANDBOX
|
||||
interval: str = ""
|
||||
@@ -180,9 +182,9 @@ class SystemFeatureModel(BaseModel):
|
||||
plugin_manager: PluginManagerModel = PluginManagerModel()
|
||||
is_custom_auth2: str = "" # extend: Customizing AUTH2
|
||||
is_custom_auth2_logout: str = "" # extend: Customizing AUTH2
|
||||
ding_talk_client_id: str = "" # extend: DingTalk third-party login
|
||||
ding_talk_corp_id: str = "" # extend: DingTalk sidebar login
|
||||
ding_talk: bool = "" # extend: DingTalk sidebar login
|
||||
ding_talk_client_id: str = "" # extend: DingTalk third-party login
|
||||
ding_talk_corp_id: str = "" # extend: DingTalk sidebar login
|
||||
ding_talk: bool = "" # extend: DingTalk sidebar login
|
||||
|
||||
|
||||
class FeatureService:
|
||||
@@ -216,9 +218,14 @@ class FeatureService:
|
||||
def get_system_features(cls) -> SystemFeatureModel:
|
||||
system_features = SystemFeatureModel()
|
||||
# extend start: oauth2
|
||||
api_host = request.host_url
|
||||
# 通过nginx代理转发会导致 request.host_url 获取的是内网ip,这个时候使用.env的配置
|
||||
if bool(re.search(r'^(?:[0-9]{1,3}\.){3}[0-9]{1,3}', request.host_url)):
|
||||
# 检查是否有请求上下文(在 Celery worker 中可能没有)
|
||||
if has_request_context():
|
||||
api_host = request.host_url
|
||||
# 通过nginx代理转发会导致 request.host_url 获取的是内网ip,这个时候使用.env的配置
|
||||
if bool(re.search(r'^(?:[0-9]{1,3}\.){3}[0-9]{1,3}', request.host_url)):
|
||||
api_host = dify_config.CONSOLE_WEB_URL
|
||||
else:
|
||||
# 没有请求上下文时(如 Celery worker),直接使用配置值
|
||||
api_host = dify_config.CONSOLE_WEB_URL
|
||||
redis_client.set("api_host", api_host)
|
||||
# extend stop: oauth2
|
||||
@@ -246,19 +253,21 @@ class FeatureService:
|
||||
system_features.is_allow_create_workspace = dify_config.ALLOW_CREATE_WORKSPACE
|
||||
system_features.is_email_setup = dify_config.MAIL_TYPE is not None and dify_config.MAIL_TYPE != ""
|
||||
# extend start: DingTalk third-party login
|
||||
for i in db.session.query(SystemIntegrationExtend).filter(SystemIntegrationExtend.status == True).all():
|
||||
if i.classify == SystemIntegrationClassify.SYSTEM_INTEGRATION_DINGTALK:
|
||||
system_features.ding_talk_client_id = i.app_key
|
||||
system_features.ding_talk_corp_id = i.corp_id
|
||||
system_features.ding_talk = i.status
|
||||
# Extend: OAuth2 Start
|
||||
elif i.classify == SystemIntegrationClassify.SYSTEM_INTEGRATION_OAUTH_TWO:
|
||||
config = json.loads(i.config)
|
||||
system_features.is_custom_auth2 = i.status
|
||||
if "logout_url" in config.keys():
|
||||
system_features.is_custom_auth2_logout = "{}{}".format(
|
||||
config['server_url'], config['logout_url'])
|
||||
# Extend: OAuth2 Stop
|
||||
# 检查是否有应用上下文(访问 db.session 需要应用上下文)
|
||||
if has_app_context():
|
||||
for i in db.session.query(SystemIntegrationExtend).filter(SystemIntegrationExtend.status == True).all():
|
||||
if i.classify == SystemIntegrationClassify.SYSTEM_INTEGRATION_DINGTALK:
|
||||
system_features.ding_talk_client_id = i.app_key
|
||||
system_features.ding_talk_corp_id = i.corp_id
|
||||
system_features.ding_talk = i.status
|
||||
# Extend: OAuth2 Start
|
||||
elif i.classify == SystemIntegrationClassify.SYSTEM_INTEGRATION_OAUTH_TWO:
|
||||
config = json.loads(i.config)
|
||||
system_features.is_custom_auth2 = i.status
|
||||
if "logout_url" in config.keys():
|
||||
system_features.is_custom_auth2_logout = "{}{}".format(
|
||||
config['server_url'], config['logout_url'])
|
||||
# Extend: OAuth2 Stop
|
||||
# extend stop: DingTalk third-party login
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -8,7 +8,7 @@ from services.model_provider_service_extend import ModelProviderExtendService
|
||||
class ModelExtendService:
|
||||
@staticmethod
|
||||
def sync_set_all_model_to_tenant(tenant_id: str) -> bool:
|
||||
logging.info(f"开始同步所有模型到工作区: {tenant_id}")
|
||||
logging.info("开始同步所有模型到工作区: %s", tenant_id)
|
||||
model_provider_service_extend = ModelProviderExtendService()
|
||||
# 同步供应商+模型名称的模型数据
|
||||
provider_model_records = TenantExtendService.get_sync_all_model()
|
||||
|
||||
@@ -2,6 +2,7 @@ from sqlalchemy import select
|
||||
|
||||
from constants.languages import languages
|
||||
from extensions.ext_database import db
|
||||
|
||||
# extend add category to categories
|
||||
from models.model import App, RecommendedApp, RecommendedAppsCategoryJoinExtend, RecommendedCategoryExtend
|
||||
from services.app_dsl_service import AppDslService
|
||||
@@ -67,22 +68,34 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase):
|
||||
if not site:
|
||||
continue
|
||||
|
||||
recommended_app_result = {
|
||||
"id": recommended_app.id,
|
||||
"app": recommended_app.app,
|
||||
"app_id": recommended_app.app_id,
|
||||
"description": site.description,
|
||||
"copyright": site.copyright,
|
||||
"privacy_policy": site.privacy_policy,
|
||||
"custom_disclaimer": site.custom_disclaimer,
|
||||
"category": recommended_app.category,
|
||||
"position": recommended_app.position,
|
||||
"is_listed": recommended_app.is_listed,
|
||||
}
|
||||
recommended_apps_result.append(recommended_app_result)
|
||||
|
||||
categories.add(recommended_app.category)
|
||||
config = app.app_model_config
|
||||
if config is not None and config.pre_prompt is not None and len(config.pre_prompt) > 0:
|
||||
description = config.pre_prompt
|
||||
if recommended_app.id in recommended:
|
||||
classList = recommended[recommended_app.id]
|
||||
if len(classList) == 0:
|
||||
classList.append("")
|
||||
for classId in classList:
|
||||
category = "未分类"
|
||||
if classId in class_dick:
|
||||
category = class_dick[classId]
|
||||
recommended_app_result = {
|
||||
"id": recommended_app.id,
|
||||
"app": recommended_app.app,
|
||||
"app_id": recommended_app.app_id,
|
||||
"description": description,
|
||||
"copyright": site.copyright,
|
||||
"privacy_policy": site.privacy_policy,
|
||||
"custom_disclaimer": site.custom_disclaimer,
|
||||
"category": category,
|
||||
"position": recommended_app.position,
|
||||
"is_listed": recommended_app.is_listed,
|
||||
}
|
||||
recommended_apps_result.append(recommended_app_result)
|
||||
|
||||
categories.add(category) # add category to categories
|
||||
categories = sorted(categories)
|
||||
categories.append("未分类")
|
||||
return {"recommended_apps": recommended_apps_result, "categories": sorted(categories)}
|
||||
# -------------- extend stop: add category to categories ---------------
|
||||
|
||||
|
||||
@@ -186,3 +186,26 @@ class RecommendedAppService:
|
||||
return recommendedApp.id
|
||||
except:
|
||||
return ""
|
||||
|
||||
# Extend: start messages context handling
|
||||
@classmethod
|
||||
def message_context(cls, conversation_id: str):
|
||||
from models.model_extend import MessageContextExtend
|
||||
message_list = []
|
||||
message_context = db.session.query(MessageContextExtend).filter(
|
||||
MessageContextExtend.conversation_id == conversation_id).order_by(
|
||||
MessageContextExtend.created_at.desc()).all()
|
||||
for v in message_context:
|
||||
message_list.append(v.message_id)
|
||||
return message_list
|
||||
|
||||
@classmethod
|
||||
def delete_message_context(cls, conversation_id, message_id: str):
|
||||
from models.model_extend import MessageContextExtend
|
||||
db.session.query(MessageContextExtend).filter(
|
||||
MessageContextExtend.conversation_id == conversation_id,
|
||||
MessageContextExtend.message_id == message_id,
|
||||
).delete()
|
||||
db.session.commit()
|
||||
return 'ok'
|
||||
# Extend: stop messages context handling
|
||||
|
||||
@@ -4,6 +4,7 @@ from configs import dify_config
|
||||
from extensions.ext_database import db
|
||||
from models.account import Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from services.account_service import TenantService
|
||||
|
||||
# extend: 添加用户权限
|
||||
from services.account_service_extend import TenantExtendService
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
+111
-45
@@ -1,26 +1,97 @@
|
||||
import json
|
||||
import logging
|
||||
from decimal import Decimal
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from sqlalchemy import exists
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from configs import dify_config
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.enums import NodeType
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.account import Account
|
||||
from models.account_money_extend import AccountMoneyExtend
|
||||
from models.api_token_money_extend import ApiTokenMessageJoinsExtend, ApiTokenMoneyExtend
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model_extend import EndUserAccountJoinsExtend
|
||||
|
||||
# 缓存键前缀和过期时间
|
||||
PAYER_ID_CACHE_PREFIX = "billing:payer_id:"
|
||||
PAYER_ID_CACHE_TTL = 3600 # 1小时缓存
|
||||
|
||||
|
||||
def _get_payer_id_from_cache(end_user_id: str) -> Optional[str]:
|
||||
"""从Redis缓存获取付费人ID"""
|
||||
try:
|
||||
cache_key = f"{PAYER_ID_CACHE_PREFIX}{end_user_id}"
|
||||
cached_value = redis_client.get(cache_key)
|
||||
if cached_value:
|
||||
return cached_value.decode('utf-8') if isinstance(cached_value, bytes) else cached_value
|
||||
except Exception as e:
|
||||
logging.debug("缓存读取失败: %s", e)
|
||||
return None
|
||||
|
||||
|
||||
def _set_payer_id_to_cache(end_user_id: str, payer_id: str) -> None:
|
||||
"""将付费人ID写入Redis缓存"""
|
||||
try:
|
||||
cache_key = f"{PAYER_ID_CACHE_PREFIX}{end_user_id}"
|
||||
redis_client.setex(cache_key, PAYER_ID_CACHE_TTL, payer_id)
|
||||
except Exception as e:
|
||||
logging.debug("缓存写入失败: %s", e)
|
||||
|
||||
|
||||
def _resolve_payer_id(created_by: str, created_by_role: Optional[str]) -> str:
|
||||
"""
|
||||
解析实际付费人ID
|
||||
使用缓存+高效查询优化性能
|
||||
"""
|
||||
payer_id = created_by
|
||||
|
||||
if created_by_role != CreatorUserRole.END_USER.value:
|
||||
return payer_id
|
||||
|
||||
# 先检查缓存
|
||||
cached_payer_id = _get_payer_id_from_cache(created_by)
|
||||
if cached_payer_id:
|
||||
return cached_payer_id
|
||||
|
||||
# 使用 EXISTS 子查询检查是否是真实账户,比 SELECT 更高效
|
||||
is_account = db.session.query(
|
||||
exists().where(Account.id == created_by)
|
||||
).scalar()
|
||||
|
||||
if is_account:
|
||||
# 是真实账户,缓存并返回
|
||||
_set_payer_id_to_cache(created_by, created_by)
|
||||
return created_by
|
||||
|
||||
# 查询关联表获取真正的付费账户
|
||||
# 只选择需要的字段,使用索引优化查询
|
||||
end_user_account = (
|
||||
db.session.query(EndUserAccountJoinsExtend.account_id)
|
||||
.filter(EndUserAccountJoinsExtend.end_user_id == created_by)
|
||||
.order_by(EndUserAccountJoinsExtend.created_at.desc())
|
||||
.first()
|
||||
)
|
||||
|
||||
if end_user_account:
|
||||
payer_id = str(end_user_account.account_id)
|
||||
|
||||
# 缓存结果
|
||||
_set_payer_id_to_cache(created_by, payer_id)
|
||||
return payer_id
|
||||
|
||||
|
||||
@shared_task(queue="extend_high", bind=True, max_retries=3)
|
||||
def update_account_money_when_workflow_node_execution_created_extend(
|
||||
self, workflow_node_execution_dict: dict):
|
||||
"""
|
||||
计算工作流节点执行的费用并更新账户额度
|
||||
优化版本:使用缓存减少数据库查询,使用原子更新避免并发问题
|
||||
:param workflow_node_execution_dict: 工作流节点执行字典
|
||||
"""
|
||||
|
||||
@@ -35,7 +106,7 @@ def update_account_money_when_workflow_node_execution_created_extend(
|
||||
node_id = workflow_node_execution_dict.get("id")
|
||||
logging.info(click.style("工作流节点ID: {}".format(node_id), fg="cyan"))
|
||||
|
||||
# 拿到费用 - 从 outputs 字段获取费用信息(参考原始代码)
|
||||
# 拿到费用 - 从 outputs 字段获取费用信息
|
||||
outputs = workflow_node_execution_dict.get("outputs", {})
|
||||
|
||||
# 如果 outputs 是字符串,则解析 JSON;如果已经是字典,则直接使用
|
||||
@@ -55,36 +126,26 @@ def update_account_money_when_workflow_node_execution_created_extend(
|
||||
logging.info(click.style("扣除费用: {}".format(price), fg="green"))
|
||||
|
||||
try:
|
||||
# 当前是end_user,节点账号id
|
||||
# 分两种情况
|
||||
# web应用的请求,created_by记录的是登录账号的ID,可以拿这个ID来扣钱
|
||||
# API调用,created_by记录的是节点登录账号ID,真正需要扣钱的在关联表EndUserAccountJoinsExtend,需要多做一层查询
|
||||
created_by = workflow_node_execution_dict.get("created_by")
|
||||
created_by_role = workflow_node_execution_dict.get("created_by_role")
|
||||
payerId = created_by # 付钱的ID
|
||||
if created_by_role == CreatorUserRole.END_USER.value:
|
||||
account = db.session.query(Account).filter(Account.id == created_by).first()
|
||||
if not account:
|
||||
end_user_account_joins = (
|
||||
db.session.query(EndUserAccountJoinsExtend)
|
||||
.filter(EndUserAccountJoinsExtend.end_user_id == created_by)
|
||||
.order_by(EndUserAccountJoinsExtend.created_at.desc())
|
||||
.first()
|
||||
)
|
||||
if end_user_account_joins:
|
||||
payerId = end_user_account_joins.account_id
|
||||
|
||||
# 使用优化后的方法获取付费人ID
|
||||
payer_id = _resolve_payer_id(created_by, created_by_role)
|
||||
logging.info(click.style("更新账号额度,账号ID: {}".format(payer_id), fg="green"))
|
||||
|
||||
account_money = db.session.query(AccountMoneyExtend).filter(
|
||||
AccountMoneyExtend.account_id == payerId).first()
|
||||
logging.info(click.style("更新账号额度,账号ID: {}".format(payerId), fg="green"))
|
||||
if account_money:
|
||||
db.session.query(AccountMoneyExtend).filter(AccountMoneyExtend.account_id == payerId).update(
|
||||
{
|
||||
"used_quota": float(account_money.used_quota) + price}
|
||||
)
|
||||
else:
|
||||
# 使用原子更新,避免并发问题,并减少一次查询
|
||||
# UPDATE ... SET used_quota = used_quota + price WHERE account_id = ?
|
||||
rows_updated = db.session.query(AccountMoneyExtend).filter(
|
||||
AccountMoneyExtend.account_id == payer_id
|
||||
).update(
|
||||
{"used_quota": AccountMoneyExtend.used_quota + price},
|
||||
synchronize_session=False
|
||||
)
|
||||
|
||||
if rows_updated == 0:
|
||||
# 记录不存在,创建新记录
|
||||
account_money_add = AccountMoneyExtend(
|
||||
account_id=payerId,
|
||||
account_id=payer_id,
|
||||
used_quota=price,
|
||||
total_quota=dify_config.ACCOUNT_TOTAL_QUOTA,
|
||||
)
|
||||
@@ -92,33 +153,38 @@ def update_account_money_when_workflow_node_execution_created_extend(
|
||||
|
||||
# 扣掉密钥的钱
|
||||
workflow_run_id = workflow_node_execution_dict.get("workflow_run_id")
|
||||
api_token_message = (
|
||||
db.session.query(ApiTokenMessageJoinsExtend)
|
||||
.filter(ApiTokenMessageJoinsExtend.record_id == workflow_run_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if api_token_message:
|
||||
logging.info(click.style("更新密钥额度,密钥ID: {}".format(
|
||||
api_token_message.app_token_id), fg="green"))
|
||||
db.session.query(ApiTokenMoneyExtend).filter(
|
||||
ApiTokenMoneyExtend.app_token_id == api_token_message.app_token_id
|
||||
).update(
|
||||
{
|
||||
"accumulated_quota": ApiTokenMoneyExtend.accumulated_quota + price,
|
||||
"day_used_quota": ApiTokenMoneyExtend.day_used_quota + price,
|
||||
"month_used_quota": ApiTokenMoneyExtend.month_used_quota + price,
|
||||
},
|
||||
if workflow_run_id:
|
||||
# 只查询需要的字段
|
||||
api_token_id_result = (
|
||||
db.session.query(ApiTokenMessageJoinsExtend.app_token_id)
|
||||
.filter(ApiTokenMessageJoinsExtend.record_id == workflow_run_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if api_token_id_result and api_token_id_result.app_token_id:
|
||||
app_token_id = api_token_id_result.app_token_id
|
||||
logging.info(click.style("更新密钥额度,密钥ID: {}".format(app_token_id), fg="green"))
|
||||
db.session.query(ApiTokenMoneyExtend).filter(
|
||||
ApiTokenMoneyExtend.app_token_id == app_token_id
|
||||
).update(
|
||||
{
|
||||
"accumulated_quota": ApiTokenMoneyExtend.accumulated_quota + price,
|
||||
"day_used_quota": ApiTokenMoneyExtend.day_used_quota + price,
|
||||
"month_used_quota": ApiTokenMoneyExtend.month_used_quota + price,
|
||||
},
|
||||
synchronize_session=False
|
||||
)
|
||||
|
||||
db.session.commit()
|
||||
except SQLAlchemyError as e:
|
||||
db.session.rollback()
|
||||
logging.exception(
|
||||
click.style(f"工作流节点ID: {format(node_id)},扣除费用:"
|
||||
f"{format(price)} 数据库异常,60秒后进行重试,", fg="red")
|
||||
)
|
||||
raise self.retry(exc=e, countdown=60) # Retry after 60 seconds
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
logging.exception(
|
||||
click.style(f"工作流节点ID: {format(node_id)},扣除费用:"
|
||||
f"{format(price)} 异常报错,60秒后进行重试,", fg="red")
|
||||
|
||||
+2
-1
@@ -1,10 +1,11 @@
|
||||
import time
|
||||
from decimal import Decimal
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RuntimeRouteState
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
|
||||
|
||||
|
||||
@@ -3,11 +3,11 @@ import uuid
|
||||
from datetime import UTC, datetime
|
||||
|
||||
import pytest
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState, RuntimeRouteState
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState, RuntimeRouteState
|
||||
|
||||
_TEST_DATETIME = datetime(2024, 1, 15, 10, 30, 45)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user