orm filter -> where (#22801)

Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
Asuka Minato
2025-07-24 01:57:45 +09:00
committed by GitHub
parent e64e7563f6
commit ef51678c73
161 changed files with 828 additions and 857 deletions
+11 -13
View File
@@ -643,7 +643,7 @@ class AccountService:
)
)
account = db.session.query(Account).filter(Account.email == email).first()
account = db.session.query(Account).where(Account.email == email).first()
if not account:
return None
@@ -900,7 +900,7 @@ class TenantService:
return (
db.session.query(Tenant)
.join(TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id)
.filter(TenantAccountJoin.account_id == account.id, Tenant.status == TenantStatus.NORMAL)
.where(TenantAccountJoin.account_id == account.id, Tenant.status == TenantStatus.NORMAL)
.all()
)
@@ -929,7 +929,7 @@ class TenantService:
tenant_account_join = (
db.session.query(TenantAccountJoin)
.join(Tenant, TenantAccountJoin.tenant_id == Tenant.id)
.filter(
.where(
TenantAccountJoin.account_id == account.id,
TenantAccountJoin.tenant_id == tenant_id,
Tenant.status == TenantStatus.NORMAL,
@@ -940,7 +940,7 @@ class TenantService:
if not tenant_account_join:
raise AccountNotLinkTenantError("Tenant not found or account is not a member of the tenant.")
else:
db.session.query(TenantAccountJoin).filter(
db.session.query(TenantAccountJoin).where(
TenantAccountJoin.account_id == account.id, TenantAccountJoin.tenant_id != tenant_id
).update({"current": False})
tenant_account_join.current = True
@@ -955,7 +955,7 @@ class TenantService:
db.session.query(Account, TenantAccountJoin.role)
.select_from(Account)
.join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id)
.filter(TenantAccountJoin.tenant_id == tenant.id)
.where(TenantAccountJoin.tenant_id == tenant.id)
)
# Initialize an empty list to store the updated accounts
@@ -974,8 +974,8 @@ class TenantService:
db.session.query(Account, TenantAccountJoin.role)
.select_from(Account)
.join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id)
.filter(TenantAccountJoin.tenant_id == tenant.id)
.filter(TenantAccountJoin.role == "dataset_operator")
.where(TenantAccountJoin.tenant_id == tenant.id)
.where(TenantAccountJoin.role == "dataset_operator")
)
# Initialize an empty list to store the updated accounts
@@ -995,9 +995,7 @@ class TenantService:
return (
db.session.query(TenantAccountJoin)
.filter(
TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role.in_([role.value for role in roles])
)
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role.in_([role.value for role in roles]))
.first()
is not None
)
@@ -1007,7 +1005,7 @@ class TenantService:
"""Get the role of the current account for a given tenant"""
join = (
db.session.query(TenantAccountJoin)
.filter(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id)
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id)
.first()
)
return TenantAccountRole(join.role) if join else None
@@ -1274,7 +1272,7 @@ class RegisterService:
tenant = (
db.session.query(Tenant)
.filter(Tenant.id == invitation_data["workspace_id"], Tenant.status == "normal")
.where(Tenant.id == invitation_data["workspace_id"], Tenant.status == "normal")
.first()
)
@@ -1284,7 +1282,7 @@ class RegisterService:
tenant_account = (
db.session.query(Account, TenantAccountJoin.role)
.join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id)
.filter(Account.email == invitation_data["email"], TenantAccountJoin.tenant_id == tenant.id)
.where(Account.email == invitation_data["email"], TenantAccountJoin.tenant_id == tenant.id)
.first()
)
+4 -6
View File
@@ -25,7 +25,7 @@ class AgentService:
conversation: Conversation | None = (
db.session.query(Conversation)
.filter(
.where(
Conversation.id == conversation_id,
Conversation.app_id == app_model.id,
)
@@ -37,7 +37,7 @@ class AgentService:
message: Optional[Message] = (
db.session.query(Message)
.filter(
.where(
Message.id == message_id,
Message.conversation_id == conversation_id,
)
@@ -52,12 +52,10 @@ class AgentService:
if conversation.from_end_user_id:
# only select name field
executor = (
db.session.query(EndUser, EndUser.name).filter(EndUser.id == conversation.from_end_user_id).first()
db.session.query(EndUser, EndUser.name).where(EndUser.id == conversation.from_end_user_id).first()
)
else:
executor = (
db.session.query(Account, Account.name).filter(Account.id == conversation.from_account_id).first()
)
executor = db.session.query(Account, Account.name).where(Account.id == conversation.from_account_id).first()
if executor:
executor = executor.name
+28 -34
View File
@@ -26,7 +26,7 @@ class AppAnnotationService:
# get app info
app = (
db.session.query(App)
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.first()
)
@@ -35,7 +35,7 @@ class AppAnnotationService:
if args.get("message_id"):
message_id = str(args["message_id"])
# get message info
message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app.id).first()
message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app.id).first()
if not message:
raise NotFound("Message Not Exists.")
@@ -61,9 +61,7 @@ class AppAnnotationService:
db.session.add(annotation)
db.session.commit()
# if annotation reply is enabled , add annotation to index
annotation_setting = (
db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first()
)
annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
if annotation_setting:
add_annotation_to_index_task.delay(
annotation.id,
@@ -117,7 +115,7 @@ class AppAnnotationService:
# get app info
app = (
db.session.query(App)
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.first()
)
@@ -126,8 +124,8 @@ class AppAnnotationService:
if keyword:
stmt = (
select(MessageAnnotation)
.filter(MessageAnnotation.app_id == app_id)
.filter(
.where(MessageAnnotation.app_id == app_id)
.where(
or_(
MessageAnnotation.question.ilike("%{}%".format(keyword)),
MessageAnnotation.content.ilike("%{}%".format(keyword)),
@@ -138,7 +136,7 @@ class AppAnnotationService:
else:
stmt = (
select(MessageAnnotation)
.filter(MessageAnnotation.app_id == app_id)
.where(MessageAnnotation.app_id == app_id)
.order_by(MessageAnnotation.created_at.desc(), MessageAnnotation.id.desc())
)
annotations = db.paginate(select=stmt, page=page, per_page=limit, max_per_page=100, error_out=False)
@@ -149,7 +147,7 @@ class AppAnnotationService:
# get app info
app = (
db.session.query(App)
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.first()
)
@@ -157,7 +155,7 @@ class AppAnnotationService:
raise NotFound("App not found")
annotations = (
db.session.query(MessageAnnotation)
.filter(MessageAnnotation.app_id == app_id)
.where(MessageAnnotation.app_id == app_id)
.order_by(MessageAnnotation.created_at.desc())
.all()
)
@@ -168,7 +166,7 @@ class AppAnnotationService:
# get app info
app = (
db.session.query(App)
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.first()
)
@@ -181,9 +179,7 @@ class AppAnnotationService:
db.session.add(annotation)
db.session.commit()
# if annotation reply is enabled , add annotation to index
annotation_setting = (
db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first()
)
annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
if annotation_setting:
add_annotation_to_index_task.delay(
annotation.id,
@@ -199,14 +195,14 @@ class AppAnnotationService:
# get app info
app = (
db.session.query(App)
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.first()
)
if not app:
raise NotFound("App not found")
annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).first()
annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first()
if not annotation:
raise NotFound("Annotation not found")
@@ -217,7 +213,7 @@ class AppAnnotationService:
db.session.commit()
# if annotation reply is enabled , add annotation to index
app_annotation_setting = (
db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first()
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
)
if app_annotation_setting:
@@ -236,14 +232,14 @@ class AppAnnotationService:
# get app info
app = (
db.session.query(App)
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.first()
)
if not app:
raise NotFound("App not found")
annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).first()
annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first()
if not annotation:
raise NotFound("Annotation not found")
@@ -252,7 +248,7 @@ class AppAnnotationService:
annotation_hit_histories = (
db.session.query(AppAnnotationHitHistory)
.filter(AppAnnotationHitHistory.annotation_id == annotation_id)
.where(AppAnnotationHitHistory.annotation_id == annotation_id)
.all()
)
if annotation_hit_histories:
@@ -262,7 +258,7 @@ class AppAnnotationService:
db.session.commit()
# if annotation reply is enabled , delete annotation index
app_annotation_setting = (
db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first()
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
)
if app_annotation_setting:
@@ -275,7 +271,7 @@ class AppAnnotationService:
# get app info
app = (
db.session.query(App)
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.first()
)
@@ -314,21 +310,21 @@ class AppAnnotationService:
# get app info
app = (
db.session.query(App)
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.first()
)
if not app:
raise NotFound("App not found")
annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).first()
annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first()
if not annotation:
raise NotFound("Annotation not found")
stmt = (
select(AppAnnotationHitHistory)
.filter(
.where(
AppAnnotationHitHistory.app_id == app_id,
AppAnnotationHitHistory.annotation_id == annotation_id,
)
@@ -341,7 +337,7 @@ class AppAnnotationService:
@classmethod
def get_annotation_by_id(cls, annotation_id: str) -> MessageAnnotation | None:
annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).first()
annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first()
if not annotation:
return None
@@ -361,7 +357,7 @@ class AppAnnotationService:
score: float,
):
# add hit count to annotation
db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).update(
db.session.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).update(
{MessageAnnotation.hit_count: MessageAnnotation.hit_count + 1}, synchronize_session=False
)
@@ -384,16 +380,14 @@ class AppAnnotationService:
# get app info
app = (
db.session.query(App)
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.first()
)
if not app:
raise NotFound("App not found")
annotation_setting = (
db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first()
)
annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
if annotation_setting:
collection_binding_detail = annotation_setting.collection_binding_detail
return {
@@ -412,7 +406,7 @@ class AppAnnotationService:
# get app info
app = (
db.session.query(App)
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.first()
)
@@ -421,7 +415,7 @@ class AppAnnotationService:
annotation_setting = (
db.session.query(AppAnnotationSetting)
.filter(
.where(
AppAnnotationSetting.app_id == app_id,
AppAnnotationSetting.id == annotation_setting_id,
)
+1 -1
View File
@@ -73,7 +73,7 @@ class APIBasedExtensionService:
db.session.query(APIBasedExtension)
.filter_by(tenant_id=extension_data.tenant_id)
.filter_by(name=extension_data.name)
.filter(APIBasedExtension.id != extension_data.id)
.where(APIBasedExtension.id != extension_data.id)
.first()
)
+3 -3
View File
@@ -382,7 +382,7 @@ class AppService:
elif provider_type == "api":
try:
provider: Optional[ApiToolProvider] = (
db.session.query(ApiToolProvider).filter(ApiToolProvider.id == provider_id).first()
db.session.query(ApiToolProvider).where(ApiToolProvider.id == provider_id).first()
)
if provider is None:
raise ValueError(f"provider not found for tool {tool_name}")
@@ -399,7 +399,7 @@ class AppService:
:param app_id: app id
:return: app code
"""
site = db.session.query(Site).filter(Site.app_id == app_id).first()
site = db.session.query(Site).where(Site.app_id == app_id).first()
if not site:
raise ValueError(f"App with id {app_id} not found")
return str(site.code)
@@ -411,7 +411,7 @@ class AppService:
:param app_code: app code
:return: app id
"""
site = db.session.query(Site).filter(Site.code == app_code).first()
site = db.session.query(Site).where(Site.code == app_code).first()
if not site:
raise ValueError(f"App with code {app_code} not found")
return str(site.app_id)
+1 -1
View File
@@ -135,7 +135,7 @@ class AudioService:
uuid.UUID(message_id)
except ValueError:
return None
message = db.session.query(Message).filter(Message.id == message_id).first()
message = db.session.query(Message).where(Message.id == message_id).first()
if message is None:
return None
if message.answer == "" and message.status == MessageStatus.NORMAL:
+3 -3
View File
@@ -11,7 +11,7 @@ class ApiKeyAuthService:
def get_provider_auth_list(tenant_id: str) -> list:
data_source_api_key_bindings = (
db.session.query(DataSourceApiKeyAuthBinding)
.filter(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.disabled.is_(False))
.where(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.disabled.is_(False))
.all()
)
return data_source_api_key_bindings
@@ -36,7 +36,7 @@ class ApiKeyAuthService:
def get_auth_credentials(tenant_id: str, category: str, provider: str):
data_source_api_key_bindings = (
db.session.query(DataSourceApiKeyAuthBinding)
.filter(
.where(
DataSourceApiKeyAuthBinding.tenant_id == tenant_id,
DataSourceApiKeyAuthBinding.category == category,
DataSourceApiKeyAuthBinding.provider == provider,
@@ -53,7 +53,7 @@ class ApiKeyAuthService:
def delete_provider_auth(tenant_id: str, binding_id: str):
data_source_api_key_binding = (
db.session.query(DataSourceApiKeyAuthBinding)
.filter(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.id == binding_id)
.where(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.id == binding_id)
.first()
)
if data_source_api_key_binding:
+1 -1
View File
@@ -75,7 +75,7 @@ class BillingService:
join: Optional[TenantAccountJoin] = (
db.session.query(TenantAccountJoin)
.filter(TenantAccountJoin.tenant_id == tenant_id, TenantAccountJoin.account_id == current_user.id)
.where(TenantAccountJoin.tenant_id == tenant_id, TenantAccountJoin.account_id == current_user.id)
.first()
)
@@ -24,13 +24,13 @@ class ClearFreePlanTenantExpiredLogs:
@classmethod
def process_tenant(cls, flask_app: Flask, tenant_id: str, days: int, batch: int):
with flask_app.app_context():
apps = db.session.query(App).filter(App.tenant_id == tenant_id).all()
apps = db.session.query(App).where(App.tenant_id == tenant_id).all()
app_ids = [app.id for app in apps]
while True:
with Session(db.engine).no_autoflush as session:
messages = (
session.query(Message)
.filter(
.where(
Message.app_id.in_(app_ids),
Message.created_at < datetime.datetime.now() - datetime.timedelta(days=days),
)
@@ -54,7 +54,7 @@ class ClearFreePlanTenantExpiredLogs:
message_ids = [message.id for message in messages]
# delete messages
session.query(Message).filter(
session.query(Message).where(
Message.id.in_(message_ids),
).delete(synchronize_session=False)
@@ -70,7 +70,7 @@ class ClearFreePlanTenantExpiredLogs:
with Session(db.engine).no_autoflush as session:
conversations = (
session.query(Conversation)
.filter(
.where(
Conversation.app_id.in_(app_ids),
Conversation.updated_at < datetime.datetime.now() - datetime.timedelta(days=days),
)
@@ -93,7 +93,7 @@ class ClearFreePlanTenantExpiredLogs:
)
conversation_ids = [conversation.id for conversation in conversations]
session.query(Conversation).filter(
session.query(Conversation).where(
Conversation.id.in_(conversation_ids),
).delete(synchronize_session=False)
session.commit()
@@ -276,7 +276,7 @@ class ClearFreePlanTenantExpiredLogs:
for test_interval in test_intervals:
tenant_count = (
session.query(Tenant.id)
.filter(Tenant.created_at.between(current_time, current_time + test_interval))
.where(Tenant.created_at.between(current_time, current_time + test_interval))
.count()
)
if tenant_count <= 100:
@@ -301,7 +301,7 @@ class ClearFreePlanTenantExpiredLogs:
rs = (
session.query(Tenant.id)
.filter(Tenant.created_at.between(current_time, batch_end))
.where(Tenant.created_at.between(current_time, batch_end))
.order_by(Tenant.created_at)
)
+2 -2
View File
@@ -123,7 +123,7 @@ class ConversationService:
# get conversation first message
message = (
db.session.query(Message)
.filter(Message.app_id == app_model.id, Message.conversation_id == conversation.id)
.where(Message.app_id == app_model.id, Message.conversation_id == conversation.id)
.order_by(Message.created_at.asc())
.first()
)
@@ -148,7 +148,7 @@ class ConversationService:
def get_conversation(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]):
conversation = (
db.session.query(Conversation)
.filter(
.where(
Conversation.id == conversation_id,
Conversation.app_id == app_model.id,
Conversation.from_source == ("api" if isinstance(user, EndUser) else "console"),
+50 -50
View File
@@ -80,7 +80,7 @@ from tasks.sync_website_document_indexing_task import sync_website_document_inde
class DatasetService:
@staticmethod
def get_datasets(page, per_page, tenant_id=None, user=None, search=None, tag_ids=None, include_all=False):
query = select(Dataset).filter(Dataset.tenant_id == tenant_id).order_by(Dataset.created_at.desc())
query = select(Dataset).where(Dataset.tenant_id == tenant_id).order_by(Dataset.created_at.desc())
if user:
# get permitted dataset ids
@@ -92,14 +92,14 @@ class DatasetService:
if user.current_role == TenantAccountRole.DATASET_OPERATOR:
# only show datasets that the user has permission to access
if permitted_dataset_ids:
query = query.filter(Dataset.id.in_(permitted_dataset_ids))
query = query.where(Dataset.id.in_(permitted_dataset_ids))
else:
return [], 0
else:
if user.current_role != TenantAccountRole.OWNER or not include_all:
# show all datasets that the user has permission to access
if permitted_dataset_ids:
query = query.filter(
query = query.where(
db.or_(
Dataset.permission == DatasetPermissionEnum.ALL_TEAM,
db.and_(
@@ -112,7 +112,7 @@ class DatasetService:
)
)
else:
query = query.filter(
query = query.where(
db.or_(
Dataset.permission == DatasetPermissionEnum.ALL_TEAM,
db.and_(
@@ -122,15 +122,15 @@ class DatasetService:
)
else:
# if no user, only show datasets that are shared with all team members
query = query.filter(Dataset.permission == DatasetPermissionEnum.ALL_TEAM)
query = query.where(Dataset.permission == DatasetPermissionEnum.ALL_TEAM)
if search:
query = query.filter(Dataset.name.ilike(f"%{search}%"))
query = query.where(Dataset.name.ilike(f"%{search}%"))
if tag_ids:
target_ids = TagService.get_target_ids_by_tag_ids("knowledge", tenant_id, tag_ids)
if target_ids:
query = query.filter(Dataset.id.in_(target_ids))
query = query.where(Dataset.id.in_(target_ids))
else:
return [], 0
@@ -143,7 +143,7 @@ class DatasetService:
# get the latest process rule
dataset_process_rule = (
db.session.query(DatasetProcessRule)
.filter(DatasetProcessRule.dataset_id == dataset_id)
.where(DatasetProcessRule.dataset_id == dataset_id)
.order_by(DatasetProcessRule.created_at.desc())
.limit(1)
.one_or_none()
@@ -158,7 +158,7 @@ class DatasetService:
@staticmethod
def get_datasets_by_ids(ids, tenant_id):
stmt = select(Dataset).filter(Dataset.id.in_(ids), Dataset.tenant_id == tenant_id)
stmt = select(Dataset).where(Dataset.id.in_(ids), Dataset.tenant_id == tenant_id)
datasets = db.paginate(select=stmt, page=1, per_page=len(ids), max_per_page=len(ids), error_out=False)
@@ -697,7 +697,7 @@ class DatasetService:
def get_related_apps(dataset_id: str):
return (
db.session.query(AppDatasetJoin)
.filter(AppDatasetJoin.dataset_id == dataset_id)
.where(AppDatasetJoin.dataset_id == dataset_id)
.order_by(db.desc(AppDatasetJoin.created_at))
.all()
)
@@ -714,7 +714,7 @@ class DatasetService:
start_date = datetime.datetime.now() - datetime.timedelta(days=30)
dataset_auto_disable_logs = (
db.session.query(DatasetAutoDisableLog)
.filter(
.where(
DatasetAutoDisableLog.dataset_id == dataset_id,
DatasetAutoDisableLog.created_at >= start_date,
)
@@ -843,7 +843,7 @@ class DocumentService:
def get_document(dataset_id: str, document_id: Optional[str] = None) -> Optional[Document]:
if document_id:
document = (
db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first()
db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
)
return document
else:
@@ -851,7 +851,7 @@ class DocumentService:
@staticmethod
def get_document_by_id(document_id: str) -> Optional[Document]:
document = db.session.query(Document).filter(Document.id == document_id).first()
document = db.session.query(Document).where(Document.id == document_id).first()
return document
@@ -859,7 +859,7 @@ class DocumentService:
def get_document_by_ids(document_ids: list[str]) -> list[Document]:
documents = (
db.session.query(Document)
.filter(
.where(
Document.id.in_(document_ids),
Document.enabled == True,
Document.indexing_status == "completed",
@@ -873,7 +873,7 @@ class DocumentService:
def get_document_by_dataset_id(dataset_id: str) -> list[Document]:
documents = (
db.session.query(Document)
.filter(
.where(
Document.dataset_id == dataset_id,
Document.enabled == True,
)
@@ -886,7 +886,7 @@ class DocumentService:
def get_working_documents_by_dataset_id(dataset_id: str) -> list[Document]:
documents = (
db.session.query(Document)
.filter(
.where(
Document.dataset_id == dataset_id,
Document.enabled == True,
Document.indexing_status == "completed",
@@ -901,7 +901,7 @@ class DocumentService:
def get_error_documents_by_dataset_id(dataset_id: str) -> list[Document]:
documents = (
db.session.query(Document)
.filter(Document.dataset_id == dataset_id, Document.indexing_status.in_(["error", "paused"]))
.where(Document.dataset_id == dataset_id, Document.indexing_status.in_(["error", "paused"]))
.all()
)
return documents
@@ -910,7 +910,7 @@ class DocumentService:
def get_batch_documents(dataset_id: str, batch: str) -> list[Document]:
documents = (
db.session.query(Document)
.filter(
.where(
Document.batch == batch,
Document.dataset_id == dataset_id,
Document.tenant_id == current_user.current_tenant_id,
@@ -922,7 +922,7 @@ class DocumentService:
@staticmethod
def get_document_file_detail(file_id: str):
file_detail = db.session.query(UploadFile).filter(UploadFile.id == file_id).one_or_none()
file_detail = db.session.query(UploadFile).where(UploadFile.id == file_id).one_or_none()
return file_detail
@staticmethod
@@ -950,7 +950,7 @@ class DocumentService:
@staticmethod
def delete_documents(dataset: Dataset, document_ids: list[str]):
documents = db.session.query(Document).filter(Document.id.in_(document_ids)).all()
documents = db.session.query(Document).where(Document.id.in_(document_ids)).all()
file_ids = [
document.data_source_info_dict["upload_file_id"]
for document in documents
@@ -1189,7 +1189,7 @@ class DocumentService:
for file_id in upload_file_list:
file = (
db.session.query(UploadFile)
.filter(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id)
.where(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id)
.first()
)
@@ -1270,7 +1270,7 @@ class DocumentService:
workspace_id = notion_info.workspace_id
data_source_binding = (
db.session.query(DataSourceOauthBinding)
.filter(
.where(
db.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
@@ -1413,7 +1413,7 @@ class DocumentService:
def get_tenant_documents_count():
documents_count = (
db.session.query(Document)
.filter(
.where(
Document.completed_at.isnot(None),
Document.enabled == True,
Document.archived == False,
@@ -1469,7 +1469,7 @@ class DocumentService:
for file_id in upload_file_list:
file = (
db.session.query(UploadFile)
.filter(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id)
.where(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id)
.first()
)
@@ -1489,7 +1489,7 @@ class DocumentService:
workspace_id = notion_info.workspace_id
data_source_binding = (
db.session.query(DataSourceOauthBinding)
.filter(
.where(
db.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
@@ -2005,7 +2005,7 @@ class SegmentService:
with redis_client.lock(lock_name, timeout=600):
max_position = (
db.session.query(func.max(DocumentSegment.position))
.filter(DocumentSegment.document_id == document.id)
.where(DocumentSegment.document_id == document.id)
.scalar()
)
segment_document = DocumentSegment(
@@ -2043,7 +2043,7 @@ class SegmentService:
segment_document.status = "error"
segment_document.error = str(e)
db.session.commit()
segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_document.id).first()
segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_document.id).first()
return segment
@classmethod
@@ -2062,7 +2062,7 @@ class SegmentService:
)
max_position = (
db.session.query(func.max(DocumentSegment.position))
.filter(DocumentSegment.document_id == document.id)
.where(DocumentSegment.document_id == document.id)
.scalar()
)
pre_segment_data_list = []
@@ -2201,7 +2201,7 @@ class SegmentService:
# get the process rule
processing_rule = (
db.session.query(DatasetProcessRule)
.filter(DatasetProcessRule.id == document.dataset_process_rule_id)
.where(DatasetProcessRule.id == document.dataset_process_rule_id)
.first()
)
if not processing_rule:
@@ -2276,7 +2276,7 @@ class SegmentService:
# get the process rule
processing_rule = (
db.session.query(DatasetProcessRule)
.filter(DatasetProcessRule.id == document.dataset_process_rule_id)
.where(DatasetProcessRule.id == document.dataset_process_rule_id)
.first()
)
if not processing_rule:
@@ -2295,7 +2295,7 @@ class SegmentService:
segment.status = "error"
segment.error = str(e)
db.session.commit()
new_segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment.id).first()
new_segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment.id).first()
return new_segment
@classmethod
@@ -2321,7 +2321,7 @@ class SegmentService:
index_node_ids = (
db.session.query(DocumentSegment)
.with_entities(DocumentSegment.index_node_id)
.filter(
.where(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.document_id == document.id,
@@ -2332,7 +2332,7 @@ class SegmentService:
index_node_ids = [index_node_id[0] for index_node_id in index_node_ids]
delete_segment_from_index_task.delay(index_node_ids, dataset.id, document.id)
db.session.query(DocumentSegment).filter(DocumentSegment.id.in_(segment_ids)).delete()
db.session.query(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)).delete()
db.session.commit()
@classmethod
@@ -2340,7 +2340,7 @@ class SegmentService:
if action == "enable":
segments = (
db.session.query(DocumentSegment)
.filter(
.where(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.document_id == document.id,
@@ -2367,7 +2367,7 @@ class SegmentService:
elif action == "disable":
segments = (
db.session.query(DocumentSegment)
.filter(
.where(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.document_id == document.id,
@@ -2404,7 +2404,7 @@ class SegmentService:
index_node_hash = helper.generate_text_hash(content)
child_chunk_count = (
db.session.query(ChildChunk)
.filter(
.where(
ChildChunk.tenant_id == current_user.current_tenant_id,
ChildChunk.dataset_id == dataset.id,
ChildChunk.document_id == document.id,
@@ -2414,7 +2414,7 @@ class SegmentService:
)
max_position = (
db.session.query(func.max(ChildChunk.position))
.filter(
.where(
ChildChunk.tenant_id == current_user.current_tenant_id,
ChildChunk.dataset_id == dataset.id,
ChildChunk.document_id == document.id,
@@ -2457,7 +2457,7 @@ class SegmentService:
) -> list[ChildChunk]:
child_chunks = (
db.session.query(ChildChunk)
.filter(
.where(
ChildChunk.dataset_id == dataset.id,
ChildChunk.document_id == document.id,
ChildChunk.segment_id == segment.id,
@@ -2578,7 +2578,7 @@ class SegmentService:
"""Get a child chunk by its ID."""
result = (
db.session.query(ChildChunk)
.filter(ChildChunk.id == child_chunk_id, ChildChunk.tenant_id == tenant_id)
.where(ChildChunk.id == child_chunk_id, ChildChunk.tenant_id == tenant_id)
.first()
)
return result if isinstance(result, ChildChunk) else None
@@ -2594,15 +2594,15 @@ class SegmentService:
limit: int = 20,
):
"""Get segments for a document with optional filtering."""
query = select(DocumentSegment).filter(
query = select(DocumentSegment).where(
DocumentSegment.document_id == document_id, DocumentSegment.tenant_id == tenant_id
)
if status_list:
query = query.filter(DocumentSegment.status.in_(status_list))
query = query.where(DocumentSegment.status.in_(status_list))
if keyword:
query = query.filter(DocumentSegment.content.ilike(f"%{keyword}%"))
query = query.where(DocumentSegment.content.ilike(f"%{keyword}%"))
query = query.order_by(DocumentSegment.position.asc())
paginated_segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
@@ -2615,7 +2615,7 @@ class SegmentService:
) -> tuple[DocumentSegment, Document]:
"""Update a segment by its ID with validation and checks."""
# check dataset
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise NotFound("Dataset not found.")
@@ -2647,7 +2647,7 @@ class SegmentService:
# check segment
segment = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == user_id)
.where(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == user_id)
.first()
)
if not segment:
@@ -2664,7 +2664,7 @@ class SegmentService:
"""Get a segment by its ID."""
result = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == tenant_id)
.where(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == tenant_id)
.first()
)
return result if isinstance(result, DocumentSegment) else None
@@ -2677,7 +2677,7 @@ class DatasetCollectionBindingService:
) -> DatasetCollectionBinding:
dataset_collection_binding = (
db.session.query(DatasetCollectionBinding)
.filter(
.where(
DatasetCollectionBinding.provider_name == provider_name,
DatasetCollectionBinding.model_name == model_name,
DatasetCollectionBinding.type == collection_type,
@@ -2703,7 +2703,7 @@ class DatasetCollectionBindingService:
) -> DatasetCollectionBinding:
dataset_collection_binding = (
db.session.query(DatasetCollectionBinding)
.filter(
.where(
DatasetCollectionBinding.id == collection_binding_id, DatasetCollectionBinding.type == collection_type
)
.order_by(DatasetCollectionBinding.created_at)
@@ -2722,7 +2722,7 @@ class DatasetPermissionService:
db.session.query(
DatasetPermission.account_id,
)
.filter(DatasetPermission.dataset_id == dataset_id)
.where(DatasetPermission.dataset_id == dataset_id)
.all()
)
@@ -2735,7 +2735,7 @@ class DatasetPermissionService:
@classmethod
def update_partial_member_list(cls, tenant_id, dataset_id, user_list):
try:
db.session.query(DatasetPermission).filter(DatasetPermission.dataset_id == dataset_id).delete()
db.session.query(DatasetPermission).where(DatasetPermission.dataset_id == dataset_id).delete()
permissions = []
for user in user_list:
permission = DatasetPermission(
@@ -2771,7 +2771,7 @@ class DatasetPermissionService:
@classmethod
def clear_partial_member_list(cls, dataset_id):
try:
db.session.query(DatasetPermission).filter(DatasetPermission.dataset_id == dataset_id).delete()
db.session.query(DatasetPermission).where(DatasetPermission.dataset_id == dataset_id).delete()
db.session.commit()
except Exception as e:
db.session.rollback()
+2 -2
View File
@@ -30,11 +30,11 @@ class ExternalDatasetService:
) -> tuple[list[ExternalKnowledgeApis], int | None]:
query = (
select(ExternalKnowledgeApis)
.filter(ExternalKnowledgeApis.tenant_id == tenant_id)
.where(ExternalKnowledgeApis.tenant_id == tenant_id)
.order_by(ExternalKnowledgeApis.created_at.desc())
)
if search:
query = query.filter(ExternalKnowledgeApis.name.ilike(f"%{search}%"))
query = query.where(ExternalKnowledgeApis.name.ilike(f"%{search}%"))
external_knowledge_apis = db.paginate(
select=query, page=page, per_page=per_page, max_per_page=100, error_out=False
+4 -4
View File
@@ -144,7 +144,7 @@ class FileService:
@staticmethod
def get_file_preview(file_id: str):
upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first()
upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
if not upload_file:
raise NotFound("File not found")
@@ -167,7 +167,7 @@ class FileService:
if not result:
raise NotFound("File not found or signature is invalid")
upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first()
upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
if not upload_file:
raise NotFound("File not found or signature is invalid")
@@ -187,7 +187,7 @@ class FileService:
if not result:
raise NotFound("File not found or signature is invalid")
upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first()
upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
if not upload_file:
raise NotFound("File not found or signature is invalid")
@@ -198,7 +198,7 @@ class FileService:
@staticmethod
def get_public_image_preview(file_id: str):
upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first()
upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
if not upload_file:
raise NotFound("File not found or signature is invalid")
+10 -12
View File
@@ -50,7 +50,7 @@ class MessageService:
if first_id:
first_message = (
db.session.query(Message)
.filter(Message.conversation_id == conversation.id, Message.id == first_id)
.where(Message.conversation_id == conversation.id, Message.id == first_id)
.first()
)
@@ -59,7 +59,7 @@ class MessageService:
history_messages = (
db.session.query(Message)
.filter(
.where(
Message.conversation_id == conversation.id,
Message.created_at < first_message.created_at,
Message.id != first_message.id,
@@ -71,7 +71,7 @@ class MessageService:
else:
history_messages = (
db.session.query(Message)
.filter(Message.conversation_id == conversation.id)
.where(Message.conversation_id == conversation.id)
.order_by(Message.created_at.desc())
.limit(fetch_limit)
.all()
@@ -109,19 +109,19 @@ class MessageService:
app_model=app_model, user=user, conversation_id=conversation_id
)
base_query = base_query.filter(Message.conversation_id == conversation.id)
base_query = base_query.where(Message.conversation_id == conversation.id)
if include_ids is not None:
base_query = base_query.filter(Message.id.in_(include_ids))
base_query = base_query.where(Message.id.in_(include_ids))
if last_id:
last_message = base_query.filter(Message.id == last_id).first()
last_message = base_query.where(Message.id == last_id).first()
if not last_message:
raise LastMessageNotExistsError()
history_messages = (
base_query.filter(Message.created_at < last_message.created_at, Message.id != last_message.id)
base_query.where(Message.created_at < last_message.created_at, Message.id != last_message.id)
.order_by(Message.created_at.desc())
.limit(fetch_limit)
.all()
@@ -183,7 +183,7 @@ class MessageService:
offset = (page - 1) * limit
feedbacks = (
db.session.query(MessageFeedback)
.filter(MessageFeedback.app_id == app_model.id)
.where(MessageFeedback.app_id == app_model.id)
.order_by(MessageFeedback.created_at.desc(), MessageFeedback.id.desc())
.limit(limit)
.offset(offset)
@@ -196,7 +196,7 @@ class MessageService:
def get_message(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str):
message = (
db.session.query(Message)
.filter(
.where(
Message.id == message_id,
Message.app_id == app_model.id,
Message.from_source == ("api" if isinstance(user, EndUser) else "console"),
@@ -248,9 +248,7 @@ class MessageService:
if not conversation.override_model_configs:
app_model_config = (
db.session.query(AppModelConfig)
.filter(
AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id
)
.where(AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id)
.first()
)
else:
+4 -4
View File
@@ -103,7 +103,7 @@ class ModelLoadBalancingService:
# Get load balancing configurations
load_balancing_configs = (
db.session.query(LoadBalancingModelConfig)
.filter(
.where(
LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(),
@@ -219,7 +219,7 @@ class ModelLoadBalancingService:
# Get load balancing configurations
load_balancing_model_config = (
db.session.query(LoadBalancingModelConfig)
.filter(
.where(
LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(),
@@ -307,7 +307,7 @@ class ModelLoadBalancingService:
current_load_balancing_configs = (
db.session.query(LoadBalancingModelConfig)
.filter(
.where(
LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(),
@@ -457,7 +457,7 @@ class ModelLoadBalancingService:
# Get load balancing config
load_balancing_model_config = (
db.session.query(LoadBalancingModelConfig)
.filter(
.where(
LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider,
LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(),
+7 -7
View File
@@ -17,7 +17,7 @@ class OpsService:
"""
trace_config_data: Optional[TraceAppConfig] = (
db.session.query(TraceAppConfig)
.filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
.where(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
.first()
)
@@ -25,7 +25,7 @@ class OpsService:
return None
# decrypt_token and obfuscated_token
app = db.session.query(App).filter(App.id == app_id).first()
app = db.session.query(App).where(App.id == app_id).first()
if not app:
return None
tenant_id = app.tenant_id
@@ -148,7 +148,7 @@ class OpsService:
# check if trace config already exists
trace_config_data: Optional[TraceAppConfig] = (
db.session.query(TraceAppConfig)
.filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
.where(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
.first()
)
@@ -156,7 +156,7 @@ class OpsService:
return None
# get tenant id
app = db.session.query(App).filter(App.id == app_id).first()
app = db.session.query(App).where(App.id == app_id).first()
if not app:
return None
tenant_id = app.tenant_id
@@ -190,7 +190,7 @@ class OpsService:
# check if trace config already exists
current_trace_config = (
db.session.query(TraceAppConfig)
.filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
.where(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
.first()
)
@@ -198,7 +198,7 @@ class OpsService:
return None
# get tenant id
app = db.session.query(App).filter(App.id == app_id).first()
app = db.session.query(App).where(App.id == app_id).first()
if not app:
return None
tenant_id = app.tenant_id
@@ -227,7 +227,7 @@ class OpsService:
"""
trace_config = (
db.session.query(TraceAppConfig)
.filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
.where(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
.first()
)
+6 -6
View File
@@ -101,7 +101,7 @@ class PluginMigration:
for test_interval in test_intervals:
tenant_count = (
session.query(Tenant.id)
.filter(Tenant.created_at.between(current_time, current_time + test_interval))
.where(Tenant.created_at.between(current_time, current_time + test_interval))
.count()
)
if tenant_count <= 100:
@@ -126,7 +126,7 @@ class PluginMigration:
rs = (
session.query(Tenant.id)
.filter(Tenant.created_at.between(current_time, batch_end))
.where(Tenant.created_at.between(current_time, batch_end))
.order_by(Tenant.created_at)
)
@@ -212,7 +212,7 @@ class PluginMigration:
Extract tool tables.
"""
with Session(db.engine) as session:
rs = session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all()
rs = session.query(BuiltinToolProvider).where(BuiltinToolProvider.tenant_id == tenant_id).all()
result = []
for row in rs:
result.append(ToolProviderID(row.provider).plugin_id)
@@ -226,7 +226,7 @@ class PluginMigration:
"""
with Session(db.engine) as session:
rs = session.query(Workflow).filter(Workflow.tenant_id == tenant_id).all()
rs = session.query(Workflow).where(Workflow.tenant_id == tenant_id).all()
result = []
for row in rs:
graph = row.graph_dict
@@ -249,7 +249,7 @@ class PluginMigration:
Extract app tables.
"""
with Session(db.engine) as session:
apps = session.query(App).filter(App.tenant_id == tenant_id).all()
apps = session.query(App).where(App.tenant_id == tenant_id).all()
if not apps:
return []
@@ -257,7 +257,7 @@ class PluginMigration:
app.app_model_config_id for app in apps if app.is_agent or app.mode == AppMode.AGENT_CHAT.value
]
rs = session.query(AppModelConfig).filter(AppModelConfig.id.in_(agent_app_model_config_ids)).all()
rs = session.query(AppModelConfig).where(AppModelConfig.id.in_(agent_app_model_config_ids)).all()
result = []
for row in rs:
agent_config = row.agent_mode_dict
@@ -51,7 +51,7 @@ class PluginParameterService:
with Session(db.engine) as session:
db_record = (
session.query(BuiltinToolProvider)
.filter(
.where(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider,
)
@@ -8,7 +8,7 @@ class PluginPermissionService:
@staticmethod
def get_permission(tenant_id: str) -> TenantPluginPermission | None:
with Session(db.engine) as session:
return session.query(TenantPluginPermission).filter(TenantPluginPermission.tenant_id == tenant_id).first()
return session.query(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).first()
@staticmethod
def change_permission(
@@ -18,7 +18,7 @@ class PluginPermissionService:
):
with Session(db.engine) as session:
permission = (
session.query(TenantPluginPermission).filter(TenantPluginPermission.tenant_id == tenant_id).first()
session.query(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).first()
)
if not permission:
permission = TenantPluginPermission(
@@ -33,14 +33,14 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase):
"""
recommended_apps = (
db.session.query(RecommendedApp)
.filter(RecommendedApp.is_listed == True, RecommendedApp.language == language)
.where(RecommendedApp.is_listed == True, RecommendedApp.language == language)
.all()
)
if len(recommended_apps) == 0:
recommended_apps = (
db.session.query(RecommendedApp)
.filter(RecommendedApp.is_listed == True, RecommendedApp.language == languages[0])
.where(RecommendedApp.is_listed == True, RecommendedApp.language == languages[0])
.all()
)
@@ -83,7 +83,7 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase):
# is in public recommended list
recommended_app = (
db.session.query(RecommendedApp)
.filter(RecommendedApp.is_listed == True, RecommendedApp.app_id == app_id)
.where(RecommendedApp.is_listed == True, RecommendedApp.app_id == app_id)
.first()
)
@@ -91,7 +91,7 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase):
return None
# get app detail
app_model = db.session.query(App).filter(App.id == app_id).first()
app_model = db.session.query(App).where(App.id == app_id).first()
if not app_model or not app_model.is_public:
return None
+3 -3
View File
@@ -17,7 +17,7 @@ class SavedMessageService:
raise ValueError("User is required")
saved_messages = (
db.session.query(SavedMessage)
.filter(
.where(
SavedMessage.app_id == app_model.id,
SavedMessage.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
SavedMessage.created_by == user.id,
@@ -37,7 +37,7 @@ class SavedMessageService:
return
saved_message = (
db.session.query(SavedMessage)
.filter(
.where(
SavedMessage.app_id == app_model.id,
SavedMessage.message_id == message_id,
SavedMessage.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
@@ -67,7 +67,7 @@ class SavedMessageService:
return
saved_message = (
db.session.query(SavedMessage)
.filter(
.where(
SavedMessage.app_id == app_model.id,
SavedMessage.message_id == message_id,
SavedMessage.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
+14 -14
View File
@@ -16,10 +16,10 @@ class TagService:
query = (
db.session.query(Tag.id, Tag.type, Tag.name, func.count(TagBinding.id).label("binding_count"))
.outerjoin(TagBinding, Tag.id == TagBinding.tag_id)
.filter(Tag.type == tag_type, Tag.tenant_id == current_tenant_id)
.where(Tag.type == tag_type, Tag.tenant_id == current_tenant_id)
)
if keyword:
query = query.filter(db.and_(Tag.name.ilike(f"%{keyword}%")))
query = query.where(db.and_(Tag.name.ilike(f"%{keyword}%")))
query = query.group_by(Tag.id, Tag.type, Tag.name, Tag.created_at)
results: list = query.order_by(Tag.created_at.desc()).all()
return results
@@ -28,7 +28,7 @@ class TagService:
def get_target_ids_by_tag_ids(tag_type: str, current_tenant_id: str, tag_ids: list) -> list:
tags = (
db.session.query(Tag)
.filter(Tag.id.in_(tag_ids), Tag.tenant_id == current_tenant_id, Tag.type == tag_type)
.where(Tag.id.in_(tag_ids), Tag.tenant_id == current_tenant_id, Tag.type == tag_type)
.all()
)
if not tags:
@@ -36,7 +36,7 @@ class TagService:
tag_ids = [tag.id for tag in tags]
tag_bindings = (
db.session.query(TagBinding.target_id)
.filter(TagBinding.tag_id.in_(tag_ids), TagBinding.tenant_id == current_tenant_id)
.where(TagBinding.tag_id.in_(tag_ids), TagBinding.tenant_id == current_tenant_id)
.all()
)
if not tag_bindings:
@@ -50,7 +50,7 @@ class TagService:
return []
tags = (
db.session.query(Tag)
.filter(Tag.name == tag_name, Tag.tenant_id == current_tenant_id, Tag.type == tag_type)
.where(Tag.name == tag_name, Tag.tenant_id == current_tenant_id, Tag.type == tag_type)
.all()
)
if not tags:
@@ -62,7 +62,7 @@ class TagService:
tags = (
db.session.query(Tag)
.join(TagBinding, Tag.id == TagBinding.tag_id)
.filter(
.where(
TagBinding.target_id == target_id,
TagBinding.tenant_id == current_tenant_id,
Tag.tenant_id == current_tenant_id,
@@ -92,7 +92,7 @@ class TagService:
def update_tags(args: dict, tag_id: str) -> Tag:
if TagService.get_tag_by_tag_name(args.get("type", ""), current_user.current_tenant_id, args.get("name", "")):
raise ValueError("Tag name already exists")
tag = db.session.query(Tag).filter(Tag.id == tag_id).first()
tag = db.session.query(Tag).where(Tag.id == tag_id).first()
if not tag:
raise NotFound("Tag not found")
tag.name = args["name"]
@@ -101,17 +101,17 @@ class TagService:
@staticmethod
def get_tag_binding_count(tag_id: str) -> int:
count = db.session.query(TagBinding).filter(TagBinding.tag_id == tag_id).count()
count = db.session.query(TagBinding).where(TagBinding.tag_id == tag_id).count()
return count
@staticmethod
def delete_tag(tag_id: str):
tag = db.session.query(Tag).filter(Tag.id == tag_id).first()
tag = db.session.query(Tag).where(Tag.id == tag_id).first()
if not tag:
raise NotFound("Tag not found")
db.session.delete(tag)
# delete tag binding
tag_bindings = db.session.query(TagBinding).filter(TagBinding.tag_id == tag_id).all()
tag_bindings = db.session.query(TagBinding).where(TagBinding.tag_id == tag_id).all()
if tag_bindings:
for tag_binding in tag_bindings:
db.session.delete(tag_binding)
@@ -125,7 +125,7 @@ class TagService:
for tag_id in args["tag_ids"]:
tag_binding = (
db.session.query(TagBinding)
.filter(TagBinding.tag_id == tag_id, TagBinding.target_id == args["target_id"])
.where(TagBinding.tag_id == tag_id, TagBinding.target_id == args["target_id"])
.first()
)
if tag_binding:
@@ -146,7 +146,7 @@ class TagService:
# delete tag binding
tag_bindings = (
db.session.query(TagBinding)
.filter(TagBinding.target_id == args["target_id"], TagBinding.tag_id == (args["tag_id"]))
.where(TagBinding.target_id == args["target_id"], TagBinding.tag_id == (args["tag_id"]))
.first()
)
if tag_bindings:
@@ -158,7 +158,7 @@ class TagService:
if type == "knowledge":
dataset = (
db.session.query(Dataset)
.filter(Dataset.tenant_id == current_user.current_tenant_id, Dataset.id == target_id)
.where(Dataset.tenant_id == current_user.current_tenant_id, Dataset.id == target_id)
.first()
)
if not dataset:
@@ -166,7 +166,7 @@ class TagService:
elif type == "app":
app = (
db.session.query(App)
.filter(App.tenant_id == current_user.current_tenant_id, App.id == target_id)
.where(App.tenant_id == current_user.current_tenant_id, App.id == target_id)
.first()
)
if not app:
@@ -119,7 +119,7 @@ class ApiToolManageService:
# check if the provider exists
provider = (
db.session.query(ApiToolProvider)
.filter(
.where(
ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == provider_name,
)
@@ -210,7 +210,7 @@ class ApiToolManageService:
"""
provider: ApiToolProvider | None = (
db.session.query(ApiToolProvider)
.filter(
.where(
ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == provider_name,
)
@@ -257,7 +257,7 @@ class ApiToolManageService:
# check if the provider exists
provider = (
db.session.query(ApiToolProvider)
.filter(
.where(
ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == original_provider,
)
@@ -326,7 +326,7 @@ class ApiToolManageService:
"""
provider = (
db.session.query(ApiToolProvider)
.filter(
.where(
ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == provider_name,
)
@@ -376,7 +376,7 @@ class ApiToolManageService:
db_provider = (
db.session.query(ApiToolProvider)
.filter(
.where(
ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == provider_name,
)
@@ -444,7 +444,7 @@ class ApiToolManageService:
"""
# get all api providers
db_providers: list[ApiToolProvider] = (
db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all() or []
db.session.query(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id).all() or []
)
result: list[ToolProviderApiEntity] = []
@@ -154,7 +154,7 @@ class BuiltinToolManageService:
# get if the provider exists
db_provider = (
session.query(BuiltinToolProvider)
.filter(
.where(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.id == credential_id,
)
@@ -404,7 +404,7 @@ class BuiltinToolManageService:
with Session(db.engine) as session:
db_provider = (
session.query(BuiltinToolProvider)
.filter(
.where(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.id == credential_id,
)
@@ -613,7 +613,7 @@ class BuiltinToolManageService:
if provider_id_entity.organization != "langgenius":
provider = (
session.query(BuiltinToolProvider)
.filter(
.where(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == full_provider_name,
)
@@ -626,7 +626,7 @@ class BuiltinToolManageService:
else:
provider = (
session.query(BuiltinToolProvider)
.filter(
.where(
BuiltinToolProvider.tenant_id == tenant_id,
(BuiltinToolProvider.provider == provider_name)
| (BuiltinToolProvider.provider == full_provider_name),
@@ -647,7 +647,7 @@ class BuiltinToolManageService:
# it's an old provider without organization
return (
session.query(BuiltinToolProvider)
.filter(BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.provider == provider_name)
.where(BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.provider == provider_name)
.order_by(
BuiltinToolProvider.is_default.desc(), # default=True first
BuiltinToolProvider.created_at.asc(), # oldest first
@@ -31,7 +31,7 @@ class MCPToolManageService:
def get_mcp_provider_by_provider_id(provider_id: str, tenant_id: str) -> MCPToolProvider:
res = (
db.session.query(MCPToolProvider)
.filter(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.id == provider_id)
.where(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.id == provider_id)
.first()
)
if not res:
@@ -42,7 +42,7 @@ class MCPToolManageService:
def get_mcp_provider_by_server_identifier(server_identifier: str, tenant_id: str) -> MCPToolProvider:
res = (
db.session.query(MCPToolProvider)
.filter(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == server_identifier)
.where(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == server_identifier)
.first()
)
if not res:
@@ -63,7 +63,7 @@ class MCPToolManageService:
server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
existing_provider = (
db.session.query(MCPToolProvider)
.filter(
.where(
MCPToolProvider.tenant_id == tenant_id,
or_(
MCPToolProvider.name == name,
@@ -100,7 +100,7 @@ class MCPToolManageService:
def retrieve_mcp_tools(tenant_id: str, for_list: bool = False) -> list[ToolProviderApiEntity]:
mcp_providers = (
db.session.query(MCPToolProvider)
.filter(MCPToolProvider.tenant_id == tenant_id)
.where(MCPToolProvider.tenant_id == tenant_id)
.order_by(MCPToolProvider.name)
.all()
)
@@ -43,7 +43,7 @@ class WorkflowToolManageService:
# check if the name is unique
existing_workflow_tool_provider = (
db.session.query(WorkflowToolProvider)
.filter(
.where(
WorkflowToolProvider.tenant_id == tenant_id,
# name or app_id
or_(WorkflowToolProvider.name == name, WorkflowToolProvider.app_id == workflow_app_id),
@@ -54,7 +54,7 @@ class WorkflowToolManageService:
if existing_workflow_tool_provider is not None:
raise ValueError(f"Tool with name {name} or app_id {workflow_app_id} already exists")
app: App | None = db.session.query(App).filter(App.id == workflow_app_id, App.tenant_id == tenant_id).first()
app: App | None = db.session.query(App).where(App.id == workflow_app_id, App.tenant_id == tenant_id).first()
if app is None:
raise ValueError(f"App {workflow_app_id} not found")
@@ -123,7 +123,7 @@ class WorkflowToolManageService:
# check if the name is unique
existing_workflow_tool_provider = (
db.session.query(WorkflowToolProvider)
.filter(
.where(
WorkflowToolProvider.tenant_id == tenant_id,
WorkflowToolProvider.name == name,
WorkflowToolProvider.id != workflow_tool_id,
@@ -136,7 +136,7 @@ class WorkflowToolManageService:
workflow_tool_provider: WorkflowToolProvider | None = (
db.session.query(WorkflowToolProvider)
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
.first()
)
@@ -144,7 +144,7 @@ class WorkflowToolManageService:
raise ValueError(f"Tool {workflow_tool_id} not found")
app: App | None = (
db.session.query(App).filter(App.id == workflow_tool_provider.app_id, App.tenant_id == tenant_id).first()
db.session.query(App).where(App.id == workflow_tool_provider.app_id, App.tenant_id == tenant_id).first()
)
if app is None:
@@ -186,7 +186,7 @@ class WorkflowToolManageService:
:param tenant_id: the tenant id
:return: the list of tools
"""
db_tools = db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all()
db_tools = db.session.query(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id).all()
tools: list[WorkflowToolProviderController] = []
for provider in db_tools:
@@ -224,7 +224,7 @@ class WorkflowToolManageService:
:param tenant_id: the tenant id
:param workflow_tool_id: the workflow tool id
"""
db.session.query(WorkflowToolProvider).filter(
db.session.query(WorkflowToolProvider).where(
WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id
).delete()
@@ -243,7 +243,7 @@ class WorkflowToolManageService:
"""
db_tool: WorkflowToolProvider | None = (
db.session.query(WorkflowToolProvider)
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
.first()
)
return cls._get_workflow_tool(tenant_id, db_tool)
@@ -259,7 +259,7 @@ class WorkflowToolManageService:
"""
db_tool: WorkflowToolProvider | None = (
db.session.query(WorkflowToolProvider)
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.app_id == workflow_app_id)
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.app_id == workflow_app_id)
.first()
)
return cls._get_workflow_tool(tenant_id, db_tool)
@@ -275,7 +275,7 @@ class WorkflowToolManageService:
raise ValueError("Tool not found")
workflow_app: App | None = (
db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == db_tool.tenant_id).first()
db.session.query(App).where(App.id == db_tool.app_id, App.tenant_id == db_tool.tenant_id).first()
)
if workflow_app is None:
@@ -318,7 +318,7 @@ class WorkflowToolManageService:
"""
db_tool: WorkflowToolProvider | None = (
db.session.query(WorkflowToolProvider)
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
.first()
)
+1 -1
View File
@@ -36,7 +36,7 @@ class VectorService:
# get the process rule
processing_rule = (
db.session.query(DatasetProcessRule)
.filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
.where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
.first()
)
if not processing_rule:
+2 -2
View File
@@ -65,7 +65,7 @@ class WebConversationService:
return
pinned_conversation = (
db.session.query(PinnedConversation)
.filter(
.where(
PinnedConversation.app_id == app_model.id,
PinnedConversation.conversation_id == conversation_id,
PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
@@ -97,7 +97,7 @@ class WebConversationService:
return
pinned_conversation = (
db.session.query(PinnedConversation)
.filter(
.where(
PinnedConversation.app_id == app_model.id,
PinnedConversation.conversation_id == conversation_id,
PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
+3 -3
View File
@@ -52,7 +52,7 @@ class WebAppAuthService:
@classmethod
def get_user_through_email(cls, email: str):
account = db.session.query(Account).filter(Account.email == email).first()
account = db.session.query(Account).where(Account.email == email).first()
if not account:
return None
@@ -91,10 +91,10 @@ class WebAppAuthService:
@classmethod
def create_end_user(cls, app_code, email) -> EndUser:
site = db.session.query(Site).filter(Site.code == app_code).first()
site = db.session.query(Site).where(Site.code == app_code).first()
if not site:
raise NotFound("Site not found.")
app_model = db.session.query(App).filter(App.id == site.app_id).first()
app_model = db.session.query(App).where(App.id == site.app_id).first()
if not app_model:
raise NotFound("App not found.")
end_user = EndUser(
+1 -1
View File
@@ -620,7 +620,7 @@ class WorkflowConverter:
"""
api_based_extension = (
db.session.query(APIBasedExtension)
.filter(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id)
.where(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id)
.first()
)
@@ -138,7 +138,7 @@ class WorkflowDraftVariableService:
)
def get_variable(self, variable_id: str) -> WorkflowDraftVariable | None:
return self._session.query(WorkflowDraftVariable).filter(WorkflowDraftVariable.id == variable_id).first()
return self._session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.id == variable_id).first()
def get_draft_variables_by_selectors(
self,
@@ -166,7 +166,7 @@ class WorkflowDraftVariableService:
def list_variables_without_values(self, app_id: str, page: int, limit: int) -> WorkflowDraftVariableList:
criteria = WorkflowDraftVariable.app_id == app_id
total = None
query = self._session.query(WorkflowDraftVariable).filter(criteria)
query = self._session.query(WorkflowDraftVariable).where(criteria)
if page == 1:
total = query.count()
variables = (
@@ -185,7 +185,7 @@ class WorkflowDraftVariableService:
WorkflowDraftVariable.app_id == app_id,
WorkflowDraftVariable.node_id == node_id,
)
query = self._session.query(WorkflowDraftVariable).filter(*criteria)
query = self._session.query(WorkflowDraftVariable).where(*criteria)
variables = query.order_by(WorkflowDraftVariable.created_at.desc()).all()
return WorkflowDraftVariableList(variables=variables)
@@ -328,7 +328,7 @@ class WorkflowDraftVariableService:
def delete_workflow_variables(self, app_id: str):
(
self._session.query(WorkflowDraftVariable)
.filter(WorkflowDraftVariable.app_id == app_id)
.where(WorkflowDraftVariable.app_id == app_id)
.delete(synchronize_session=False)
)
@@ -379,7 +379,7 @@ class WorkflowDraftVariableService:
if conv_id is not None:
conversation = (
self._session.query(Conversation)
.filter(
.where(
Conversation.id == conv_id,
Conversation.app_id == workflow.app_id,
)
+5 -5
View File
@@ -89,7 +89,7 @@ class WorkflowService:
def is_workflow_exist(self, app_model: App) -> bool:
return (
db.session.query(Workflow)
.filter(
.where(
Workflow.tenant_id == app_model.tenant_id,
Workflow.app_id == app_model.id,
Workflow.version == Workflow.VERSION_DRAFT,
@@ -104,7 +104,7 @@ class WorkflowService:
# fetch draft workflow by app_model
workflow = (
db.session.query(Workflow)
.filter(
.where(
Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.version == "draft"
)
.first()
@@ -117,7 +117,7 @@ class WorkflowService:
# fetch published workflow by workflow_id
workflow = (
db.session.query(Workflow)
.filter(
.where(
Workflow.tenant_id == app_model.tenant_id,
Workflow.app_id == app_model.id,
Workflow.id == workflow_id,
@@ -141,7 +141,7 @@ class WorkflowService:
# fetch published workflow by workflow_id
workflow = (
db.session.query(Workflow)
.filter(
.where(
Workflow.tenant_id == app_model.tenant_id,
Workflow.app_id == app_model.id,
Workflow.id == app_model.workflow_id,
@@ -658,7 +658,7 @@ class WorkflowService:
# Check if there's a tool provider using this specific workflow version
tool_provider = (
session.query(WorkflowToolProvider)
.filter(
.where(
WorkflowToolProvider.tenant_id == workflow.tenant_id,
WorkflowToolProvider.app_id == workflow.app_id,
WorkflowToolProvider.version == workflow.version,
+1 -1
View File
@@ -25,7 +25,7 @@ class WorkspaceService:
# Get role of user
tenant_account_join = (
db.session.query(TenantAccountJoin)
.filter(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == current_user.id)
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == current_user.id)
.first()
)
assert tenant_account_join is not None, "TenantAccountJoin not found"