fix: 上下文,聊天计费,额度

This commit is contained in:
npc0-hue
2026-01-21 18:10:18 +08:00
parent 0cf63b0f08
commit 9ed0d7c891
111 changed files with 1954 additions and 729 deletions
@@ -11,7 +11,7 @@ from controllers.service_api import service_api_ns
from controllers.service_api.wraps import validate_app_token
from extensions.ext_redis import redis_client
from fields.annotation_fields import annotation_fields, build_annotation_model
from models.model import App
from models.model import ApiToken, App # extend - 密钥额度限制,新增ApiToken
from services.annotation_service import AppAnnotationService
@@ -111,7 +111,7 @@ class AnnotationListApi(Resource):
)
@validate_app_token
@service_api_ns.marshal_with(build_annotation_list_model(service_api_ns))
def get(self, app_model: App):
def get(self, app_model: App, api_token: ApiToken): # extend - 密钥额度限制,新增api_token
"""List annotations for the application."""
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
+4 -4
View File
@@ -5,7 +5,7 @@ from controllers.service_api import service_api_ns
from controllers.service_api.app.error import AppUnavailableError
from controllers.service_api.wraps import validate_app_token
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
from models.model import App, AppMode
from models.model import ApiToken, App, AppMode
from services.app_service import AppService
@@ -23,7 +23,7 @@ class AppParameterApi(Resource):
}
)
@validate_app_token
def get(self, app_model: App):
def get(self, app_model: App, api_token: ApiToken): # extend - 密钥额度限制,新增api_token
"""Retrieve app parameters.
Returns the input form parameters and configuration for the application.
@@ -60,7 +60,7 @@ class AppMetaApi(Resource):
}
)
@validate_app_token
def get(self, app_model: App):
def get(self, app_model: App, api_token: ApiToken): # extend - 密钥额度限制,新增api_token
"""Get app metadata.
Returns metadata about the application including configuration and settings.
@@ -80,7 +80,7 @@ class AppInfoApi(Resource):
}
)
@validate_app_token
def get(self, app_model: App):
def get(self, app_model: App, api_token: ApiToken): # extend - 密钥额度限制,新增api_token
"""Get app information.
Returns basic information about the application including name, description, tags, and mode.
+1 -1
View File
@@ -22,7 +22,7 @@ from controllers.service_api.app.error import (
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError
from models.model import ApiToken, App, EndUser # extend: 二开部分 密钥额度限制,新增api_token
from models.model import ApiToken, App, EndUser # extend: 二开部分 密钥额度限制,新增api_token
from services.audio_service import AudioService
from services.errors.audio import (
AudioTooLargeServiceError,
@@ -30,9 +30,9 @@ from core.errors.error import (
from core.helper.trace_id_helper import get_external_trace_id
from core.model_runtime.errors.invoke import InvokeError
from libs import helper
from models.model import ApiToken, App, AppMode, EndUser # extend: 密钥额度限制,新增ApiToken
from models.model import ApiToken, App, AppMode, EndUser # extend: 密钥额度限制,新增ApiToken
from services.app_generate_service import AppGenerateService
from services.app_generate_service_extend import AppGenerateServiceExtend # extend: 密钥额度限制,新增ApiToken
from services.app_generate_service_extend import AppGenerateServiceExtend # extend: 密钥额度限制,新增ApiToken
from services.app_task_service import AppTaskService
from services.errors.app import IsDraftWorkflowError, WorkflowIdFormatError, WorkflowNotFoundError
from services.errors.llm import InvokeRateLimitError
+2 -2
View File
@@ -15,7 +15,7 @@ from controllers.service_api import service_api_ns
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from extensions.ext_database import db
from fields.file_fields import FileResponse
from models import ApiToken, App, EndUser # extend - 密钥额度限制,新增api_token,否则上传文件会报错
from models import ApiToken, App, EndUser # extend - 密钥额度限制,新增api_token,否则上传文件会报错
from services.file_service import FileService
register_schema_models(service_api_ns, FileResponse)
@@ -36,7 +36,7 @@ class FileApi(Resource):
)
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM)) # type: ignore
@service_api_ns.response(HTTPStatus.CREATED, "File uploaded", service_api_ns.models[FileResponse.__name__])
def post(self, app_model: App, end_user: EndUser, api_token: ApiToken): # extend - 密钥额度限制,新增api_token,否则上传文件会报错
def post(self, app_model: App, end_user: EndUser, api_token: ApiToken): # extend - 密钥额度限制,新增api_token,否则上传文件会报错
"""Upload a file for use in conversations.
Accepts a single file upload via multipart/form-data.
+5 -5
View File
@@ -15,7 +15,7 @@ from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate
from core.app.entities.app_invoke_entities import InvokeFrom
from fields.conversation_fields import ResultResponse
from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem
from models.model import ApiToken, App, AppMode, EndUser # extend - 密钥额度限制,新增ApiToken
from models.model import ApiToken, App, AppMode, EndUser # extend - 密钥额度限制,新增ApiToken
from services.errors.message import (
FirstMessageNotExistsError,
MessageNotExistsError,
@@ -58,7 +58,7 @@ class MessageListApi(Resource):
}
)
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
def get(self, app_model: App, end_user: EndUser, api_token: ApiToken): # extend - 密钥额度限制,新增api_token,否则上传文件会报错
def get(self, app_model: App, end_user: EndUser, api_token: ApiToken): # extend - 密钥额度限制,新增api_token,否则上传文件会报错
"""List messages in a conversation.
Retrieves messages with pagination support using first_id.
@@ -102,7 +102,7 @@ class MessageFeedbackApi(Resource):
}
)
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser, message_id, api_token: ApiToken): # extend - 密钥额度限制,新增api_token,否则上传文件会报错
def post(self, app_model: App, end_user: EndUser, message_id, api_token: ApiToken): # extend - 密钥额度限制,新增api_token,否则上传文件会报错
"""Submit feedback for a message.
Allows users to rate messages as like/dislike and provide optional feedback content.
@@ -137,7 +137,7 @@ class AppGetFeedbacksApi(Resource):
}
)
@validate_app_token
def get(self, app_model: App):
def get(self, app_model: App, api_token: ApiToken): # extend - 密钥额度限制,新增api_token
"""Get all feedbacks for the application.
Returns paginated list of all feedback submitted for messages in this app.
@@ -162,7 +162,7 @@ class MessageSuggestedApi(Resource):
}
)
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY, required=True))
def get(self, app_model: App, end_user: EndUser, message_id, api_token: ApiToken): # extend - 密钥额度限制,新增api_token,否则上传文件会报错
def get(self, app_model: App, end_user: EndUser, message_id, api_token: ApiToken): # extend - 密钥额度限制,新增api_token,否则上传文件会报错
"""Get suggested follow-up questions for a message.
Returns AI-generated follow-up questions based on the message content.
+2 -2
View File
@@ -6,7 +6,7 @@ from controllers.service_api import service_api_ns
from controllers.service_api.wraps import validate_app_token
from extensions.ext_database import db
from models.account import TenantStatus
from models.model import App, Site
from models.model import ApiToken, App, Site # extend - 密钥额度限制,新增ApiToken
@service_api_ns.route("/site")
@@ -23,7 +23,7 @@ class AppSiteApi(Resource):
}
)
@validate_app_token
def get(self, app_model: App):
def get(self, app_model: App, api_token: ApiToken): # extend - 密钥额度限制,新增api_token
"""Retrieve app site info.
Returns the site configuration for the application including theme, icons, and text.
+8 -4
View File
@@ -34,7 +34,7 @@ from extensions.ext_database import db
from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model
from libs import helper
from libs.helper import TimestampField
from models.model import ApiToken, App, AppMode, EndUser # extend - 密钥额度限制,ApiToken
from models.model import ApiToken, App, AppMode, EndUser # extend - 密钥额度限制,ApiToken
from repositories.factory import DifyAPIRepositoryFactory
from services.app_generate_service import AppGenerateService
from services.errors.app import IsDraftWorkflowError, WorkflowIdFormatError, WorkflowNotFoundError
@@ -97,7 +97,7 @@ class WorkflowRunDetailApi(Resource):
)
@validate_app_token
@service_api_ns.marshal_with(build_workflow_run_model(service_api_ns))
def get(self, app_model: App, workflow_run_id: str):
def get(self, app_model: App, workflow_run_id: str, api_token: ApiToken): # extend - 密钥额度限制,新增api_token参数
"""Get a workflow task running detail.
Returns detailed information about a specific workflow run.
@@ -134,7 +134,7 @@ class WorkflowRunApi(Resource):
}
)
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser, api_token: ApiToken, workflow_id: str): # extend: 密钥额度限制
def post(self, app_model: App, end_user: EndUser, api_token: ApiToken): # extend: 密钥额度限制
"""Execute a workflow.
Runs a workflow with the provided inputs and returns the results.
@@ -195,7 +195,7 @@ class WorkflowRunByIdApi(Resource):
}
)
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
def post(self, app_model: App, end_user: EndUser, workflow_id: str):
def post(self, app_model: App, end_user: EndUser, api_token: ApiToken, workflow_id: str): # extend: 密钥额度限制
"""Run specific workflow by ID.
Executes a specific workflow version identified by its ID.
@@ -215,6 +215,10 @@ class WorkflowRunByIdApi(Resource):
args["external_trace_id"] = external_trace_id
streaming = payload.response_mode == "streaming"
# ------------------- 二开部分Begin - 密钥额度限制 -------------------
args["api_token"] = api_token
# ------------------- 二开部分End - 密钥额度限制 -------------------
try:
response = AppGenerateService.generate(
app_model=app_model, user=end_user, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=streaming
+33 -22
View File
@@ -14,25 +14,26 @@ from sqlalchemy import select, update
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from libs.login import current_user
from models import Account, Tenant, TenantAccountJoin, TenantStatus
from models.dataset import Dataset, RateLimitLog
from models.model import ApiToken, App
from services.end_user_service import EndUserService
from services.feature_service import FeatureService
# extend: start 额度限制,API调用计费,新增TenantAccountRole
from controllers.service_api.app.error_extend import (
AccountNoMoneyErrorExtend,
ApiTokenDayNoMoneyErrorExtend,
ApiTokenMonthNoMoneyErrorExtend,
)
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from libs.login import current_user
from models import Account, Tenant, TenantAccountJoin, TenantStatus
from models.account_money_extend import AccountMoneyExtend
from models.api_token_money_extend import ApiTokenMoneyExtend
from models.dataset import Dataset, RateLimitLog
from models.model import ApiToken, App
from models.model_extend import EndUserAccountJoinsExtend
from services.end_user_service import EndUserService
from services.feature_service import FeatureService
# extend: stop 额度限制,API调用计费,新增TenantAccountRole
@@ -80,17 +81,27 @@ def validate_app_token(view: Callable[P, R] | None = None, *, fetch_user_arg: Fe
if tenant.status == TenantStatus.ARCHIVE:
raise Forbidden("The workspace's status is archived.")
# ---------------------extend: 二开部分Begin 额度限制,API调用计费 ---------------------
# TODO 需要写入缓存,读缓存
account_money = (
db.session.query(AccountMoneyExtend)
.filter(AccountMoneyExtend.account_id == ta.account_id)
.first()
)
if account_money and account_money.used_quota >= account_money.total_quota:
raise AccountNoMoneyErrorExtend()
tenant_account_join = (
db.session.query(Tenant, TenantAccountJoin)
.filter(Tenant.id == api_token.tenant_id)
.filter(TenantAccountJoin.tenant_id == Tenant.id)
.filter(TenantAccountJoin.role.in_(["owner"]))
.filter(Tenant.status == TenantStatus.NORMAL)
.one_or_none()
) # TODO: only owner information is required, so only one is returned.
if tenant_account_join:
tenant, ta = tenant_account_join
# TODO 需要写入缓存,读缓存
account_money = (
db.session.query(AccountMoneyExtend)
.filter(AccountMoneyExtend.account_id == ta.account_id)
.first()
)
if account_money and account_money.used_quota >= account_money.total_quota:
raise AccountNoMoneyErrorExtend()
else:
raise Unauthorized("Tenant does not exist.")
# 密钥额度判断
kwargs["api_token"] = api_token # API token消息数据传递下去
api_token_money = (
@@ -382,6 +393,7 @@ def validate_and_get_api_token(scope: str | None = None):
return api_token
# ---------------------二开部分Begin 额度限制,API调用计费 ---------------------
def create_or_update_end_user_account_join_extend(end_user_id, account_id, app_id: str) -> EndUserAccountJoinsExtend:
# 插入节点账号id和用户账号id关联关系,以方便扣钱查询
@@ -402,7 +414,6 @@ def create_or_update_end_user_account_join_extend(end_user_id, account_id, app_i
# ---------------------二开部分End 额度限制,API调用计费 ---------------------
class DatasetApiResource(Resource):
method_decorators = [validate_dataset_token]