chore: add ast-grep rule to convert Optional[T] to T | None (#25560)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
-LAN-
2025-09-15 13:06:33 +08:00
committed by GitHub
parent 2e44ebe98d
commit bab4975809
394 changed files with 2555 additions and 2792 deletions
+55 -59
View File
@@ -5,7 +5,7 @@ import secrets
import uuid
from datetime import UTC, datetime, timedelta
from hashlib import sha256
from typing import Any, Optional, cast
from typing import Any, cast
from pydantic import BaseModel
from sqlalchemy import func
@@ -171,7 +171,7 @@ class AccountService:
return token
@staticmethod
def authenticate(email: str, password: str, invite_token: Optional[str] = None) -> Account:
def authenticate(email: str, password: str, invite_token: str | None = None) -> Account:
"""authenticate account with email and password"""
account = db.session.query(Account).filter_by(email=email).first()
@@ -228,9 +228,9 @@ class AccountService:
email: str,
name: str,
interface_language: str,
password: Optional[str] = None,
password: str | None = None,
interface_theme: str = "light",
is_setup: Optional[bool] = False,
is_setup: bool | None = False,
) -> Account:
"""create account"""
if not FeatureService.get_system_features().is_allow_register and not is_setup:
@@ -276,7 +276,7 @@ class AccountService:
@staticmethod
def create_account_and_tenant(
email: str, name: str, interface_language: str, password: Optional[str] = None
email: str, name: str, interface_language: str, password: str | None = None
) -> Account:
"""create account"""
account = AccountService.create_account(
@@ -330,7 +330,7 @@ class AccountService:
"""Link account integrate"""
try:
# Query whether there is an existing binding record for the same provider
account_integrate: Optional[AccountIntegrate] = (
account_integrate: AccountIntegrate | None = (
db.session.query(AccountIntegrate).filter_by(account_id=account.id, provider=provider).first()
)
@@ -391,7 +391,7 @@ class AccountService:
db.session.commit()
@staticmethod
def login(account: Account, *, ip_address: Optional[str] = None) -> TokenPair:
def login(account: Account, *, ip_address: str | None = None) -> TokenPair:
if ip_address:
AccountService.update_login_info(account=account, ip_address=ip_address)
@@ -439,8 +439,8 @@ class AccountService:
@classmethod
def send_reset_password_email(
cls,
account: Optional[Account] = None,
email: Optional[str] = None,
account: Account | None = None,
email: str | None = None,
language: str = "en-US",
is_allow_register: bool = False,
):
@@ -473,8 +473,8 @@ class AccountService:
@classmethod
def send_email_register_email(
cls,
account: Optional[Account] = None,
email: Optional[str] = None,
account: Account | None = None,
email: str | None = None,
language: str = "en-US",
):
account_email = account.email if account else email
@@ -507,11 +507,11 @@ class AccountService:
@classmethod
def send_change_email_email(
cls,
account: Optional[Account] = None,
email: Optional[str] = None,
old_email: Optional[str] = None,
account: Account | None = None,
email: str | None = None,
old_email: str | None = None,
language: str = "en-US",
phase: Optional[str] = None,
phase: str | None = None,
):
account_email = account.email if account else email
if account_email is None:
@@ -538,8 +538,8 @@ class AccountService:
@classmethod
def send_change_email_completed_notify_email(
cls,
account: Optional[Account] = None,
email: Optional[str] = None,
account: Account | None = None,
email: str | None = None,
language: str = "en-US",
):
account_email = account.email if account else email
@@ -554,10 +554,10 @@ class AccountService:
@classmethod
def send_owner_transfer_email(
cls,
account: Optional[Account] = None,
email: Optional[str] = None,
account: Account | None = None,
email: str | None = None,
language: str = "en-US",
workspace_name: Optional[str] = "",
workspace_name: str | None = "",
):
account_email = account.email if account else email
if account_email is None:
@@ -583,10 +583,10 @@ class AccountService:
@classmethod
def send_old_owner_transfer_notify_email(
cls,
account: Optional[Account] = None,
email: Optional[str] = None,
account: Account | None = None,
email: str | None = None,
language: str = "en-US",
workspace_name: Optional[str] = "",
workspace_name: str | None = "",
new_owner_email: str = "",
):
account_email = account.email if account else email
@@ -604,10 +604,10 @@ class AccountService:
@classmethod
def send_new_owner_transfer_notify_email(
cls,
account: Optional[Account] = None,
email: Optional[str] = None,
account: Account | None = None,
email: str | None = None,
language: str = "en-US",
workspace_name: Optional[str] = "",
workspace_name: str | None = "",
):
account_email = account.email if account else email
if account_email is None:
@@ -624,8 +624,8 @@ class AccountService:
def generate_reset_password_token(
cls,
email: str,
account: Optional[Account] = None,
code: Optional[str] = None,
account: Account | None = None,
code: str | None = None,
additional_data: dict[str, Any] = {},
):
if not code:
@@ -640,7 +640,7 @@ class AccountService:
def generate_email_register_token(
cls,
email: str,
code: Optional[str] = None,
code: str | None = None,
additional_data: dict[str, Any] = {},
):
if not code:
@@ -653,9 +653,9 @@ class AccountService:
def generate_change_email_token(
cls,
email: str,
account: Optional[Account] = None,
code: Optional[str] = None,
old_email: Optional[str] = None,
account: Account | None = None,
code: str | None = None,
old_email: str | None = None,
additional_data: dict[str, Any] = {},
):
if not code:
@@ -671,8 +671,8 @@ class AccountService:
def generate_owner_transfer_token(
cls,
email: str,
account: Optional[Account] = None,
code: Optional[str] = None,
account: Account | None = None,
code: str | None = None,
additional_data: dict[str, Any] = {},
):
if not code:
@@ -700,26 +700,26 @@ class AccountService:
TokenManager.revoke_token(token, "owner_transfer")
@classmethod
def get_reset_password_data(cls, token: str) -> Optional[dict[str, Any]]:
def get_reset_password_data(cls, token: str) -> dict[str, Any] | None:
return TokenManager.get_token_data(token, "reset_password")
@classmethod
def get_email_register_data(cls, token: str) -> Optional[dict[str, Any]]:
def get_email_register_data(cls, token: str) -> dict[str, Any] | None:
return TokenManager.get_token_data(token, "email_register")
@classmethod
def get_change_email_data(cls, token: str) -> Optional[dict[str, Any]]:
def get_change_email_data(cls, token: str) -> dict[str, Any] | None:
return TokenManager.get_token_data(token, "change_email")
@classmethod
def get_owner_transfer_data(cls, token: str) -> Optional[dict[str, Any]]:
def get_owner_transfer_data(cls, token: str) -> dict[str, Any] | None:
return TokenManager.get_token_data(token, "owner_transfer")
@classmethod
def send_email_code_login_email(
cls,
account: Optional[Account] = None,
email: Optional[str] = None,
account: Account | None = None,
email: str | None = None,
language: str = "en-US",
):
email = account.email if account else email
@@ -743,7 +743,7 @@ class AccountService:
return token
@classmethod
def get_email_code_login_data(cls, token: str) -> Optional[dict[str, Any]]:
def get_email_code_login_data(cls, token: str) -> dict[str, Any] | None:
return TokenManager.get_token_data(token, "email_code_login")
@classmethod
@@ -965,7 +965,7 @@ class AccountService:
class TenantService:
@staticmethod
def create_tenant(name: str, is_setup: Optional[bool] = False, is_from_dashboard: Optional[bool] = False) -> Tenant:
def create_tenant(name: str, is_setup: bool | None = False, is_from_dashboard: bool | None = False) -> Tenant:
"""Create tenant"""
if (
not FeatureService.get_system_features().is_allow_create_workspace
@@ -996,9 +996,7 @@ class TenantService:
return tenant
@staticmethod
def create_owner_tenant_if_not_exist(
account: Account, name: Optional[str] = None, is_setup: Optional[bool] = False
):
def create_owner_tenant_if_not_exist(account: Account, name: str | None = None, is_setup: bool | None = False):
"""Check if user have a workspace or not"""
available_ta = (
db.session.query(TenantAccountJoin)
@@ -1070,7 +1068,7 @@ class TenantService:
return tenant
@staticmethod
def switch_tenant(account: Account, tenant_id: Optional[str] = None):
def switch_tenant(account: Account, tenant_id: str | None = None):
"""Switch the current workspace for the account"""
# Ensure tenant_id is provided
@@ -1152,7 +1150,7 @@ class TenantService:
)
@staticmethod
def get_user_role(account: Account, tenant: Tenant) -> Optional[TenantAccountRole]:
def get_user_role(account: Account, tenant: Tenant) -> TenantAccountRole | None:
"""Get the role of the current account for a given tenant"""
join = (
db.session.query(TenantAccountJoin)
@@ -1292,13 +1290,13 @@ class RegisterService:
cls,
email,
name,
password: Optional[str] = None,
open_id: Optional[str] = None,
provider: Optional[str] = None,
language: Optional[str] = None,
status: Optional[AccountStatus] = None,
is_setup: Optional[bool] = False,
create_workspace_required: Optional[bool] = True,
password: str | None = None,
open_id: str | None = None,
provider: str | None = None,
language: str | None = None,
status: AccountStatus | None = None,
is_setup: bool | None = False,
create_workspace_required: bool | None = True,
) -> Account:
db.session.begin_nested()
"""Register account"""
@@ -1415,9 +1413,7 @@ class RegisterService:
redis_client.delete(cls._get_invitation_token_key(token))
@classmethod
def get_invitation_if_token_valid(
cls, workspace_id: Optional[str], email: str, token: str
) -> Optional[dict[str, Any]]:
def get_invitation_if_token_valid(cls, workspace_id: str | None, email: str, token: str) -> dict[str, Any] | None:
invitation_data = cls.get_invitation_by_token(token, workspace_id, email)
if not invitation_data:
return None
@@ -1456,8 +1452,8 @@ class RegisterService:
@classmethod
def get_invitation_by_token(
cls, token: str, workspace_id: Optional[str] = None, email: Optional[str] = None
) -> Optional[dict[str, str]]:
cls, token: str, workspace_id: str | None = None, email: str | None = None
) -> dict[str, str] | None:
if workspace_id is not None and email is not None:
email_hash = sha256(email.encode()).hexdigest()
cache_key = f"member_invite_token:{workspace_id}, {email_hash}:{token}"
+2 -2
View File
@@ -1,5 +1,5 @@
import threading
from typing import Any, Optional
from typing import Any
import pytz
@@ -35,7 +35,7 @@ class AgentService:
if not conversation:
raise ValueError(f"Conversation not found: {conversation_id}")
message: Optional[Message] = (
message: Message | None = (
db.session.query(Message)
.where(
Message.id == message_id,
+1 -2
View File
@@ -1,5 +1,4 @@
import uuid
from typing import Optional
import pandas as pd
from sqlalchemy import or_, select
@@ -42,7 +41,7 @@ class AppAnnotationService:
if not message:
raise NotFound("Message Not Exists.")
annotation: Optional[MessageAnnotation] = message.annotation
annotation: MessageAnnotation | None = message.annotation
# save the message annotation
if annotation:
annotation.content = args["answer"]
+19 -20
View File
@@ -4,7 +4,6 @@ import logging
import uuid
from collections.abc import Mapping
from enum import StrEnum
from typing import Optional
from urllib.parse import urlparse
from uuid import uuid4
@@ -61,8 +60,8 @@ class ImportStatus(StrEnum):
class Import(BaseModel):
id: str
status: ImportStatus
app_id: Optional[str] = None
app_mode: Optional[str] = None
app_id: str | None = None
app_mode: str | None = None
current_dsl_version: str = CURRENT_DSL_VERSION
imported_dsl_version: str = ""
error: str = ""
@@ -121,14 +120,14 @@ class AppDslService:
*,
account: Account,
import_mode: str,
yaml_content: Optional[str] = None,
yaml_url: Optional[str] = None,
name: Optional[str] = None,
description: Optional[str] = None,
icon_type: Optional[str] = None,
icon: Optional[str] = None,
icon_background: Optional[str] = None,
app_id: Optional[str] = None,
yaml_content: str | None = None,
yaml_url: str | None = None,
name: str | None = None,
description: str | None = None,
icon_type: str | None = None,
icon: str | None = None,
icon_background: str | None = None,
app_id: str | None = None,
) -> Import:
"""Import an app from YAML content or URL."""
import_id = str(uuid.uuid4())
@@ -407,15 +406,15 @@ class AppDslService:
def _create_or_update_app(
self,
*,
app: Optional[App],
app: App | None,
data: dict,
account: Account,
name: Optional[str] = None,
description: Optional[str] = None,
icon_type: Optional[str] = None,
icon: Optional[str] = None,
icon_background: Optional[str] = None,
dependencies: Optional[list[PluginDependency]] = None,
name: str | None = None,
description: str | None = None,
icon_type: str | None = None,
icon: str | None = None,
icon_background: str | None = None,
dependencies: list[PluginDependency] | None = None,
) -> App:
"""Create a new app or update an existing one."""
app_data = data.get("app", {})
@@ -533,7 +532,7 @@ class AppDslService:
return app
@classmethod
def export_dsl(cls, app_model: App, include_secret: bool = False, workflow_id: Optional[str] = None) -> str:
def export_dsl(cls, app_model: App, include_secret: bool = False, workflow_id: str | None = None) -> str:
"""
Export app
:param app_model: App instance
@@ -566,7 +565,7 @@ class AppDslService:
@classmethod
def _append_workflow_export_data(
cls, *, export_data: dict, app_model: App, include_secret: bool, workflow_id: Optional[str] = None
cls, *, export_data: dict, app_model: App, include_secret: bool, workflow_id: str | None = None
):
"""
Append workflow export data
+2 -2
View File
@@ -1,6 +1,6 @@
import uuid
from collections.abc import Generator, Mapping
from typing import Any, Optional, Union
from typing import Any, Union
from openai._exceptions import RateLimitError
@@ -214,7 +214,7 @@ class AppGenerateService:
)
@classmethod
def _get_workflow(cls, app_model: App, invoke_from: InvokeFrom, workflow_id: Optional[str] = None) -> Workflow:
def _get_workflow(cls, app_model: App, invoke_from: InvokeFrom, workflow_id: str | None = None) -> Workflow:
"""
Get workflow
:param app_model: app model
+3 -3
View File
@@ -1,6 +1,6 @@
import json
import logging
from typing import Optional, TypedDict, cast
from typing import TypedDict, cast
from flask_sqlalchemy.pagination import Pagination
@@ -370,7 +370,7 @@ class AppService:
}
)
else:
app_model_config: Optional[AppModelConfig] = app_model.app_model_config
app_model_config: AppModelConfig | None = app_model.app_model_config
if not app_model_config:
return meta
@@ -393,7 +393,7 @@ class AppService:
meta["tool_icons"][tool_name] = url_prefix + provider_id + "/icon"
elif provider_type == "api":
try:
provider: Optional[ApiToolProvider] = (
provider: ApiToolProvider | None = (
db.session.query(ApiToolProvider).where(ApiToolProvider.id == provider_id).first()
)
if provider is None:
+6 -7
View File
@@ -2,7 +2,6 @@ import io
import logging
import uuid
from collections.abc import Generator
from typing import Optional
from flask import Response, stream_with_context
from werkzeug.datastructures import FileStorage
@@ -30,7 +29,7 @@ logger = logging.getLogger(__name__)
class AudioService:
@classmethod
def transcript_asr(cls, app_model: App, file: FileStorage, end_user: Optional[str] = None):
def transcript_asr(cls, app_model: App, file: FileStorage, end_user: str | None = None):
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
workflow = app_model.workflow
if workflow is None:
@@ -77,15 +76,15 @@ class AudioService:
def transcript_tts(
cls,
app_model: App,
text: Optional[str] = None,
voice: Optional[str] = None,
end_user: Optional[str] = None,
message_id: Optional[str] = None,
text: str | None = None,
voice: str | None = None,
end_user: str | None = None,
message_id: str | None = None,
is_draft: bool = False,
):
from app import app
def invoke_tts(text_content: str, app_model: App, voice: Optional[str] = None, is_draft: bool = False):
def invoke_tts(text_content: str, app_model: App, voice: str | None = None, is_draft: bool = False):
with app.app_context():
if voice is None:
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
+2 -2
View File
@@ -1,5 +1,5 @@
import os
from typing import Literal, Optional
from typing import Literal
import httpx
from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fixed
@@ -73,7 +73,7 @@ class BillingService:
def is_tenant_owner_or_admin(current_user: Account):
tenant_id = current_user.current_tenant_id
join: Optional[TenantAccountJoin] = (
join: TenantAccountJoin | None = (
db.session.query(TenantAccountJoin)
.where(TenantAccountJoin.tenant_id == tenant_id, TenantAccountJoin.account_id == current_user.id)
.first()
+11 -11
View File
@@ -1,7 +1,7 @@
import contextlib
import logging
from collections.abc import Callable, Sequence
from typing import Any, Optional, Union
from typing import Any, Union
from sqlalchemy import asc, desc, func, or_, select
from sqlalchemy.orm import Session
@@ -36,12 +36,12 @@ class ConversationService:
*,
session: Session,
app_model: App,
user: Optional[Union[Account, EndUser]],
last_id: Optional[str],
user: Union[Account, EndUser] | None,
last_id: str | None,
limit: int,
invoke_from: InvokeFrom,
include_ids: Optional[Sequence[str]] = None,
exclude_ids: Optional[Sequence[str]] = None,
include_ids: Sequence[str] | None = None,
exclude_ids: Sequence[str] | None = None,
sort_by: str = "-updated_at",
) -> InfiniteScrollPagination:
if not user:
@@ -118,7 +118,7 @@ class ConversationService:
cls,
app_model: App,
conversation_id: str,
user: Optional[Union[Account, EndUser]],
user: Union[Account, EndUser] | None,
name: str,
auto_generate: bool,
):
@@ -158,7 +158,7 @@ class ConversationService:
return conversation
@classmethod
def get_conversation(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]):
def get_conversation(cls, app_model: App, conversation_id: str, user: Union[Account, EndUser] | None):
conversation = (
db.session.query(Conversation)
.where(
@@ -178,7 +178,7 @@ class ConversationService:
return conversation
@classmethod
def delete(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]):
def delete(cls, app_model: App, conversation_id: str, user: Union[Account, EndUser] | None):
try:
logger.info(
"Initiating conversation deletion for app_name %s, conversation_id: %s",
@@ -200,9 +200,9 @@ class ConversationService:
cls,
app_model: App,
conversation_id: str,
user: Optional[Union[Account, EndUser]],
user: Union[Account, EndUser] | None,
limit: int,
last_id: Optional[str],
last_id: str | None,
) -> InfiniteScrollPagination:
conversation = cls.get_conversation(app_model, conversation_id, user)
@@ -248,7 +248,7 @@ class ConversationService:
app_model: App,
conversation_id: str,
variable_id: str,
user: Optional[Union[Account, EndUser]],
user: Union[Account, EndUser] | None,
new_value: Any,
):
"""
+19 -19
View File
@@ -7,7 +7,7 @@ import time
import uuid
from collections import Counter
from collections.abc import Sequence
from typing import Any, Literal, Optional
from typing import Any, Literal
import sqlalchemy as sa
from sqlalchemy import exists, func, select
@@ -185,16 +185,16 @@ class DatasetService:
def create_empty_dataset(
tenant_id: str,
name: str,
description: Optional[str],
indexing_technique: Optional[str],
description: str | None,
indexing_technique: str | None,
account: Account,
permission: Optional[str] = None,
permission: str | None = None,
provider: str = "vendor",
external_knowledge_api_id: Optional[str] = None,
external_knowledge_id: Optional[str] = None,
embedding_model_provider: Optional[str] = None,
embedding_model_name: Optional[str] = None,
retrieval_model: Optional[RetrievalModel] = None,
external_knowledge_api_id: str | None = None,
external_knowledge_id: str | None = None,
embedding_model_provider: str | None = None,
embedding_model_name: str | None = None,
retrieval_model: RetrievalModel | None = None,
):
# check if dataset name already exists
if db.session.query(Dataset).filter_by(name=name, tenant_id=tenant_id).first():
@@ -257,8 +257,8 @@ class DatasetService:
return dataset
@staticmethod
def get_dataset(dataset_id) -> Optional[Dataset]:
dataset: Optional[Dataset] = db.session.query(Dataset).filter_by(id=dataset_id).first()
def get_dataset(dataset_id) -> Dataset | None:
dataset: Dataset | None = db.session.query(Dataset).filter_by(id=dataset_id).first()
return dataset
@staticmethod
@@ -694,7 +694,7 @@ class DatasetService:
raise NoPermissionError("You do not have permission to access this dataset.")
@staticmethod
def check_dataset_operator_permission(user: Optional[Account] = None, dataset: Optional[Dataset] = None):
def check_dataset_operator_permission(user: Account | None = None, dataset: Dataset | None = None):
if not dataset:
raise ValueError("Dataset not found")
@@ -868,7 +868,7 @@ class DocumentService:
}
@staticmethod
def get_document(dataset_id: str, document_id: Optional[str] = None) -> Optional[Document]:
def get_document(dataset_id: str, document_id: str | None = None) -> Document | None:
if document_id:
document = (
db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
@@ -878,7 +878,7 @@ class DocumentService:
return None
@staticmethod
def get_document_by_id(document_id: str) -> Optional[Document]:
def get_document_by_id(document_id: str) -> Document | None:
document = db.session.query(Document).where(Document.id == document_id).first()
return document
@@ -1099,7 +1099,7 @@ class DocumentService:
dataset: Dataset,
knowledge_config: KnowledgeConfig,
account: Account | Any,
dataset_process_rule: Optional[DatasetProcessRule] = None,
dataset_process_rule: DatasetProcessRule | None = None,
created_from: str = "web",
) -> tuple[list[Document], str]:
# check doc_form
@@ -1463,7 +1463,7 @@ class DocumentService:
dataset: Dataset,
document_data: KnowledgeConfig,
account: Account,
dataset_process_rule: Optional[DatasetProcessRule] = None,
dataset_process_rule: DatasetProcessRule | None = None,
created_from: str = "web",
):
assert isinstance(current_user, Account)
@@ -2655,7 +2655,7 @@ class SegmentService:
@classmethod
def get_child_chunks(
cls, segment_id: str, document_id: str, dataset_id: str, page: int, limit: int, keyword: Optional[str] = None
cls, segment_id: str, document_id: str, dataset_id: str, page: int, limit: int, keyword: str | None = None
):
assert isinstance(current_user, Account)
@@ -2674,7 +2674,7 @@ class SegmentService:
return db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
@classmethod
def get_child_chunk_by_id(cls, child_chunk_id: str, tenant_id: str) -> Optional[ChildChunk]:
def get_child_chunk_by_id(cls, child_chunk_id: str, tenant_id: str) -> ChildChunk | None:
"""Get a child chunk by its ID."""
result = (
db.session.query(ChildChunk)
@@ -2711,7 +2711,7 @@ class SegmentService:
return paginated_segments.items, paginated_segments.total
@classmethod
def get_segment_by_id(cls, segment_id: str, tenant_id: str) -> Optional[DocumentSegment]:
def get_segment_by_id(cls, segment_id: str, tenant_id: str) -> DocumentSegment | None:
"""Get a segment by its ID."""
result = (
db.session.query(DocumentSegment)
@@ -1,4 +1,4 @@
from typing import Literal, Optional, Union
from typing import Literal, Union
from pydantic import BaseModel
@@ -11,7 +11,7 @@ class AuthorizationConfig(BaseModel):
class Authorization(BaseModel):
type: Literal["no-auth", "api-key"]
config: Optional[AuthorizationConfig] = None
config: AuthorizationConfig | None = None
class ProcessStatusSetting(BaseModel):
@@ -22,5 +22,5 @@ class ProcessStatusSetting(BaseModel):
class ExternalKnowledgeApiSetting(BaseModel):
url: str
request_method: str
headers: Optional[dict] = None
params: Optional[dict] = None
headers: dict | None = None
params: dict | None = None
@@ -1,5 +1,5 @@
from enum import StrEnum
from typing import Literal, Optional
from typing import Literal
from pydantic import BaseModel
@@ -11,14 +11,14 @@ class ParentMode(StrEnum):
class NotionIcon(BaseModel):
type: str
url: Optional[str] = None
emoji: Optional[str] = None
url: str | None = None
emoji: str | None = None
class NotionPage(BaseModel):
page_id: str
page_name: str
page_icon: Optional[NotionIcon] = None
page_icon: NotionIcon | None = None
type: str
@@ -40,9 +40,9 @@ class FileInfo(BaseModel):
class InfoList(BaseModel):
data_source_type: Literal["upload_file", "notion_import", "website_crawl"]
notion_info_list: Optional[list[NotionInfo]] = None
file_info_list: Optional[FileInfo] = None
website_info_list: Optional[WebsiteInfo] = None
notion_info_list: list[NotionInfo] | None = None
file_info_list: FileInfo | None = None
website_info_list: WebsiteInfo | None = None
class DataSource(BaseModel):
@@ -61,20 +61,20 @@ class Segmentation(BaseModel):
class Rule(BaseModel):
pre_processing_rules: Optional[list[PreProcessingRule]] = None
segmentation: Optional[Segmentation] = None
parent_mode: Optional[Literal["full-doc", "paragraph"]] = None
subchunk_segmentation: Optional[Segmentation] = None
pre_processing_rules: list[PreProcessingRule] | None = None
segmentation: Segmentation | None = None
parent_mode: Literal["full-doc", "paragraph"] | None = None
subchunk_segmentation: Segmentation | None = None
class ProcessRule(BaseModel):
mode: Literal["automatic", "custom", "hierarchical"]
rules: Optional[Rule] = None
rules: Rule | None = None
class RerankingModel(BaseModel):
reranking_provider_name: Optional[str] = None
reranking_model_name: Optional[str] = None
reranking_provider_name: str | None = None
reranking_model_name: str | None = None
class WeightVectorSetting(BaseModel):
@@ -88,20 +88,20 @@ class WeightKeywordSetting(BaseModel):
class WeightModel(BaseModel):
weight_type: Optional[Literal["semantic_first", "keyword_first", "customized"]] = None
vector_setting: Optional[WeightVectorSetting] = None
keyword_setting: Optional[WeightKeywordSetting] = None
weight_type: Literal["semantic_first", "keyword_first", "customized"] | None = None
vector_setting: WeightVectorSetting | None = None
keyword_setting: WeightKeywordSetting | None = None
class RetrievalModel(BaseModel):
search_method: Literal["hybrid_search", "semantic_search", "full_text_search", "keyword_search"]
reranking_enable: bool
reranking_model: Optional[RerankingModel] = None
reranking_mode: Optional[str] = None
reranking_model: RerankingModel | None = None
reranking_mode: str | None = None
top_k: int
score_threshold_enabled: bool
score_threshold: Optional[float] = None
weights: Optional[WeightModel] = None
score_threshold: float | None = None
weights: WeightModel | None = None
class MetaDataConfig(BaseModel):
@@ -110,29 +110,29 @@ class MetaDataConfig(BaseModel):
class KnowledgeConfig(BaseModel):
original_document_id: Optional[str] = None
original_document_id: str | None = None
duplicate: bool = True
indexing_technique: Literal["high_quality", "economy"]
data_source: Optional[DataSource] = None
process_rule: Optional[ProcessRule] = None
retrieval_model: Optional[RetrievalModel] = None
data_source: DataSource | None = None
process_rule: ProcessRule | None = None
retrieval_model: RetrievalModel | None = None
doc_form: str = "text_model"
doc_language: str = "English"
embedding_model: Optional[str] = None
embedding_model_provider: Optional[str] = None
name: Optional[str] = None
embedding_model: str | None = None
embedding_model_provider: str | None = None
name: str | None = None
class SegmentUpdateArgs(BaseModel):
content: Optional[str] = None
answer: Optional[str] = None
keywords: Optional[list[str]] = None
content: str | None = None
answer: str | None = None
keywords: list[str] | None = None
regenerate_child_chunks: bool = False
enabled: Optional[bool] = None
enabled: bool | None = None
class ChildChunkUpdateArgs(BaseModel):
id: Optional[str] = None
id: str | None = None
content: str
@@ -143,13 +143,13 @@ class MetadataArgs(BaseModel):
class MetadataUpdateArgs(BaseModel):
name: str
value: Optional[str | int | float] = None
value: str | int | float | None = None
class MetadataDetail(BaseModel):
id: str
name: str
value: Optional[str | int | float] = None
value: str | int | float | None = None
class DocumentMetadataOperation(BaseModel):
@@ -1,5 +1,4 @@
from enum import Enum
from typing import Optional
from pydantic import BaseModel, ConfigDict
@@ -42,11 +41,11 @@ class CustomConfigurationResponse(BaseModel):
"""
status: CustomConfigurationStatus
current_credential_id: Optional[str] = None
current_credential_name: Optional[str] = None
available_credentials: Optional[list[CredentialConfiguration]] = None
custom_models: Optional[list[CustomModelConfiguration]] = None
can_added_models: Optional[list[UnaddedModelConfiguration]] = None
current_credential_id: str | None = None
current_credential_name: str | None = None
available_credentials: list[CredentialConfiguration] | None = None
custom_models: list[CustomModelConfiguration] | None = None
can_added_models: list[UnaddedModelConfiguration] | None = None
class SystemConfigurationResponse(BaseModel):
@@ -55,7 +54,7 @@ class SystemConfigurationResponse(BaseModel):
"""
enabled: bool
current_quota_type: Optional[ProviderQuotaType] = None
current_quota_type: ProviderQuotaType | None = None
quota_configurations: list[QuotaConfiguration] = []
@@ -67,15 +66,15 @@ class ProviderResponse(BaseModel):
tenant_id: str
provider: str
label: I18nObject
description: Optional[I18nObject] = None
icon_small: Optional[I18nObject] = None
icon_large: Optional[I18nObject] = None
background: Optional[str] = None
help: Optional[ProviderHelpEntity] = None
description: I18nObject | None = None
icon_small: I18nObject | None = None
icon_large: I18nObject | None = None
background: str | None = None
help: ProviderHelpEntity | None = None
supported_model_types: list[ModelType]
configurate_methods: list[ConfigurateMethod]
provider_credential_schema: Optional[ProviderCredentialSchema] = None
model_credential_schema: Optional[ModelCredentialSchema] = None
provider_credential_schema: ProviderCredentialSchema | None = None
model_credential_schema: ModelCredentialSchema | None = None
preferred_provider_type: ProviderType
custom_configuration: CustomConfigurationResponse
system_configuration: SystemConfigurationResponse
@@ -108,8 +107,8 @@ class ProviderWithModelsResponse(BaseModel):
tenant_id: str
provider: str
label: I18nObject
icon_small: Optional[I18nObject] = None
icon_large: Optional[I18nObject] = None
icon_small: I18nObject | None = None
icon_large: I18nObject | None = None
status: CustomConfigurationStatus
models: list[ProviderModelWithStatusEntity]
+1 -4
View File
@@ -1,6 +1,3 @@
from typing import Optional
class BaseServiceError(ValueError):
def __init__(self, description: Optional[str] = None):
def __init__(self, description: str | None = None):
self.description = description
+2 -5
View File
@@ -1,12 +1,9 @@
from typing import Optional
class InvokeError(Exception):
"""Base class for all LLM exceptions."""
description: Optional[str] = None
description: str | None = None
def __init__(self, description: Optional[str] = None):
def __init__(self, description: str | None = None):
self.description = description
def __str__(self):
+6 -6
View File
@@ -1,6 +1,6 @@
import json
from copy import deepcopy
from typing import Any, Optional, Union, cast
from typing import Any, Union, cast
from urllib.parse import urlparse
import httpx
@@ -100,7 +100,7 @@ class ExternalDatasetService:
@staticmethod
def get_external_knowledge_api(external_knowledge_api_id: str) -> ExternalKnowledgeApis:
external_knowledge_api: Optional[ExternalKnowledgeApis] = (
external_knowledge_api: ExternalKnowledgeApis | None = (
db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id).first()
)
if external_knowledge_api is None:
@@ -109,7 +109,7 @@ class ExternalDatasetService:
@staticmethod
def update_external_knowledge_api(tenant_id, user_id, external_knowledge_api_id, args) -> ExternalKnowledgeApis:
external_knowledge_api: Optional[ExternalKnowledgeApis] = (
external_knowledge_api: ExternalKnowledgeApis | None = (
db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first()
)
if external_knowledge_api is None:
@@ -151,7 +151,7 @@ class ExternalDatasetService:
@staticmethod
def get_external_knowledge_binding_with_dataset_id(tenant_id: str, dataset_id: str) -> ExternalKnowledgeBindings:
external_knowledge_binding: Optional[ExternalKnowledgeBindings] = (
external_knowledge_binding: ExternalKnowledgeBindings | None = (
db.session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id, tenant_id=tenant_id).first()
)
if not external_knowledge_binding:
@@ -203,7 +203,7 @@ class ExternalDatasetService:
return response
@staticmethod
def assembling_headers(authorization: Authorization, headers: Optional[dict] = None) -> dict[str, Any]:
def assembling_headers(authorization: Authorization, headers: dict | None = None) -> dict[str, Any]:
authorization = deepcopy(authorization)
if headers:
headers = deepcopy(headers)
@@ -277,7 +277,7 @@ class ExternalDatasetService:
dataset_id: str,
query: str,
external_retrieval_parameters: dict,
metadata_condition: Optional[MetadataCondition] = None,
metadata_condition: MetadataCondition | None = None,
):
external_knowledge_binding = (
db.session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id, tenant_id=tenant_id).first()
+12 -12
View File
@@ -1,5 +1,5 @@
import json
from typing import Optional, Union
from typing import Union
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
from core.app.entities.app_invoke_entities import InvokeFrom
@@ -29,9 +29,9 @@ class MessageService:
def pagination_by_first_id(
cls,
app_model: App,
user: Optional[Union[Account, EndUser]],
user: Union[Account, EndUser] | None,
conversation_id: str,
first_id: Optional[str],
first_id: str | None,
limit: int,
order: str = "asc",
) -> InfiniteScrollPagination:
@@ -91,11 +91,11 @@ class MessageService:
def pagination_by_last_id(
cls,
app_model: App,
user: Optional[Union[Account, EndUser]],
last_id: Optional[str],
user: Union[Account, EndUser] | None,
last_id: str | None,
limit: int,
conversation_id: Optional[str] = None,
include_ids: Optional[list] = None,
conversation_id: str | None = None,
include_ids: list | None = None,
) -> InfiniteScrollPagination:
if not user:
return InfiniteScrollPagination(data=[], limit=limit, has_more=False)
@@ -145,9 +145,9 @@ class MessageService:
*,
app_model: App,
message_id: str,
user: Optional[Union[Account, EndUser]],
rating: Optional[str],
content: Optional[str],
user: Union[Account, EndUser] | None,
rating: str | None,
content: str | None,
):
if not user:
raise ValueError("user cannot be None")
@@ -196,7 +196,7 @@ class MessageService:
return [record.to_dict() for record in feedbacks]
@classmethod
def get_message(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str):
def get_message(cls, app_model: App, user: Union[Account, EndUser] | None, message_id: str):
message = (
db.session.query(Message)
.where(
@@ -216,7 +216,7 @@ class MessageService:
@classmethod
def get_suggested_questions_after_answer(
cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str, invoke_from: InvokeFrom
cls, app_model: App, user: Union[Account, EndUser] | None, message_id: str, invoke_from: InvokeFrom
) -> list[Message]:
if not user:
raise ValueError("user cannot be None")
+1 -2
View File
@@ -1,6 +1,5 @@
import copy
import logging
from typing import Optional
from flask_login import current_user
@@ -237,7 +236,7 @@ class MetadataService:
redis_client.delete(lock_key)
@staticmethod
def knowledge_base_metadata_lock_check(dataset_id: Optional[str], document_id: Optional[str]):
def knowledge_base_metadata_lock_check(dataset_id: str | None, document_id: str | None):
if dataset_id:
lock_key = f"dataset_metadata_lock_{dataset_id}"
if redis_client.get(lock_key):
+4 -4
View File
@@ -1,7 +1,7 @@
import json
import logging
from json import JSONDecodeError
from typing import Optional, Union
from typing import Union
from sqlalchemy import or_, select
@@ -211,7 +211,7 @@ class ModelLoadBalancingService:
def get_load_balancing_config(
self, tenant_id: str, provider: str, model: str, model_type: str, config_id: str
) -> Optional[dict]:
) -> dict | None:
"""
Get load balancing configuration.
:param tenant_id: workspace id
@@ -478,7 +478,7 @@ class ModelLoadBalancingService:
model: str,
model_type: str,
credentials: dict,
config_id: Optional[str] = None,
config_id: str | None = None,
):
"""
Validate load balancing credentials.
@@ -536,7 +536,7 @@ class ModelLoadBalancingService:
model_type: ModelType,
model: str,
credentials: dict,
load_balancing_model_config: Optional[LoadBalancingModelConfig] = None,
load_balancing_model_config: LoadBalancingModelConfig | None = None,
validate: bool = True,
):
"""
+5 -8
View File
@@ -1,5 +1,4 @@
import logging
from typing import Optional
from core.entities.model_entities import ModelWithProviderEntity, ProviderModelWithStatusEntity
from core.model_runtime.entities.model_entities import ModelType, ParameterRule
@@ -52,7 +51,7 @@ class ModelProviderService:
return provider_configuration
def get_provider_list(self, tenant_id: str, model_type: Optional[str] = None) -> list[ProviderResponse]:
def get_provider_list(self, tenant_id: str, model_type: str | None = None) -> list[ProviderResponse]:
"""
get provider list.
@@ -128,9 +127,7 @@ class ModelProviderService:
for model in provider_configurations.get_models(provider=provider)
]
def get_provider_credential(
self, tenant_id: str, provider: str, credential_id: Optional[str] = None
) -> Optional[dict]:
def get_provider_credential(self, tenant_id: str, provider: str, credential_id: str | None = None) -> dict | None:
"""
get provider credentials.
@@ -216,7 +213,7 @@ class ModelProviderService:
def get_model_credential(
self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str | None
) -> Optional[dict]:
) -> dict | None:
"""
Retrieve model-specific credentials.
@@ -449,7 +446,7 @@ class ModelProviderService:
return model_schema.parameter_rules if model_schema else []
def get_default_model_of_model_type(self, tenant_id: str, model_type: str) -> Optional[DefaultModelResponse]:
def get_default_model_of_model_type(self, tenant_id: str, model_type: str) -> DefaultModelResponse | None:
"""
get default model of model type.
@@ -498,7 +495,7 @@ class ModelProviderService:
def get_model_provider_icon(
self, tenant_id: str, provider: str, icon_type: str, lang: str
) -> tuple[Optional[bytes], Optional[str]]:
) -> tuple[bytes | None, str | None]:
"""
get model provider icon.
+3 -3
View File
@@ -1,4 +1,4 @@
from typing import Any, Optional
from typing import Any
from core.ops.entities.config_entity import BaseTracingConfig
from core.ops.ops_trace_manager import OpsTraceManager, provider_config_map
@@ -15,7 +15,7 @@ class OpsService:
:param tracing_provider: tracing provider
:return:
"""
trace_config_data: Optional[TraceAppConfig] = (
trace_config_data: TraceAppConfig | None = (
db.session.query(TraceAppConfig)
.where(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
.first()
@@ -153,7 +153,7 @@ class OpsService:
project_url = None
# check if trace config already exists
trace_config_data: Optional[TraceAppConfig] = (
trace_config_data: TraceAppConfig | None = (
db.session.query(TraceAppConfig)
.where(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
.first()
+2 -2
View File
@@ -5,7 +5,7 @@ import time
from collections.abc import Mapping, Sequence
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import Any, Optional
from typing import Any
from uuid import uuid4
import click
@@ -281,7 +281,7 @@ class PluginMigration:
return result
@classmethod
def _fetch_plugin_unique_identifier(cls, plugin_id: str) -> Optional[str]:
def _fetch_plugin_unique_identifier(cls, plugin_id: str) -> str | None:
"""
Fetch plugin unique identifier using plugin id.
"""
+4 -5
View File
@@ -1,7 +1,6 @@
import logging
from collections.abc import Mapping, Sequence
from mimetypes import guess_type
from typing import Optional
from pydantic import BaseModel
@@ -46,11 +45,11 @@ class PluginService:
REDIS_TTL = 60 * 5 # 5 minutes
@staticmethod
def fetch_latest_plugin_version(plugin_ids: Sequence[str]) -> Mapping[str, Optional[LatestPluginCache]]:
def fetch_latest_plugin_version(plugin_ids: Sequence[str]) -> Mapping[str, LatestPluginCache | None]:
"""
Fetch the latest plugin version
"""
result: dict[str, Optional[PluginService.LatestPluginCache]] = {}
result: dict[str, PluginService.LatestPluginCache | None] = {}
try:
cache_not_exists = []
@@ -109,7 +108,7 @@ class PluginService:
raise PluginInstallationForbiddenError("Plugin installation is restricted to marketplace only")
@staticmethod
def _check_plugin_installation_scope(plugin_verification: Optional[PluginVerification]):
def _check_plugin_installation_scope(plugin_verification: PluginVerification | None):
"""
Check the plugin installation scope
"""
@@ -144,7 +143,7 @@ class PluginService:
return manager.get_debugging_key(tenant_id)
@staticmethod
def list_latest_versions(plugin_ids: Sequence[str]) -> Mapping[str, Optional[LatestPluginCache]]:
def list_latest_versions(plugin_ids: Sequence[str]) -> Mapping[str, LatestPluginCache | None]:
"""
List the latest versions of the plugins
"""
@@ -1,7 +1,6 @@
import json
from os import path
from pathlib import Path
from typing import Optional
from flask import current_app
@@ -14,7 +13,7 @@ class BuildInRecommendAppRetrieval(RecommendAppRetrievalBase):
Retrieval recommended app from buildin, the location is constants/recommended_apps.json
"""
builtin_data: Optional[dict] = None
builtin_data: dict | None = None
def get_type(self) -> str:
return RecommendAppType.BUILDIN
@@ -54,7 +53,7 @@ class BuildInRecommendAppRetrieval(RecommendAppRetrievalBase):
return builtin_data.get("recommended_apps", {}).get(language, {})
@classmethod
def fetch_recommended_app_detail_from_builtin(cls, app_id: str) -> Optional[dict]:
def fetch_recommended_app_detail_from_builtin(cls, app_id: str) -> dict | None:
"""
Fetch recommended app detail from builtin.
:param app_id: App ID
@@ -1,5 +1,3 @@
from typing import Optional
from sqlalchemy import select
from constants.languages import languages
@@ -72,7 +70,7 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase):
return {"recommended_apps": recommended_apps_result, "categories": sorted(categories)}
@classmethod
def fetch_recommended_app_detail_from_db(cls, app_id: str) -> Optional[dict]:
def fetch_recommended_app_detail_from_db(cls, app_id: str) -> dict | None:
"""
Fetch recommended app detail from db.
:param app_id: App ID
@@ -1,5 +1,4 @@
import logging
from typing import Optional
import requests
@@ -36,7 +35,7 @@ class RemoteRecommendAppRetrieval(RecommendAppRetrievalBase):
return RecommendAppType.REMOTE
@classmethod
def fetch_recommended_app_detail_from_dify_official(cls, app_id: str) -> Optional[dict]:
def fetch_recommended_app_detail_from_dify_official(cls, app_id: str) -> dict | None:
"""
Fetch recommended app detail from dify official.
:param app_id: App ID
+1 -3
View File
@@ -1,5 +1,3 @@
from typing import Optional
from configs import dify_config
from services.recommend_app.recommend_app_factory import RecommendAppRetrievalFactory
@@ -25,7 +23,7 @@ class RecommendedAppService:
return result
@classmethod
def get_recommend_app_detail(cls, app_id: str) -> Optional[dict]:
def get_recommend_app_detail(cls, app_id: str) -> dict | None:
"""
Get recommend app detail.
:param app_id: app id
+4 -4
View File
@@ -1,4 +1,4 @@
from typing import Optional, Union
from typing import Union
from extensions.ext_database import db
from libs.infinite_scroll_pagination import InfiniteScrollPagination
@@ -11,7 +11,7 @@ from services.message_service import MessageService
class SavedMessageService:
@classmethod
def pagination_by_last_id(
cls, app_model: App, user: Optional[Union[Account, EndUser]], last_id: Optional[str], limit: int
cls, app_model: App, user: Union[Account, EndUser] | None, last_id: str | None, limit: int
) -> InfiniteScrollPagination:
if not user:
raise ValueError("User is required")
@@ -32,7 +32,7 @@ class SavedMessageService:
)
@classmethod
def save(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str):
def save(cls, app_model: App, user: Union[Account, EndUser] | None, message_id: str):
if not user:
return
saved_message = (
@@ -62,7 +62,7 @@ class SavedMessageService:
db.session.commit()
@classmethod
def delete(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str):
def delete(cls, app_model: App, user: Union[Account, EndUser] | None, message_id: str):
if not user:
return
saved_message = (
+1 -2
View File
@@ -1,5 +1,4 @@
import uuid
from typing import Optional
from flask_login import current_user
from sqlalchemy import func, select
@@ -12,7 +11,7 @@ from models.model import App, Tag, TagBinding
class TagService:
@staticmethod
def get_tags(tag_type: str, current_tenant_id: str, keyword: Optional[str] = None):
def get_tags(tag_type: str, current_tenant_id: str, keyword: str | None = None):
query = (
db.session.query(Tag.id, Tag.type, Tag.name, func.count(TagBinding.id).label("binding_count"))
.outerjoin(TagBinding, Tag.id == TagBinding.tag_id)
@@ -3,7 +3,7 @@ import logging
import re
from collections.abc import Mapping
from pathlib import Path
from typing import Any, Optional
from typing import Any
from sqlalchemy import exists, select
from sqlalchemy.orm import Session
@@ -604,7 +604,7 @@ class BuiltinToolManageService:
return BuiltinToolProviderSort.sort(result)
@staticmethod
def get_builtin_provider(provider_name: str, tenant_id: str) -> Optional[BuiltinToolProvider]:
def get_builtin_provider(provider_name: str, tenant_id: str) -> BuiltinToolProvider | None:
"""
This method is used to fetch the builtin provider from the database
1.if the default provider exists, return the default provider
@@ -665,8 +665,8 @@ class BuiltinToolManageService:
def save_custom_oauth_client_params(
tenant_id: str,
provider: str,
client_params: Optional[dict] = None,
enable_oauth_custom_client: Optional[bool] = None,
client_params: dict | None = None,
enable_oauth_custom_client: bool | None = None,
):
"""
setup oauth custom client
+1 -2
View File
@@ -1,5 +1,4 @@
import logging
from typing import Optional
from core.tools.entities.api_entities import ToolProviderTypeApiLiteral
from core.tools.tool_manager import ToolManager
@@ -10,7 +9,7 @@ logger = logging.getLogger(__name__)
class ToolCommonService:
@staticmethod
def list_tool_providers(user_id: str, tenant_id: str, typ: Optional[ToolProviderTypeApiLiteral] = None):
def list_tool_providers(user_id: str, tenant_id: str, typ: ToolProviderTypeApiLiteral | None = None):
"""
list tool providers
@@ -1,6 +1,6 @@
import json
import logging
from typing import Any, Optional, Union, cast
from typing import Any, Union, cast
from yarl import URL
@@ -94,7 +94,7 @@ class ToolTransformService:
def builtin_provider_to_user_provider(
cls,
provider_controller: BuiltinToolProviderController | PluginToolProviderController,
db_provider: Optional[BuiltinToolProvider],
db_provider: BuiltinToolProvider | None,
decrypt_credentials: bool = True,
) -> ToolProviderApiEntity:
"""
+2 -3
View File
@@ -1,5 +1,4 @@
import logging
from typing import Optional
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.model_entities import ModelType
@@ -19,7 +18,7 @@ logger = logging.getLogger(__name__)
class VectorService:
@classmethod
def create_segments_vector(
cls, keywords_list: Optional[list[list[str]]], segments: list[DocumentSegment], dataset: Dataset, doc_form: str
cls, keywords_list: list[list[str]] | None, segments: list[DocumentSegment], dataset: Dataset, doc_form: str
):
documents: list[Document] = []
@@ -79,7 +78,7 @@ class VectorService:
index_processor.load(dataset, documents, with_keywords=True, keywords_list=keywords_list)
@classmethod
def update_segment_vector(cls, keywords: Optional[list[str]], segment: DocumentSegment, dataset: Dataset):
def update_segment_vector(cls, keywords: list[str] | None, segment: DocumentSegment, dataset: Dataset):
# update segment index task
# format new index
+6 -6
View File
@@ -1,4 +1,4 @@
from typing import Optional, Union
from typing import Union
from sqlalchemy import select
from sqlalchemy.orm import Session
@@ -19,11 +19,11 @@ class WebConversationService:
*,
session: Session,
app_model: App,
user: Optional[Union[Account, EndUser]],
last_id: Optional[str],
user: Union[Account, EndUser] | None,
last_id: str | None,
limit: int,
invoke_from: InvokeFrom,
pinned: Optional[bool] = None,
pinned: bool | None = None,
sort_by="-updated_at",
) -> InfiniteScrollPagination:
if not user:
@@ -60,7 +60,7 @@ class WebConversationService:
)
@classmethod
def pin(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]):
def pin(cls, app_model: App, conversation_id: str, user: Union[Account, EndUser] | None):
if not user:
return
pinned_conversation = (
@@ -92,7 +92,7 @@ class WebConversationService:
db.session.commit()
@classmethod
def unpin(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]):
def unpin(cls, app_model: App, conversation_id: str, user: Union[Account, EndUser] | None):
if not user:
return
pinned_conversation = (
+4 -4
View File
@@ -1,7 +1,7 @@
import enum
import secrets
from datetime import UTC, datetime, timedelta
from typing import Any, Optional
from typing import Any
from werkzeug.exceptions import NotFound, Unauthorized
@@ -63,7 +63,7 @@ class WebAppAuthService:
@classmethod
def send_email_code_login_email(
cls, account: Optional[Account] = None, email: Optional[str] = None, language: str = "en-US"
cls, account: Account | None = None, email: str | None = None, language: str = "en-US"
):
email = account.email if account else email
if email is None:
@@ -82,7 +82,7 @@ class WebAppAuthService:
return token
@classmethod
def get_email_code_login_data(cls, token: str) -> Optional[dict[str, Any]]:
def get_email_code_login_data(cls, token: str) -> dict[str, Any] | None:
return TokenManager.get_token_data(token, "email_code_login")
@classmethod
@@ -130,7 +130,7 @@ class WebAppAuthService:
@classmethod
def is_app_require_permission_check(
cls, app_code: Optional[str] = None, app_id: Optional[str] = None, access_mode: Optional[str] = None
cls, app_code: str | None = None, app_id: str | None = None, access_mode: str | None = None
) -> bool:
"""
Check if the app requires permission check based on its access mode.
+4 -4
View File
@@ -1,7 +1,7 @@
import datetime
import json
from dataclasses import dataclass
from typing import Any, Optional
from typing import Any
import requests
from flask_login import current_user
@@ -21,9 +21,9 @@ class CrawlOptions:
limit: int = 1
crawl_sub_pages: bool = False
only_main_content: bool = False
includes: Optional[str] = None
excludes: Optional[str] = None
max_depth: Optional[int] = None
includes: str | None = None
excludes: str | None = None
max_depth: int | None = None
use_sitemap: bool = True
def get_include_paths(self) -> list[str]:
+4 -4
View File
@@ -1,5 +1,5 @@
import json
from typing import Any, Optional
from typing import Any
from core.app.app_config.entities import (
DatasetEntity,
@@ -327,7 +327,7 @@ class WorkflowConverter:
def _convert_to_knowledge_retrieval_node(
self, new_app_mode: AppMode, dataset_config: DatasetEntity, model_config: ModelConfigEntity
) -> Optional[dict]:
) -> dict | None:
"""
Convert datasets to Knowledge Retrieval Node
:param new_app_mode: new app mode
@@ -383,7 +383,7 @@ class WorkflowConverter:
graph: dict,
model_config: ModelConfigEntity,
prompt_template: PromptTemplateEntity,
file_upload: Optional[FileUploadConfig] = None,
file_upload: FileUploadConfig | None = None,
external_data_variable_node_mapping: dict[str, str] | None = None,
):
"""
@@ -403,7 +403,7 @@ class WorkflowConverter:
)
role_prefix = None
prompts: Optional[Any] = None
prompts: Any | None = None
# Chat Model
if model_config.mode == LLMMode.CHAT.value:
+1 -2
View File
@@ -1,6 +1,5 @@
import threading
from collections.abc import Sequence
from typing import Optional
from sqlalchemy.orm import sessionmaker
@@ -80,7 +79,7 @@ class WorkflowRunService:
last_id=last_id,
)
def get_workflow_run(self, app_model: App, run_id: str) -> Optional[WorkflowRun]:
def get_workflow_run(self, app_model: App, run_id: str) -> WorkflowRun | None:
"""
Get workflow run detail
+7 -7
View File
@@ -2,7 +2,7 @@ import json
import time
import uuid
from collections.abc import Callable, Generator, Mapping, Sequence
from typing import Any, Optional, cast
from typing import Any, cast
from uuid import uuid4
from sqlalchemy import exists, select
@@ -88,7 +88,7 @@ class WorkflowService:
)
return db.session.execute(stmt).scalar_one()
def get_draft_workflow(self, app_model: App, workflow_id: Optional[str] = None) -> Optional[Workflow]:
def get_draft_workflow(self, app_model: App, workflow_id: str | None = None) -> Workflow | None:
"""
Get draft workflow
"""
@@ -108,7 +108,7 @@ class WorkflowService:
# return draft workflow
return workflow
def get_published_workflow_by_id(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
def get_published_workflow_by_id(self, app_model: App, workflow_id: str) -> Workflow | None:
"""
fetch published workflow by workflow_id
"""
@@ -130,7 +130,7 @@ class WorkflowService:
)
return workflow
def get_published_workflow(self, app_model: App) -> Optional[Workflow]:
def get_published_workflow(self, app_model: App) -> Workflow | None:
"""
Get published workflow
"""
@@ -195,7 +195,7 @@ class WorkflowService:
app_model: App,
graph: dict,
features: dict,
unique_hash: Optional[str],
unique_hash: str | None,
account: Account,
environment_variables: Sequence[Variable],
conversation_variables: Sequence[Variable],
@@ -561,7 +561,7 @@ class WorkflowService:
return default_block_configs
def get_default_block_config(self, node_type: str, filters: Optional[dict] = None) -> Optional[dict]:
def get_default_block_config(self, node_type: str, filters: dict | None = None) -> dict | None:
"""
Get default config of node.
:param node_type: node type
@@ -857,7 +857,7 @@ class WorkflowService:
def update_workflow(
self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict
) -> Optional[Workflow]:
) -> Workflow | None:
"""
Update workflow attributes