mirror of
https://github.com/YFGaia/dify-plus.git
synced 2026-06-12 18:11:42 +08:00
fix: 上下文,聊天计费,额度
This commit is contained in:
@@ -11,7 +11,7 @@ from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.wraps import validate_app_token
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.annotation_fields import annotation_fields, build_annotation_model
|
||||
from models.model import App
|
||||
from models.model import ApiToken, App # extend - 密钥额度限制,新增ApiToken
|
||||
from services.annotation_service import AppAnnotationService
|
||||
|
||||
|
||||
@@ -111,7 +111,7 @@ class AnnotationListApi(Resource):
|
||||
)
|
||||
@validate_app_token
|
||||
@service_api_ns.marshal_with(build_annotation_list_model(service_api_ns))
|
||||
def get(self, app_model: App):
|
||||
def get(self, app_model: App, api_token: ApiToken): # extend - 密钥额度限制,新增api_token
|
||||
"""List annotations for the application."""
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
limit = request.args.get("limit", default=20, type=int)
|
||||
|
||||
@@ -5,7 +5,7 @@ from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.app.error import AppUnavailableError
|
||||
from controllers.service_api.wraps import validate_app_token
|
||||
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
|
||||
from models.model import App, AppMode
|
||||
from models.model import ApiToken, App, AppMode
|
||||
from services.app_service import AppService
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ class AppParameterApi(Resource):
|
||||
}
|
||||
)
|
||||
@validate_app_token
|
||||
def get(self, app_model: App):
|
||||
def get(self, app_model: App, api_token: ApiToken): # extend - 密钥额度限制,新增api_token
|
||||
"""Retrieve app parameters.
|
||||
|
||||
Returns the input form parameters and configuration for the application.
|
||||
@@ -60,7 +60,7 @@ class AppMetaApi(Resource):
|
||||
}
|
||||
)
|
||||
@validate_app_token
|
||||
def get(self, app_model: App):
|
||||
def get(self, app_model: App, api_token: ApiToken): # extend - 密钥额度限制,新增api_token
|
||||
"""Get app metadata.
|
||||
|
||||
Returns metadata about the application including configuration and settings.
|
||||
@@ -80,7 +80,7 @@ class AppInfoApi(Resource):
|
||||
}
|
||||
)
|
||||
@validate_app_token
|
||||
def get(self, app_model: App):
|
||||
def get(self, app_model: App, api_token: ApiToken): # extend - 密钥额度限制,新增api_token
|
||||
"""Get app information.
|
||||
|
||||
Returns basic information about the application including name, description, tags, and mode.
|
||||
|
||||
@@ -22,7 +22,7 @@ from controllers.service_api.app.error import (
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from models.model import ApiToken, App, EndUser # extend: 二开部分 密钥额度限制,新增api_token
|
||||
from models.model import ApiToken, App, EndUser # extend: 二开部分 密钥额度限制,新增api_token
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import (
|
||||
AudioTooLargeServiceError,
|
||||
|
||||
@@ -30,9 +30,9 @@ from core.errors.error import (
|
||||
from core.helper.trace_id_helper import get_external_trace_id
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from libs import helper
|
||||
from models.model import ApiToken, App, AppMode, EndUser # extend: 密钥额度限制,新增ApiToken
|
||||
from models.model import ApiToken, App, AppMode, EndUser # extend: 密钥额度限制,新增ApiToken
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.app_generate_service_extend import AppGenerateServiceExtend # extend: 密钥额度限制,新增ApiToken
|
||||
from services.app_generate_service_extend import AppGenerateServiceExtend # extend: 密钥额度限制,新增ApiToken
|
||||
from services.app_task_service import AppTaskService
|
||||
from services.errors.app import IsDraftWorkflowError, WorkflowIdFormatError, WorkflowNotFoundError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
||||
@@ -15,7 +15,7 @@ from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from extensions.ext_database import db
|
||||
from fields.file_fields import FileResponse
|
||||
from models import ApiToken, App, EndUser # extend - 密钥额度限制,新增api_token,否则上传文件会报错
|
||||
from models import ApiToken, App, EndUser # extend - 密钥额度限制,新增api_token,否则上传文件会报错
|
||||
from services.file_service import FileService
|
||||
|
||||
register_schema_models(service_api_ns, FileResponse)
|
||||
@@ -36,7 +36,7 @@ class FileApi(Resource):
|
||||
)
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM)) # type: ignore
|
||||
@service_api_ns.response(HTTPStatus.CREATED, "File uploaded", service_api_ns.models[FileResponse.__name__])
|
||||
def post(self, app_model: App, end_user: EndUser, api_token: ApiToken): # extend - 密钥额度限制,新增api_token,否则上传文件会报错
|
||||
def post(self, app_model: App, end_user: EndUser, api_token: ApiToken): # extend - 密钥额度限制,新增api_token,否则上传文件会报错
|
||||
"""Upload a file for use in conversations.
|
||||
|
||||
Accepts a single file upload via multipart/form-data.
|
||||
|
||||
@@ -15,7 +15,7 @@ from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from fields.conversation_fields import ResultResponse
|
||||
from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem
|
||||
from models.model import ApiToken, App, AppMode, EndUser # extend - 密钥额度限制,新增ApiToken
|
||||
from models.model import ApiToken, App, AppMode, EndUser # extend - 密钥额度限制,新增ApiToken
|
||||
from services.errors.message import (
|
||||
FirstMessageNotExistsError,
|
||||
MessageNotExistsError,
|
||||
@@ -58,7 +58,7 @@ class MessageListApi(Resource):
|
||||
}
|
||||
)
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
|
||||
def get(self, app_model: App, end_user: EndUser, api_token: ApiToken): # extend - 密钥额度限制,新增api_token,否则上传文件会报错
|
||||
def get(self, app_model: App, end_user: EndUser, api_token: ApiToken): # extend - 密钥额度限制,新增api_token,否则上传文件会报错
|
||||
"""List messages in a conversation.
|
||||
|
||||
Retrieves messages with pagination support using first_id.
|
||||
@@ -102,7 +102,7 @@ class MessageFeedbackApi(Resource):
|
||||
}
|
||||
)
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def post(self, app_model: App, end_user: EndUser, message_id, api_token: ApiToken): # extend - 密钥额度限制,新增api_token,否则上传文件会报错
|
||||
def post(self, app_model: App, end_user: EndUser, message_id, api_token: ApiToken): # extend - 密钥额度限制,新增api_token,否则上传文件会报错
|
||||
"""Submit feedback for a message.
|
||||
|
||||
Allows users to rate messages as like/dislike and provide optional feedback content.
|
||||
@@ -137,7 +137,7 @@ class AppGetFeedbacksApi(Resource):
|
||||
}
|
||||
)
|
||||
@validate_app_token
|
||||
def get(self, app_model: App):
|
||||
def get(self, app_model: App, api_token: ApiToken): # extend - 密钥额度限制,新增api_token
|
||||
"""Get all feedbacks for the application.
|
||||
|
||||
Returns paginated list of all feedback submitted for messages in this app.
|
||||
@@ -162,7 +162,7 @@ class MessageSuggestedApi(Resource):
|
||||
}
|
||||
)
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY, required=True))
|
||||
def get(self, app_model: App, end_user: EndUser, message_id, api_token: ApiToken): # extend - 密钥额度限制,新增api_token,否则上传文件会报错
|
||||
def get(self, app_model: App, end_user: EndUser, message_id, api_token: ApiToken): # extend - 密钥额度限制,新增api_token,否则上传文件会报错
|
||||
"""Get suggested follow-up questions for a message.
|
||||
|
||||
Returns AI-generated follow-up questions based on the message content.
|
||||
|
||||
@@ -6,7 +6,7 @@ from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.wraps import validate_app_token
|
||||
from extensions.ext_database import db
|
||||
from models.account import TenantStatus
|
||||
from models.model import App, Site
|
||||
from models.model import ApiToken, App, Site # extend - 密钥额度限制,新增ApiToken
|
||||
|
||||
|
||||
@service_api_ns.route("/site")
|
||||
@@ -23,7 +23,7 @@ class AppSiteApi(Resource):
|
||||
}
|
||||
)
|
||||
@validate_app_token
|
||||
def get(self, app_model: App):
|
||||
def get(self, app_model: App, api_token: ApiToken): # extend - 密钥额度限制,新增api_token
|
||||
"""Retrieve app site info.
|
||||
|
||||
Returns the site configuration for the application including theme, icons, and text.
|
||||
|
||||
@@ -34,7 +34,7 @@ from extensions.ext_database import db
|
||||
from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model
|
||||
from libs import helper
|
||||
from libs.helper import TimestampField
|
||||
from models.model import ApiToken, App, AppMode, EndUser # extend - 密钥额度限制,ApiToken
|
||||
from models.model import ApiToken, App, AppMode, EndUser # extend - 密钥额度限制,ApiToken
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import IsDraftWorkflowError, WorkflowIdFormatError, WorkflowNotFoundError
|
||||
@@ -97,7 +97,7 @@ class WorkflowRunDetailApi(Resource):
|
||||
)
|
||||
@validate_app_token
|
||||
@service_api_ns.marshal_with(build_workflow_run_model(service_api_ns))
|
||||
def get(self, app_model: App, workflow_run_id: str):
|
||||
def get(self, app_model: App, workflow_run_id: str, api_token: ApiToken): # extend - 密钥额度限制,新增api_token参数
|
||||
"""Get a workflow task running detail.
|
||||
|
||||
Returns detailed information about a specific workflow run.
|
||||
@@ -134,7 +134,7 @@ class WorkflowRunApi(Resource):
|
||||
}
|
||||
)
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def post(self, app_model: App, end_user: EndUser, api_token: ApiToken, workflow_id: str): # extend: 密钥额度限制
|
||||
def post(self, app_model: App, end_user: EndUser, api_token: ApiToken): # extend: 密钥额度限制
|
||||
"""Execute a workflow.
|
||||
|
||||
Runs a workflow with the provided inputs and returns the results.
|
||||
@@ -195,7 +195,7 @@ class WorkflowRunByIdApi(Resource):
|
||||
}
|
||||
)
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def post(self, app_model: App, end_user: EndUser, workflow_id: str):
|
||||
def post(self, app_model: App, end_user: EndUser, api_token: ApiToken, workflow_id: str): # extend: 密钥额度限制
|
||||
"""Run specific workflow by ID.
|
||||
|
||||
Executes a specific workflow version identified by its ID.
|
||||
@@ -215,6 +215,10 @@ class WorkflowRunByIdApi(Resource):
|
||||
args["external_trace_id"] = external_trace_id
|
||||
streaming = payload.response_mode == "streaming"
|
||||
|
||||
# ------------------- 二开部分Begin - 密钥额度限制 -------------------
|
||||
args["api_token"] = api_token
|
||||
# ------------------- 二开部分End - 密钥额度限制 -------------------
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model, user=end_user, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=streaming
|
||||
|
||||
@@ -14,25 +14,26 @@ from sqlalchemy import select, update
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
|
||||
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_user
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantStatus
|
||||
from models.dataset import Dataset, RateLimitLog
|
||||
from models.model import ApiToken, App
|
||||
from services.end_user_service import EndUserService
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
# extend: start 额度限制,API调用计费,新增TenantAccountRole
|
||||
from controllers.service_api.app.error_extend import (
|
||||
AccountNoMoneyErrorExtend,
|
||||
ApiTokenDayNoMoneyErrorExtend,
|
||||
ApiTokenMonthNoMoneyErrorExtend,
|
||||
)
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_user
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantStatus
|
||||
from models.account_money_extend import AccountMoneyExtend
|
||||
from models.api_token_money_extend import ApiTokenMoneyExtend
|
||||
from models.dataset import Dataset, RateLimitLog
|
||||
from models.model import ApiToken, App
|
||||
from models.model_extend import EndUserAccountJoinsExtend
|
||||
from services.end_user_service import EndUserService
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
# extend: stop 额度限制,API调用计费,新增TenantAccountRole
|
||||
|
||||
|
||||
@@ -80,17 +81,27 @@ def validate_app_token(view: Callable[P, R] | None = None, *, fetch_user_arg: Fe
|
||||
if tenant.status == TenantStatus.ARCHIVE:
|
||||
raise Forbidden("The workspace's status is archived.")
|
||||
|
||||
|
||||
# ---------------------extend: 二开部分Begin 额度限制,API调用计费 ---------------------
|
||||
# TODO 需要写入缓存,读缓存
|
||||
account_money = (
|
||||
db.session.query(AccountMoneyExtend)
|
||||
.filter(AccountMoneyExtend.account_id == ta.account_id)
|
||||
.first()
|
||||
)
|
||||
if account_money and account_money.used_quota >= account_money.total_quota:
|
||||
raise AccountNoMoneyErrorExtend()
|
||||
|
||||
tenant_account_join = (
|
||||
db.session.query(Tenant, TenantAccountJoin)
|
||||
.filter(Tenant.id == api_token.tenant_id)
|
||||
.filter(TenantAccountJoin.tenant_id == Tenant.id)
|
||||
.filter(TenantAccountJoin.role.in_(["owner"]))
|
||||
.filter(Tenant.status == TenantStatus.NORMAL)
|
||||
.one_or_none()
|
||||
) # TODO: only owner information is required, so only one is returned.
|
||||
if tenant_account_join:
|
||||
tenant, ta = tenant_account_join
|
||||
# TODO 需要写入缓存,读缓存
|
||||
account_money = (
|
||||
db.session.query(AccountMoneyExtend)
|
||||
.filter(AccountMoneyExtend.account_id == ta.account_id)
|
||||
.first()
|
||||
)
|
||||
if account_money and account_money.used_quota >= account_money.total_quota:
|
||||
raise AccountNoMoneyErrorExtend()
|
||||
else:
|
||||
raise Unauthorized("Tenant does not exist.")
|
||||
# 密钥额度判断
|
||||
kwargs["api_token"] = api_token # API token消息数据传递下去
|
||||
api_token_money = (
|
||||
@@ -382,6 +393,7 @@ def validate_and_get_api_token(scope: str | None = None):
|
||||
|
||||
return api_token
|
||||
|
||||
|
||||
# ---------------------二开部分Begin 额度限制,API调用计费 ---------------------
|
||||
def create_or_update_end_user_account_join_extend(end_user_id, account_id, app_id: str) -> EndUserAccountJoinsExtend:
|
||||
# 插入节点账号id和用户账号id关联关系,以方便扣钱查询
|
||||
@@ -402,7 +414,6 @@ def create_or_update_end_user_account_join_extend(end_user_id, account_id, app_i
|
||||
# ---------------------二开部分End 额度限制,API调用计费 ---------------------
|
||||
|
||||
|
||||
|
||||
class DatasetApiResource(Resource):
|
||||
method_decorators = [validate_dataset_token]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user