remove bare list, dict, Sequence, None, Any (#25058)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
Asuka Minato
2025-09-06 04:32:23 +09:00
committed by GitHub
parent 2b0695bdde
commit a78339a040
306 changed files with 787 additions and 817 deletions
+1 -1
View File
@@ -11,5 +11,5 @@ class RemoteSettingsSource:
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]: def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
raise NotImplementedError raise NotImplementedError
def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, value_is_complex: bool) -> Any: def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, value_is_complex: bool):
return value return value
@@ -33,7 +33,7 @@ class NacosSettingsSource(RemoteSettingsSource):
logger.exception("[get-access-token] exception occurred") logger.exception("[get-access-token] exception occurred")
raise raise
def _parse_config(self, content: str) -> dict: def _parse_config(self, content: str):
if not content: if not content:
return {} return {}
try: try:
+1 -1
View File
@@ -207,7 +207,7 @@ class InstructionGenerationTemplateApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self) -> dict: def post(self):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("type", type=str, required=True, default=False, location="json") parser.add_argument("type", type=str, required=True, default=False, location="json")
args = parser.parse_args() args = parser.parse_args()
@@ -1,5 +1,5 @@
import logging import logging
from typing import Any, NoReturn from typing import NoReturn
from flask import Response from flask import Response
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
@@ -29,7 +29,7 @@ from services.workflow_service import WorkflowService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _convert_values_to_json_serializable_object(value: Segment) -> Any: def _convert_values_to_json_serializable_object(value: Segment):
if isinstance(value, FileSegment): if isinstance(value, FileSegment):
return value.value.model_dump() return value.value.model_dump()
elif isinstance(value, ArrayFileSegment): elif isinstance(value, ArrayFileSegment):
@@ -40,7 +40,7 @@ def _convert_values_to_json_serializable_object(value: Segment) -> Any:
return value.value return value.value
def _serialize_var_value(variable: WorkflowDraftVariable) -> Any: def _serialize_var_value(variable: WorkflowDraftVariable):
value = variable.get_value() value = variable.get_value()
# create a copy of the value to avoid affecting the model cache. # create a copy of the value to avoid affecting the model cache.
value = value.model_copy(deep=True) value = value.model_copy(deep=True)
+1 -1
View File
@@ -99,7 +99,7 @@ class MCPAppApi(Resource):
return mcp_server, app return mcp_server, app
def _validate_server_status(self, mcp_server: AppMCPServer) -> None: def _validate_server_status(self, mcp_server: AppMCPServer):
"""Validate MCP server status""" """Validate MCP server status"""
if mcp_server.status != AppMCPServerStatus.ACTIVE: if mcp_server.status != AppMCPServerStatus.ACTIVE:
raise MCPRequestError(mcp_types.INVALID_REQUEST, "Server is not active") raise MCPRequestError(mcp_types.INVALID_REQUEST, "Server is not active")
+1 -1
View File
@@ -62,7 +62,7 @@ class BaseAgentRunner(AppRunner):
model_instance: ModelInstance, model_instance: ModelInstance,
memory: Optional[TokenBufferMemory] = None, memory: Optional[TokenBufferMemory] = None,
prompt_messages: Optional[list[PromptMessage]] = None, prompt_messages: Optional[list[PromptMessage]] = None,
) -> None: ):
self.tenant_id = tenant_id self.tenant_id = tenant_id
self.application_generate_entity = application_generate_entity self.application_generate_entity = application_generate_entity
self.conversation = conversation self.conversation = conversation
+1 -1
View File
@@ -338,7 +338,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
return instruction return instruction
def _init_react_state(self, query) -> None: def _init_react_state(self, query):
""" """
init agent scratchpad init agent scratchpad
""" """
+1 -1
View File
@@ -41,7 +41,7 @@ class AgentScratchpadUnit(BaseModel):
action_name: str action_name: str
action_input: Union[dict, str] action_input: Union[dict, str]
def to_dict(self) -> dict: def to_dict(self):
""" """
Convert to dictionary. Convert to dictionary.
""" """
@@ -158,7 +158,7 @@ class DatasetConfigManager:
return config, ["agent_mode", "dataset_configs", "dataset_query_variable"] return config, ["agent_mode", "dataset_configs", "dataset_query_variable"]
@classmethod @classmethod
def extract_dataset_config_for_legacy_compatibility(cls, tenant_id: str, app_mode: AppMode, config: dict) -> dict: def extract_dataset_config_for_legacy_compatibility(cls, tenant_id: str, app_mode: AppMode, config: dict):
""" """
Extract dataset config for legacy compatibility Extract dataset config for legacy compatibility
@@ -105,7 +105,7 @@ class ModelConfigManager:
return dict(config), ["model"] return dict(config), ["model"]
@classmethod @classmethod
def validate_model_completion_params(cls, cp: dict) -> dict: def validate_model_completion_params(cls, cp: dict):
# model.completion_params # model.completion_params
if not isinstance(cp, dict): if not isinstance(cp, dict):
raise ValueError("model.completion_params must be of object type") raise ValueError("model.completion_params must be of object type")
@@ -122,7 +122,7 @@ class PromptTemplateConfigManager:
return config, ["prompt_type", "pre_prompt", "chat_prompt_config", "completion_prompt_config"] return config, ["prompt_type", "pre_prompt", "chat_prompt_config", "completion_prompt_config"]
@classmethod @classmethod
def validate_post_prompt_and_set_defaults(cls, config: dict) -> dict: def validate_post_prompt_and_set_defaults(cls, config: dict):
""" """
Validate post_prompt and set defaults for prompt feature Validate post_prompt and set defaults for prompt feature
@@ -41,7 +41,7 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
return app_config return app_config
@classmethod @classmethod
def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict: def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False):
""" """
Validate for advanced chat app model config Validate for advanced chat app model config
@@ -481,7 +481,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
message_id: str, message_id: str,
context: contextvars.Context, context: contextvars.Context,
variable_loader: VariableLoader, variable_loader: VariableLoader,
) -> None: ):
""" """
Generate worker in a new thread. Generate worker in a new thread.
:param flask_app: Flask app :param flask_app: Flask app
@@ -54,7 +54,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
workflow: Workflow, workflow: Workflow,
system_user_id: str, system_user_id: str,
app: App, app: App,
) -> None: ):
super().__init__( super().__init__(
queue_manager=queue_manager, queue_manager=queue_manager,
variable_loader=variable_loader, variable_loader=variable_loader,
@@ -68,7 +68,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
self.system_user_id = system_user_id self.system_user_id = system_user_id
self._app = app self._app = app
def run(self) -> None: def run(self):
app_config = self.application_generate_entity.app_config app_config = self.application_generate_entity.app_config
app_config = cast(AdvancedChatAppConfig, app_config) app_config = cast(AdvancedChatAppConfig, app_config)
@@ -221,7 +221,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
return False return False
def _complete_with_stream_output(self, text: str, stopped_by: QueueStopEvent.StopBy) -> None: def _complete_with_stream_output(self, text: str, stopped_by: QueueStopEvent.StopBy):
""" """
Direct output Direct output
""" """
@@ -101,7 +101,7 @@ class AdvancedChatAppGenerateTaskPipeline:
workflow_execution_repository: WorkflowExecutionRepository, workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository, workflow_node_execution_repository: WorkflowNodeExecutionRepository,
draft_var_saver_factory: DraftVariableSaverFactory, draft_var_saver_factory: DraftVariableSaverFactory,
) -> None: ):
self._base_task_pipeline = BasedGenerateTaskPipeline( self._base_task_pipeline = BasedGenerateTaskPipeline(
application_generate_entity=application_generate_entity, application_generate_entity=application_generate_entity,
queue_manager=queue_manager, queue_manager=queue_manager,
@@ -289,7 +289,7 @@ class AdvancedChatAppGenerateTaskPipeline:
session.rollback() session.rollback()
raise raise
def _ensure_workflow_initialized(self) -> None: def _ensure_workflow_initialized(self):
"""Fluent validation for workflow state.""" """Fluent validation for workflow state."""
if not self._workflow_run_id: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
@@ -888,7 +888,7 @@ class AdvancedChatAppGenerateTaskPipeline:
if self._conversation_name_generate_thread: if self._conversation_name_generate_thread:
self._conversation_name_generate_thread.join() self._conversation_name_generate_thread.join()
def _save_message(self, *, session: Session, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None: def _save_message(self, *, session: Session, graph_runtime_state: Optional[GraphRuntimeState] = None):
message = self._get_message(session=session) message = self._get_message(session=session)
# If there are assistant files, remove markdown image links from answer # If there are assistant files, remove markdown image links from answer
@@ -86,7 +86,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
return app_config return app_config
@classmethod @classmethod
def config_validate(cls, tenant_id: str, config: Mapping[str, Any]) -> dict: def config_validate(cls, tenant_id: str, config: Mapping[str, Any]):
""" """
Validate for agent chat app model config Validate for agent chat app model config
@@ -222,7 +222,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
conversation_id: str, conversation_id: str,
message_id: str, message_id: str,
) -> None: ):
""" """
Generate worker in a new thread. Generate worker in a new thread.
:param flask_app: Flask app :param flask_app: Flask app
+1 -1
View File
@@ -35,7 +35,7 @@ class AgentChatAppRunner(AppRunner):
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
conversation: Conversation, conversation: Conversation,
message: Message, message: Message,
) -> None: ):
""" """
Run assistant application Run assistant application
:param application_generate_entity: application generate entity :param application_generate_entity: application generate entity
@@ -16,7 +16,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = ChatbotAppBlockingResponse _blocking_response_type = ChatbotAppBlockingResponse
@classmethod @classmethod
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override] def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
""" """
Convert blocking full response. Convert blocking full response.
:param blocking_response: blocking response :param blocking_response: blocking response
@@ -37,7 +37,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
return response return response
@classmethod @classmethod
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override] def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
""" """
Convert blocking simple response. Convert blocking simple response.
:param blocking_response: blocking response :param blocking_response: blocking response
@@ -94,7 +94,7 @@ class AppGenerateResponseConverter(ABC):
return metadata return metadata
@classmethod @classmethod
def _error_to_stream_response(cls, e: Exception) -> dict: def _error_to_stream_response(cls, e: Exception):
""" """
Error to stream response. Error to stream response.
:param e: exception :param e: exception
+1 -1
View File
@@ -157,7 +157,7 @@ class BaseAppGenerator:
return value return value
def _sanitize_value(self, value: Any) -> Any: def _sanitize_value(self, value: Any):
if isinstance(value, str): if isinstance(value, str):
return value.replace("\x00", "") return value.replace("\x00", "")
return value return value
+6 -6
View File
@@ -25,7 +25,7 @@ class PublishFrom(IntEnum):
class AppQueueManager: class AppQueueManager:
def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom) -> None: def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom):
if not user_id: if not user_id:
raise ValueError("user is required") raise ValueError("user is required")
@@ -73,14 +73,14 @@ class AppQueueManager:
self.publish(QueuePingEvent(), PublishFrom.TASK_PIPELINE) self.publish(QueuePingEvent(), PublishFrom.TASK_PIPELINE)
last_ping_time = elapsed_time // 10 last_ping_time = elapsed_time // 10
def stop_listen(self) -> None: def stop_listen(self):
""" """
Stop listen to queue Stop listen to queue
:return: :return:
""" """
self._q.put(None) self._q.put(None)
def publish_error(self, e, pub_from: PublishFrom) -> None: def publish_error(self, e, pub_from: PublishFrom):
""" """
Publish error Publish error
:param e: error :param e: error
@@ -89,7 +89,7 @@ class AppQueueManager:
""" """
self.publish(QueueErrorEvent(error=e), pub_from) self.publish(QueueErrorEvent(error=e), pub_from)
def publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: def publish(self, event: AppQueueEvent, pub_from: PublishFrom):
""" """
Publish event to queue Publish event to queue
:param event: :param event:
@@ -100,7 +100,7 @@ class AppQueueManager:
self._publish(event, pub_from) self._publish(event, pub_from)
@abstractmethod @abstractmethod
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: def _publish(self, event: AppQueueEvent, pub_from: PublishFrom):
""" """
Publish event to queue Publish event to queue
:param event: :param event:
@@ -110,7 +110,7 @@ class AppQueueManager:
raise NotImplementedError raise NotImplementedError
@classmethod @classmethod
def set_stop_flag(cls, task_id: str, invoke_from: InvokeFrom, user_id: str) -> None: def set_stop_flag(cls, task_id: str, invoke_from: InvokeFrom, user_id: str):
""" """
Set task stop flag Set task stop flag
:return: :return:
+4 -6
View File
@@ -162,7 +162,7 @@ class AppRunner:
text: str, text: str,
stream: bool, stream: bool,
usage: Optional[LLMUsage] = None, usage: Optional[LLMUsage] = None,
) -> None: ):
""" """
Direct output Direct output
:param queue_manager: application queue manager :param queue_manager: application queue manager
@@ -204,7 +204,7 @@ class AppRunner:
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
stream: bool, stream: bool,
agent: bool = False, agent: bool = False,
) -> None: ):
""" """
Handle invoke result Handle invoke result
:param invoke_result: invoke result :param invoke_result: invoke result
@@ -220,9 +220,7 @@ class AppRunner:
else: else:
raise NotImplementedError(f"unsupported invoke result type: {type(invoke_result)}") raise NotImplementedError(f"unsupported invoke result type: {type(invoke_result)}")
def _handle_invoke_result_direct( def _handle_invoke_result_direct(self, invoke_result: LLMResult, queue_manager: AppQueueManager, agent: bool):
self, invoke_result: LLMResult, queue_manager: AppQueueManager, agent: bool
) -> None:
""" """
Handle invoke result direct Handle invoke result direct
:param invoke_result: invoke result :param invoke_result: invoke result
@@ -239,7 +237,7 @@ class AppRunner:
def _handle_invoke_result_stream( def _handle_invoke_result_stream(
self, invoke_result: Generator[LLMResultChunk, None, None], queue_manager: AppQueueManager, agent: bool self, invoke_result: Generator[LLMResultChunk, None, None], queue_manager: AppQueueManager, agent: bool
) -> None: ):
""" """
Handle invoke result Handle invoke result
:param invoke_result: invoke result :param invoke_result: invoke result
+1 -1
View File
@@ -81,7 +81,7 @@ class ChatAppConfigManager(BaseAppConfigManager):
return app_config return app_config
@classmethod @classmethod
def config_validate(cls, tenant_id: str, config: dict) -> dict: def config_validate(cls, tenant_id: str, config: dict):
""" """
Validate for chat app model config Validate for chat app model config
+1 -1
View File
@@ -211,7 +211,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
conversation_id: str, conversation_id: str,
message_id: str, message_id: str,
) -> None: ):
""" """
Generate worker in a new thread. Generate worker in a new thread.
:param flask_app: Flask app :param flask_app: Flask app
+1 -1
View File
@@ -33,7 +33,7 @@ class ChatAppRunner(AppRunner):
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
conversation: Conversation, conversation: Conversation,
message: Message, message: Message,
) -> None: ):
""" """
Run application Run application
:param application_generate_entity: application generate entity :param application_generate_entity: application generate entity
@@ -16,7 +16,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = ChatbotAppBlockingResponse _blocking_response_type = ChatbotAppBlockingResponse
@classmethod @classmethod
def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override] def convert_blocking_full_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
""" """
Convert blocking full response. Convert blocking full response.
:param blocking_response: blocking response :param blocking_response: blocking response
@@ -37,7 +37,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
return response return response
@classmethod @classmethod
def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse) -> dict: # type: ignore[override] def convert_blocking_simple_response(cls, blocking_response: ChatbotAppBlockingResponse): # type: ignore[override]
""" """
Convert blocking simple response. Convert blocking simple response.
:param blocking_response: blocking response :param blocking_response: blocking response
@@ -62,7 +62,7 @@ class WorkflowResponseConverter:
*, *,
application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity], application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity],
user: Union[Account, EndUser], user: Union[Account, EndUser],
) -> None: ):
self._application_generate_entity = application_generate_entity self._application_generate_entity = application_generate_entity
self._user = user self._user = user
@@ -66,7 +66,7 @@ class CompletionAppConfigManager(BaseAppConfigManager):
return app_config return app_config
@classmethod @classmethod
def config_validate(cls, tenant_id: str, config: dict) -> dict: def config_validate(cls, tenant_id: str, config: dict):
""" """
Validate for completion app model config Validate for completion app model config
@@ -192,7 +192,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
application_generate_entity: CompletionAppGenerateEntity, application_generate_entity: CompletionAppGenerateEntity,
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
message_id: str, message_id: str,
) -> None: ):
""" """
Generate worker in a new thread. Generate worker in a new thread.
:param flask_app: Flask app :param flask_app: Flask app
+1 -1
View File
@@ -27,7 +27,7 @@ class CompletionAppRunner(AppRunner):
def run( def run(
self, application_generate_entity: CompletionAppGenerateEntity, queue_manager: AppQueueManager, message: Message self, application_generate_entity: CompletionAppGenerateEntity, queue_manager: AppQueueManager, message: Message
) -> None: ):
""" """
Run application Run application
:param application_generate_entity: application generate entity :param application_generate_entity: application generate entity
@@ -16,7 +16,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = CompletionAppBlockingResponse _blocking_response_type = CompletionAppBlockingResponse
@classmethod @classmethod
def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse) -> dict: # type: ignore[override] def convert_blocking_full_response(cls, blocking_response: CompletionAppBlockingResponse): # type: ignore[override]
""" """
Convert blocking full response. Convert blocking full response.
:param blocking_response: blocking response :param blocking_response: blocking response
@@ -36,7 +36,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
return response return response
@classmethod @classmethod
def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse) -> dict: # type: ignore[override] def convert_blocking_simple_response(cls, blocking_response: CompletionAppBlockingResponse): # type: ignore[override]
""" """
Convert blocking simple response. Convert blocking simple response.
:param blocking_response: blocking response :param blocking_response: blocking response
@@ -14,14 +14,14 @@ from core.app.entities.queue_entities import (
class MessageBasedAppQueueManager(AppQueueManager): class MessageBasedAppQueueManager(AppQueueManager):
def __init__( def __init__(
self, task_id: str, user_id: str, invoke_from: InvokeFrom, conversation_id: str, app_mode: str, message_id: str self, task_id: str, user_id: str, invoke_from: InvokeFrom, conversation_id: str, app_mode: str, message_id: str
) -> None: ):
super().__init__(task_id, user_id, invoke_from) super().__init__(task_id, user_id, invoke_from)
self._conversation_id = str(conversation_id) self._conversation_id = str(conversation_id)
self._app_mode = app_mode self._app_mode = app_mode
self._message_id = str(message_id) self._message_id = str(message_id)
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: def _publish(self, event: AppQueueEvent, pub_from: PublishFrom):
""" """
Publish event to queue Publish event to queue
:param event: :param event:
@@ -35,7 +35,7 @@ class WorkflowAppConfigManager(BaseAppConfigManager):
return app_config return app_config
@classmethod @classmethod
def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict: def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False):
""" """
Validate for workflow app model config Validate for workflow app model config
+1 -1
View File
@@ -435,7 +435,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
context: contextvars.Context, context: contextvars.Context,
variable_loader: VariableLoader, variable_loader: VariableLoader,
workflow_thread_pool_id: Optional[str] = None, workflow_thread_pool_id: Optional[str] = None,
) -> None: ):
""" """
Generate worker in a new thread. Generate worker in a new thread.
:param flask_app: Flask app :param flask_app: Flask app
@@ -14,12 +14,12 @@ from core.app.entities.queue_entities import (
class WorkflowAppQueueManager(AppQueueManager): class WorkflowAppQueueManager(AppQueueManager):
def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom, app_mode: str) -> None: def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom, app_mode: str):
super().__init__(task_id, user_id, invoke_from) super().__init__(task_id, user_id, invoke_from)
self._app_mode = app_mode self._app_mode = app_mode
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: def _publish(self, event: AppQueueEvent, pub_from: PublishFrom):
""" """
Publish event to queue Publish event to queue
:param event: :param event:
+2 -2
View File
@@ -34,7 +34,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
workflow_thread_pool_id: Optional[str] = None, workflow_thread_pool_id: Optional[str] = None,
workflow: Workflow, workflow: Workflow,
system_user_id: str, system_user_id: str,
) -> None: ):
super().__init__( super().__init__(
queue_manager=queue_manager, queue_manager=queue_manager,
variable_loader=variable_loader, variable_loader=variable_loader,
@@ -45,7 +45,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
self._workflow = workflow self._workflow = workflow
self._sys_user_id = system_user_id self._sys_user_id = system_user_id
def run(self) -> None: def run(self):
""" """
Run application Run application
""" """
@@ -17,7 +17,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = WorkflowAppBlockingResponse _blocking_response_type = WorkflowAppBlockingResponse
@classmethod @classmethod
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override] def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse): # type: ignore[override]
""" """
Convert blocking full response. Convert blocking full response.
:param blocking_response: blocking response :param blocking_response: blocking response
@@ -26,7 +26,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
return dict(blocking_response.to_dict()) return dict(blocking_response.to_dict())
@classmethod @classmethod
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override] def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse): # type: ignore[override]
""" """
Convert blocking simple response. Convert blocking simple response.
:param blocking_response: blocking response :param blocking_response: blocking response
@@ -92,7 +92,7 @@ class WorkflowAppGenerateTaskPipeline:
workflow_execution_repository: WorkflowExecutionRepository, workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository, workflow_node_execution_repository: WorkflowNodeExecutionRepository,
draft_var_saver_factory: DraftVariableSaverFactory, draft_var_saver_factory: DraftVariableSaverFactory,
) -> None: ):
self._base_task_pipeline = BasedGenerateTaskPipeline( self._base_task_pipeline = BasedGenerateTaskPipeline(
application_generate_entity=application_generate_entity, application_generate_entity=application_generate_entity,
queue_manager=queue_manager, queue_manager=queue_manager,
@@ -263,7 +263,7 @@ class WorkflowAppGenerateTaskPipeline:
session.rollback() session.rollback()
raise raise
def _ensure_workflow_initialized(self) -> None: def _ensure_workflow_initialized(self):
"""Fluent validation for workflow state.""" """Fluent validation for workflow state."""
if not self._workflow_run_id: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
@@ -744,7 +744,7 @@ class WorkflowAppGenerateTaskPipeline:
if tts_publisher: if tts_publisher:
tts_publisher.publish(None) tts_publisher.publish(None)
def _save_workflow_app_log(self, *, session: Session, workflow_execution: WorkflowExecution) -> None: def _save_workflow_app_log(self, *, session: Session, workflow_execution: WorkflowExecution):
invoke_from = self._application_generate_entity.invoke_from invoke_from = self._application_generate_entity.invoke_from
if invoke_from == InvokeFrom.SERVICE_API: if invoke_from == InvokeFrom.SERVICE_API:
created_from = WorkflowAppLogCreatedFrom.SERVICE_API created_from = WorkflowAppLogCreatedFrom.SERVICE_API
+3 -3
View File
@@ -74,7 +74,7 @@ class WorkflowBasedAppRunner:
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER, variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
app_id: str, app_id: str,
) -> None: ):
self._queue_manager = queue_manager self._queue_manager = queue_manager
self._variable_loader = variable_loader self._variable_loader = variable_loader
self._app_id = app_id self._app_id = app_id
@@ -292,7 +292,7 @@ class WorkflowBasedAppRunner:
return graph, variable_pool return graph, variable_pool
def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent) -> None: def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent):
""" """
Handle event Handle event
:param workflow_entry: workflow entry :param workflow_entry: workflow entry
@@ -694,5 +694,5 @@ class WorkflowBasedAppRunner:
) )
) )
def _publish_event(self, event: AppQueueEvent) -> None: def _publish_event(self, event: AppQueueEvent):
self._queue_manager.publish(event, PublishFrom.APPLICATION_MANAGER) self._queue_manager.publish(event, PublishFrom.APPLICATION_MANAGER)
@@ -35,7 +35,7 @@ class BasedGenerateTaskPipeline:
application_generate_entity: AppGenerateEntity, application_generate_entity: AppGenerateEntity,
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
stream: bool, stream: bool,
) -> None: ):
self._application_generate_entity = application_generate_entity self._application_generate_entity = application_generate_entity
self.queue_manager = queue_manager self.queue_manager = queue_manager
self._start_at = time.perf_counter() self._start_at = time.perf_counter()
@@ -80,7 +80,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
conversation: Conversation, conversation: Conversation,
message: Message, message: Message,
stream: bool, stream: bool,
) -> None: ):
super().__init__( super().__init__(
application_generate_entity=application_generate_entity, application_generate_entity=application_generate_entity,
queue_manager=queue_manager, queue_manager=queue_manager,
@@ -362,7 +362,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
if self._conversation_name_generate_thread: if self._conversation_name_generate_thread:
self._conversation_name_generate_thread.join() self._conversation_name_generate_thread.join()
def _save_message(self, *, session: Session, trace_manager: Optional[TraceQueueManager] = None) -> None: def _save_message(self, *, session: Session, trace_manager: Optional[TraceQueueManager] = None):
""" """
Save message. Save message.
:return: :return:
@@ -412,7 +412,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
application_generate_entity=self._application_generate_entity, application_generate_entity=self._application_generate_entity,
) )
def _handle_stop(self, event: QueueStopEvent) -> None: def _handle_stop(self, event: QueueStopEvent):
""" """
Handle stop. Handle stop.
:return: :return:
@@ -48,7 +48,7 @@ class MessageCycleManager:
AdvancedChatAppGenerateEntity, AdvancedChatAppGenerateEntity,
], ],
task_state: Union[EasyUITaskState, WorkflowTaskState], task_state: Union[EasyUITaskState, WorkflowTaskState],
) -> None: ):
self._application_generate_entity = application_generate_entity self._application_generate_entity = application_generate_entity
self._task_state = task_state self._task_state = task_state
@@ -132,7 +132,7 @@ class MessageCycleManager:
return None return None
def handle_retriever_resources(self, event: QueueRetrieverResourcesEvent) -> None: def handle_retriever_resources(self, event: QueueRetrieverResourcesEvent):
""" """
Handle retriever resources. Handle retriever resources.
:param event: event :param event: event
@@ -23,7 +23,7 @@ def get_colored_text(text: str, color: str) -> str:
return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m" return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m"
def print_text(text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None) -> None: def print_text(text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None):
"""Print text with highlighting and no end characters.""" """Print text with highlighting and no end characters."""
text_to_print = get_colored_text(text, color) if color else text text_to_print = get_colored_text(text, color) if color else text
print(text_to_print, end=end, file=file) print(text_to_print, end=end, file=file)
@@ -37,7 +37,7 @@ class DifyAgentCallbackHandler(BaseModel):
color: Optional[str] = "" color: Optional[str] = ""
current_loop: int = 1 current_loop: int = 1
def __init__(self, color: Optional[str] = None) -> None: def __init__(self, color: Optional[str] = None):
super().__init__() super().__init__()
"""Initialize callback handler.""" """Initialize callback handler."""
# use a specific color is not specified # use a specific color is not specified
@@ -48,7 +48,7 @@ class DifyAgentCallbackHandler(BaseModel):
self, self,
tool_name: str, tool_name: str,
tool_inputs: Mapping[str, Any], tool_inputs: Mapping[str, Any],
) -> None: ):
"""Do nothing.""" """Do nothing."""
if dify_config.DEBUG: if dify_config.DEBUG:
print_text("\n[on_tool_start] ToolCall:" + tool_name + "\n" + str(tool_inputs) + "\n", color=self.color) print_text("\n[on_tool_start] ToolCall:" + tool_name + "\n" + str(tool_inputs) + "\n", color=self.color)
@@ -61,7 +61,7 @@ class DifyAgentCallbackHandler(BaseModel):
message_id: Optional[str] = None, message_id: Optional[str] = None,
timer: Optional[Any] = None, timer: Optional[Any] = None,
trace_manager: Optional[TraceQueueManager] = None, trace_manager: Optional[TraceQueueManager] = None,
) -> None: ):
"""If not the final action, print out observation.""" """If not the final action, print out observation."""
if dify_config.DEBUG: if dify_config.DEBUG:
print_text("\n[on_tool_end]\n", color=self.color) print_text("\n[on_tool_end]\n", color=self.color)
@@ -82,12 +82,12 @@ class DifyAgentCallbackHandler(BaseModel):
) )
) )
def on_tool_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> None: def on_tool_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any):
"""Do nothing.""" """Do nothing."""
if dify_config.DEBUG: if dify_config.DEBUG:
print_text("\n[on_tool_error] Error: " + str(error) + "\n", color="red") print_text("\n[on_tool_error] Error: " + str(error) + "\n", color="red")
def on_agent_start(self, thought: str) -> None: def on_agent_start(self, thought: str):
"""Run on agent start.""" """Run on agent start."""
if dify_config.DEBUG: if dify_config.DEBUG:
if thought: if thought:
@@ -98,7 +98,7 @@ class DifyAgentCallbackHandler(BaseModel):
else: else:
print_text("\n[on_agent_start] \nCurrent Loop: " + str(self.current_loop) + "\n", color=self.color) print_text("\n[on_agent_start] \nCurrent Loop: " + str(self.current_loop) + "\n", color=self.color)
def on_agent_finish(self, color: Optional[str] = None, **kwargs: Any) -> None: def on_agent_finish(self, color: Optional[str] = None, **kwargs: Any):
"""Run on agent end.""" """Run on agent end."""
if dify_config.DEBUG: if dify_config.DEBUG:
print_text("\n[on_agent_finish]\n Loop: " + str(self.current_loop) + "\n", color=self.color) print_text("\n[on_agent_finish]\n Loop: " + str(self.current_loop) + "\n", color=self.color)
@@ -21,14 +21,14 @@ class DatasetIndexToolCallbackHandler:
def __init__( def __init__(
self, queue_manager: AppQueueManager, app_id: str, message_id: str, user_id: str, invoke_from: InvokeFrom self, queue_manager: AppQueueManager, app_id: str, message_id: str, user_id: str, invoke_from: InvokeFrom
) -> None: ):
self._queue_manager = queue_manager self._queue_manager = queue_manager
self._app_id = app_id self._app_id = app_id
self._message_id = message_id self._message_id = message_id
self._user_id = user_id self._user_id = user_id
self._invoke_from = invoke_from self._invoke_from = invoke_from
def on_query(self, query: str, dataset_id: str) -> None: def on_query(self, query: str, dataset_id: str):
""" """
Handle query. Handle query.
""" """
@@ -46,7 +46,7 @@ class DatasetIndexToolCallbackHandler:
db.session.add(dataset_query) db.session.add(dataset_query)
db.session.commit() db.session.commit()
def on_tool_end(self, documents: list[Document]) -> None: def on_tool_end(self, documents: list[Document]):
"""Handle tool end.""" """Handle tool end."""
for document in documents: for document in documents:
if document.metadata is not None: if document.metadata is not None:
+2 -2
View File
@@ -33,7 +33,7 @@ class SimpleModelProviderEntity(BaseModel):
icon_large: Optional[I18nObject] = None icon_large: Optional[I18nObject] = None
supported_model_types: list[ModelType] supported_model_types: list[ModelType]
def __init__(self, provider_entity: ProviderEntity) -> None: def __init__(self, provider_entity: ProviderEntity):
""" """
Init simple provider. Init simple provider.
@@ -57,7 +57,7 @@ class ProviderModelWithStatusEntity(ProviderModel):
load_balancing_enabled: bool = False load_balancing_enabled: bool = False
has_invalid_load_balancing_configs: bool = False has_invalid_load_balancing_configs: bool = False
def raise_for_status(self) -> None: def raise_for_status(self):
""" """
Check model status and raise ValueError if not active. Check model status and raise ValueError if not active.
+16 -18
View File
@@ -280,9 +280,7 @@ class ProviderConfiguration(BaseModel):
else [], else [],
) )
def validate_provider_credentials( def validate_provider_credentials(self, credentials: dict, credential_id: str = "", session: Session | None = None):
self, credentials: dict, credential_id: str = "", session: Session | None = None
) -> dict:
""" """
Validate custom credentials. Validate custom credentials.
:param credentials: provider credentials :param credentials: provider credentials
@@ -291,7 +289,7 @@ class ProviderConfiguration(BaseModel):
:return: :return:
""" """
def _validate(s: Session) -> dict: def _validate(s: Session):
# Get provider credential secret variables # Get provider credential secret variables
provider_credential_secret_variables = self.extract_secret_variables( provider_credential_secret_variables = self.extract_secret_variables(
self.provider.provider_credential_schema.credential_form_schemas self.provider.provider_credential_schema.credential_form_schemas
@@ -402,7 +400,7 @@ class ProviderConfiguration(BaseModel):
logger.warning("Error generating next credential name: %s", str(e)) logger.warning("Error generating next credential name: %s", str(e))
return "API KEY 1" return "API KEY 1"
def create_provider_credential(self, credentials: dict, credential_name: str | None) -> None: def create_provider_credential(self, credentials: dict, credential_name: str | None):
""" """
Add custom provider credentials. Add custom provider credentials.
:param credentials: provider credentials :param credentials: provider credentials
@@ -458,7 +456,7 @@ class ProviderConfiguration(BaseModel):
credentials: dict, credentials: dict,
credential_id: str, credential_id: str,
credential_name: str | None, credential_name: str | None,
) -> None: ):
""" """
update a saved provider credential (by credential_id). update a saved provider credential (by credential_id).
@@ -519,7 +517,7 @@ class ProviderConfiguration(BaseModel):
credential_record: ProviderCredential | ProviderModelCredential, credential_record: ProviderCredential | ProviderModelCredential,
credential_source: str, credential_source: str,
session: Session, session: Session,
) -> None: ):
""" """
Update load balancing configurations that reference the given credential_id. Update load balancing configurations that reference the given credential_id.
@@ -559,7 +557,7 @@ class ProviderConfiguration(BaseModel):
session.commit() session.commit()
def delete_provider_credential(self, credential_id: str) -> None: def delete_provider_credential(self, credential_id: str):
""" """
Delete a saved provider credential (by credential_id). Delete a saved provider credential (by credential_id).
@@ -636,7 +634,7 @@ class ProviderConfiguration(BaseModel):
session.rollback() session.rollback()
raise raise
def switch_active_provider_credential(self, credential_id: str) -> None: def switch_active_provider_credential(self, credential_id: str):
""" """
Switch active provider credential (copy the selected one into current active snapshot). Switch active provider credential (copy the selected one into current active snapshot).
@@ -814,7 +812,7 @@ class ProviderConfiguration(BaseModel):
credentials: dict, credentials: dict,
credential_id: str = "", credential_id: str = "",
session: Session | None = None, session: Session | None = None,
) -> dict: ):
""" """
Validate custom model credentials. Validate custom model credentials.
@@ -825,7 +823,7 @@ class ProviderConfiguration(BaseModel):
:return: :return:
""" """
def _validate(s: Session) -> dict: def _validate(s: Session):
# Get provider credential secret variables # Get provider credential secret variables
provider_credential_secret_variables = self.extract_secret_variables( provider_credential_secret_variables = self.extract_secret_variables(
self.provider.model_credential_schema.credential_form_schemas self.provider.model_credential_schema.credential_form_schemas
@@ -1009,7 +1007,7 @@ class ProviderConfiguration(BaseModel):
session.rollback() session.rollback()
raise raise
def delete_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str) -> None: def delete_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str):
""" """
Delete a saved provider credential (by credential_id). Delete a saved provider credential (by credential_id).
@@ -1079,7 +1077,7 @@ class ProviderConfiguration(BaseModel):
session.rollback() session.rollback()
raise raise
def add_model_credential_to_model(self, model_type: ModelType, model: str, credential_id: str) -> None: def add_model_credential_to_model(self, model_type: ModelType, model: str, credential_id: str):
""" """
if model list exist this custom model, switch the custom model credential. if model list exist this custom model, switch the custom model credential.
if model list not exist this custom model, use the credential to add a new custom model record. if model list not exist this custom model, use the credential to add a new custom model record.
@@ -1122,7 +1120,7 @@ class ProviderConfiguration(BaseModel):
session.add(provider_model_record) session.add(provider_model_record)
session.commit() session.commit()
def switch_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str) -> None: def switch_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str):
""" """
switch the custom model credential. switch the custom model credential.
@@ -1152,7 +1150,7 @@ class ProviderConfiguration(BaseModel):
session.add(provider_model_record) session.add(provider_model_record)
session.commit() session.commit()
def delete_custom_model(self, model_type: ModelType, model: str) -> None: def delete_custom_model(self, model_type: ModelType, model: str):
""" """
Delete custom model. Delete custom model.
:param model_type: model type :param model_type: model type
@@ -1347,7 +1345,7 @@ class ProviderConfiguration(BaseModel):
provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
) )
def switch_preferred_provider_type(self, provider_type: ProviderType, session: Session | None = None) -> None: def switch_preferred_provider_type(self, provider_type: ProviderType, session: Session | None = None):
""" """
Switch preferred provider type. Switch preferred provider type.
:param provider_type: :param provider_type:
@@ -1359,7 +1357,7 @@ class ProviderConfiguration(BaseModel):
if provider_type == ProviderType.SYSTEM and not self.system_configuration.enabled: if provider_type == ProviderType.SYSTEM and not self.system_configuration.enabled:
return return
def _switch(s: Session) -> None: def _switch(s: Session):
# get preferred provider # get preferred provider
model_provider_id = ModelProviderID(self.provider.provider) model_provider_id = ModelProviderID(self.provider.provider)
provider_names = [self.provider.provider] provider_names = [self.provider.provider]
@@ -1403,7 +1401,7 @@ class ProviderConfiguration(BaseModel):
return secret_input_form_variables return secret_input_form_variables
def obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]) -> dict: def obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]):
""" """
Obfuscated credentials. Obfuscated credentials.
+1 -1
View File
@@ -6,7 +6,7 @@ class LLMError(ValueError):
description: Optional[str] = None description: Optional[str] = None
def __init__(self, description: Optional[str] = None) -> None: def __init__(self, description: Optional[str] = None):
self.description = description self.description = description
@@ -10,11 +10,11 @@ class APIBasedExtensionRequestor:
timeout: tuple[int, int] = (5, 60) timeout: tuple[int, int] = (5, 60)
"""timeout for request connect and read""" """timeout for request connect and read"""
def __init__(self, api_endpoint: str, api_key: str) -> None: def __init__(self, api_endpoint: str, api_key: str):
self.api_endpoint = api_endpoint self.api_endpoint = api_endpoint
self.api_key = api_key self.api_key = api_key
def request(self, point: APIBasedExtensionPoint, params: dict) -> dict: def request(self, point: APIBasedExtensionPoint, params: dict):
""" """
Request the api. Request the api.
+1 -1
View File
@@ -34,7 +34,7 @@ class Extensible:
tenant_id: str tenant_id: str
config: Optional[dict] = None config: Optional[dict] = None
def __init__(self, tenant_id: str, config: Optional[dict] = None) -> None: def __init__(self, tenant_id: str, config: Optional[dict] = None):
self.tenant_id = tenant_id self.tenant_id = tenant_id
self.config = config self.config = config
+1 -1
View File
@@ -18,7 +18,7 @@ class ApiExternalDataTool(ExternalDataTool):
"""the unique name of external data tool""" """the unique name of external data tool"""
@classmethod @classmethod
def validate_config(cls, tenant_id: str, config: dict) -> None: def validate_config(cls, tenant_id: str, config: dict):
""" """
Validate the incoming form config data. Validate the incoming form config data.
+2 -2
View File
@@ -16,14 +16,14 @@ class ExternalDataTool(Extensible, ABC):
variable: str variable: str
"""the tool variable name of app tool""" """the tool variable name of app tool"""
def __init__(self, tenant_id: str, app_id: str, variable: str, config: Optional[dict] = None) -> None: def __init__(self, tenant_id: str, app_id: str, variable: str, config: Optional[dict] = None):
super().__init__(tenant_id, config) super().__init__(tenant_id, config)
self.app_id = app_id self.app_id = app_id
self.variable = variable self.variable = variable
@classmethod @classmethod
@abstractmethod @abstractmethod
def validate_config(cls, tenant_id: str, config: dict) -> None: def validate_config(cls, tenant_id: str, config: dict):
""" """
Validate the incoming form config data. Validate the incoming form config data.
+2 -2
View File
@@ -6,14 +6,14 @@ from extensions.ext_code_based_extension import code_based_extension
class ExternalDataToolFactory: class ExternalDataToolFactory:
def __init__(self, name: str, tenant_id: str, app_id: str, variable: str, config: dict) -> None: def __init__(self, name: str, tenant_id: str, app_id: str, variable: str, config: dict):
extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name) extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name)
self.__extension_instance = extension_class( self.__extension_instance = extension_class(
tenant_id=tenant_id, app_id=app_id, variable=variable, config=config tenant_id=tenant_id, app_id=app_id, variable=variable, config=config
) )
@classmethod @classmethod
def validate_config(cls, name: str, tenant_id: str, config: dict) -> None: def validate_config(cls, name: str, tenant_id: str, config: dict):
""" """
Validate the incoming form config data. Validate the incoming form config data.
+1 -1
View File
@@ -7,6 +7,6 @@ if TYPE_CHECKING:
_tool_file_manager_factory: Callable[[], "ToolFileManager"] | None = None _tool_file_manager_factory: Callable[[], "ToolFileManager"] | None = None
def set_tool_file_manager_factory(factory: Callable[[], "ToolFileManager"]) -> None: def set_tool_file_manager_factory(factory: Callable[[], "ToolFileManager"]):
global _tool_file_manager_factory global _tool_file_manager_factory
_tool_file_manager_factory = factory _tool_file_manager_factory = factory
@@ -22,7 +22,7 @@ class CodeNodeProvider(BaseModel):
pass pass
@classmethod @classmethod
def get_default_config(cls) -> dict: def get_default_config(cls):
return { return {
"type": "code", "type": "code",
"config": { "config": {
@@ -5,7 +5,7 @@ from core.helper.code_executor.template_transformer import TemplateTransformer
class Jinja2TemplateTransformer(TemplateTransformer): class Jinja2TemplateTransformer(TemplateTransformer):
@classmethod @classmethod
def transform_response(cls, response: str) -> dict: def transform_response(cls, response: str):
""" """
Transform response to dict Transform response to dict
:param response: response :param response: response
@@ -13,7 +13,7 @@ class Python3CodeProvider(CodeNodeProvider):
def get_default_code(cls) -> str: def get_default_code(cls) -> str:
return dedent( return dedent(
""" """
def main(arg1: str, arg2: str) -> dict: def main(arg1: str, arg2: str):
return { return {
"result": arg1 + arg2, "result": arg1 + arg2,
} }
+2 -2
View File
@@ -34,7 +34,7 @@ class ProviderCredentialsCache:
else: else:
return None return None
def set(self, credentials: dict) -> None: def set(self, credentials: dict):
""" """
Cache model provider credentials. Cache model provider credentials.
@@ -43,7 +43,7 @@ class ProviderCredentialsCache:
""" """
redis_client.setex(self.cache_key, 86400, json.dumps(credentials)) redis_client.setex(self.cache_key, 86400, json.dumps(credentials))
def delete(self) -> None: def delete(self):
""" """
Delete cached model provider credentials. Delete cached model provider credentials.
+4 -4
View File
@@ -28,11 +28,11 @@ class ProviderCredentialsCache(ABC):
return None return None
return None return None
def set(self, config: dict[str, Any]) -> None: def set(self, config: dict[str, Any]):
"""Cache provider credentials""" """Cache provider credentials"""
redis_client.setex(self.cache_key, 86400, json.dumps(config)) redis_client.setex(self.cache_key, 86400, json.dumps(config))
def delete(self) -> None: def delete(self):
"""Delete cached provider credentials""" """Delete cached provider credentials"""
redis_client.delete(self.cache_key) redis_client.delete(self.cache_key)
@@ -75,10 +75,10 @@ class NoOpProviderCredentialCache:
"""Get cached provider credentials""" """Get cached provider credentials"""
return None return None
def set(self, config: dict[str, Any]) -> None: def set(self, config: dict[str, Any]):
"""Cache provider credentials""" """Cache provider credentials"""
pass pass
def delete(self) -> None: def delete(self):
"""Delete cached provider credentials""" """Delete cached provider credentials"""
pass pass
+2 -2
View File
@@ -37,11 +37,11 @@ class ToolParameterCache:
else: else:
return None return None
def set(self, parameters: dict) -> None: def set(self, parameters: dict):
"""Cache model provider credentials.""" """Cache model provider credentials."""
redis_client.setex(self.cache_key, 86400, json.dumps(parameters)) redis_client.setex(self.cache_key, 86400, json.dumps(parameters))
def delete(self) -> None: def delete(self):
""" """
Delete cached model provider credentials. Delete cached model provider credentials.
+1 -1
View File
@@ -49,7 +49,7 @@ def get_external_trace_id(request: Any) -> Optional[str]:
return None return None
def extract_external_trace_id_from_args(args: Mapping[str, Any]) -> dict: def extract_external_trace_id_from_args(args: Mapping[str, Any]):
""" """
Extract 'external_trace_id' from args. Extract 'external_trace_id' from args.
+2 -2
View File
@@ -44,11 +44,11 @@ class HostingConfiguration:
provider_map: dict[str, HostingProvider] provider_map: dict[str, HostingProvider]
moderation_config: Optional[HostedModerationConfig] = None moderation_config: Optional[HostedModerationConfig] = None
def __init__(self) -> None: def __init__(self):
self.provider_map = {} self.provider_map = {}
self.moderation_config = None self.moderation_config = None
def init_app(self, app: Flask) -> None: def init_app(self, app: Flask):
if dify_config.EDITION != "CLOUD": if dify_config.EDITION != "CLOUD":
return return
+3 -3
View File
@@ -512,7 +512,7 @@ class IndexingRunner:
dataset: Dataset, dataset: Dataset,
dataset_document: DatasetDocument, dataset_document: DatasetDocument,
documents: list[Document], documents: list[Document],
) -> None: ):
""" """
insert index and update document/segment status to completed insert index and update document/segment status to completed
""" """
@@ -651,7 +651,7 @@ class IndexingRunner:
@staticmethod @staticmethod
def _update_document_index_status( def _update_document_index_status(
document_id: str, after_indexing_status: str, extra_update_params: Optional[dict] = None document_id: str, after_indexing_status: str, extra_update_params: Optional[dict] = None
) -> None: ):
""" """
Update the document indexing status. Update the document indexing status.
""" """
@@ -670,7 +670,7 @@ class IndexingRunner:
db.session.commit() db.session.commit()
@staticmethod @staticmethod
def _update_segments_by_document(dataset_document_id: str, update_params: dict) -> None: def _update_segments_by_document(dataset_document_id: str, update_params: dict):
""" """
Update the document segment by document id. Update the document segment by document id.
""" """
+6 -8
View File
@@ -127,7 +127,7 @@ class LLMGenerator:
return questions return questions
@classmethod @classmethod
def generate_rule_config(cls, tenant_id: str, instruction: str, model_config: dict, no_variable: bool) -> dict: def generate_rule_config(cls, tenant_id: str, instruction: str, model_config: dict, no_variable: bool):
output_parser = RuleConfigGeneratorOutputParser() output_parser = RuleConfigGeneratorOutputParser()
error = "" error = ""
@@ -262,9 +262,7 @@ class LLMGenerator:
return rule_config return rule_config
@classmethod @classmethod
def generate_code( def generate_code(cls, tenant_id: str, instruction: str, model_config: dict, code_language: str = "javascript"):
cls, tenant_id: str, instruction: str, model_config: dict, code_language: str = "javascript"
) -> dict:
if code_language == "python": if code_language == "python":
prompt_template = PromptTemplateParser(PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE) prompt_template = PromptTemplateParser(PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE)
else: else:
@@ -373,7 +371,7 @@ class LLMGenerator:
@staticmethod @staticmethod
def instruction_modify_legacy( def instruction_modify_legacy(
tenant_id: str, flow_id: str, current: str, instruction: str, model_config: dict, ideal_output: str | None tenant_id: str, flow_id: str, current: str, instruction: str, model_config: dict, ideal_output: str | None
) -> dict: ):
last_run: Message | None = ( last_run: Message | None = (
db.session.query(Message).where(Message.app_id == flow_id).order_by(Message.created_at.desc()).first() db.session.query(Message).where(Message.app_id == flow_id).order_by(Message.created_at.desc()).first()
) )
@@ -413,7 +411,7 @@ class LLMGenerator:
instruction: str, instruction: str,
model_config: dict, model_config: dict,
ideal_output: str | None, ideal_output: str | None,
) -> dict: ):
from services.workflow_service import WorkflowService from services.workflow_service import WorkflowService
app: App | None = db.session.query(App).where(App.id == flow_id).first() app: App | None = db.session.query(App).where(App.id == flow_id).first()
@@ -451,7 +449,7 @@ class LLMGenerator:
return [] return []
parsed: Sequence[AgentLogEvent] = json.loads(raw_agent_log) parsed: Sequence[AgentLogEvent] = json.loads(raw_agent_log)
def dict_of_event(event: AgentLogEvent) -> dict: def dict_of_event(event: AgentLogEvent):
return { return {
"status": event.status, "status": event.status,
"error": event.error, "error": event.error,
@@ -488,7 +486,7 @@ class LLMGenerator:
instruction: str, instruction: str,
node_type: str, node_type: str,
ideal_output: str | None, ideal_output: str | None,
) -> dict: ):
LAST_RUN = "{{#last_run#}}" LAST_RUN = "{{#last_run#}}"
CURRENT = "{{#current#}}" CURRENT = "{{#current#}}"
ERROR_MESSAGE = "{{#error_message#}}" ERROR_MESSAGE = "{{#error_message#}}"
@@ -1,5 +1,3 @@
from typing import Any
from core.llm_generator.output_parser.errors import OutputParserError from core.llm_generator.output_parser.errors import OutputParserError
from core.llm_generator.prompts import ( from core.llm_generator.prompts import (
RULE_CONFIG_PARAMETER_GENERATE_TEMPLATE, RULE_CONFIG_PARAMETER_GENERATE_TEMPLATE,
@@ -17,7 +15,7 @@ class RuleConfigGeneratorOutputParser:
RULE_CONFIG_STATEMENT_GENERATE_TEMPLATE, RULE_CONFIG_STATEMENT_GENERATE_TEMPLATE,
) )
def parse(self, text: str) -> Any: def parse(self, text: str):
try: try:
expected_keys = ["prompt", "variables", "opening_statement"] expected_keys = ["prompt", "variables", "opening_statement"]
parsed = parse_and_check_json_markdown(text, expected_keys) parsed = parse_and_check_json_markdown(text, expected_keys)
@@ -210,7 +210,7 @@ def _handle_native_json_schema(
structured_output_schema: Mapping, structured_output_schema: Mapping,
model_parameters: dict, model_parameters: dict,
rules: list[ParameterRule], rules: list[ParameterRule],
) -> dict: ):
""" """
Handle structured output for models with native JSON schema support. Handle structured output for models with native JSON schema support.
@@ -232,7 +232,7 @@ def _handle_native_json_schema(
return model_parameters return model_parameters
def _set_response_format(model_parameters: dict, rules: list) -> None: def _set_response_format(model_parameters: dict, rules: list):
""" """
Set the appropriate response format parameter based on model rules. Set the appropriate response format parameter based on model rules.
@@ -306,7 +306,7 @@ def _parse_structured_output(result_text: str) -> Mapping[str, Any]:
return structured_output return structured_output
def _prepare_schema_for_model(provider: str, model_schema: AIModelEntity, schema: Mapping) -> dict: def _prepare_schema_for_model(provider: str, model_schema: AIModelEntity, schema: Mapping):
""" """
Prepare JSON schema based on model requirements. Prepare JSON schema based on model requirements.
@@ -334,7 +334,7 @@ def _prepare_schema_for_model(provider: str, model_schema: AIModelEntity, schema
return {"schema": processed_schema, "name": "llm_response"} return {"schema": processed_schema, "name": "llm_response"}
def remove_additional_properties(schema: dict) -> None: def remove_additional_properties(schema: dict):
""" """
Remove additionalProperties fields from JSON schema. Remove additionalProperties fields from JSON schema.
Used for models like Gemini that don't support this property. Used for models like Gemini that don't support this property.
@@ -357,7 +357,7 @@ def remove_additional_properties(schema: dict) -> None:
remove_additional_properties(item) remove_additional_properties(item)
def convert_boolean_to_string(schema: dict) -> None: def convert_boolean_to_string(schema: dict):
""" """
Convert boolean type specifications to string in JSON schema. Convert boolean type specifications to string in JSON schema.
@@ -1,6 +1,5 @@
import json import json
import re import re
from typing import Any
from core.llm_generator.prompts import SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT from core.llm_generator.prompts import SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT
@@ -9,7 +8,7 @@ class SuggestedQuestionsAfterAnswerOutputParser:
def get_format_instructions(self) -> str: def get_format_instructions(self) -> str:
return SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT return SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT
def parse(self, text: str) -> Any: def parse(self, text: str):
action_match = re.search(r"\[.*?\]", text.strip(), re.DOTALL) action_match = re.search(r"\[.*?\]", text.strip(), re.DOTALL)
if action_match is not None: if action_match is not None:
json_obj = json.loads(action_match.group(0).strip()) json_obj = json.loads(action_match.group(0).strip())
+3 -3
View File
@@ -44,7 +44,7 @@ class OAuthClientProvider:
return None return None
return OAuthClientInformation.model_validate(client_information) return OAuthClientInformation.model_validate(client_information)
def save_client_information(self, client_information: OAuthClientInformationFull) -> None: def save_client_information(self, client_information: OAuthClientInformationFull):
"""Saves client information after dynamic registration.""" """Saves client information after dynamic registration."""
MCPToolManageService.update_mcp_provider_credentials( MCPToolManageService.update_mcp_provider_credentials(
self.mcp_provider, self.mcp_provider,
@@ -63,13 +63,13 @@ class OAuthClientProvider:
refresh_token=credentials.get("refresh_token", ""), refresh_token=credentials.get("refresh_token", ""),
) )
def save_tokens(self, tokens: OAuthTokens) -> None: def save_tokens(self, tokens: OAuthTokens):
"""Stores new OAuth tokens for the current session.""" """Stores new OAuth tokens for the current session."""
# update mcp provider credentials # update mcp provider credentials
token_dict = tokens.model_dump() token_dict = tokens.model_dump()
MCPToolManageService.update_mcp_provider_credentials(self.mcp_provider, token_dict, authed=True) MCPToolManageService.update_mcp_provider_credentials(self.mcp_provider, token_dict, authed=True)
def save_code_verifier(self, code_verifier: str) -> None: def save_code_verifier(self, code_verifier: str):
"""Saves a PKCE code verifier for the current session.""" """Saves a PKCE code verifier for the current session."""
MCPToolManageService.update_mcp_provider_credentials(self.mcp_provider, {"code_verifier": code_verifier}) MCPToolManageService.update_mcp_provider_credentials(self.mcp_provider, {"code_verifier": code_verifier})
+8 -8
View File
@@ -47,7 +47,7 @@ class SSETransport:
headers: dict[str, Any] | None = None, headers: dict[str, Any] | None = None,
timeout: float = 5.0, timeout: float = 5.0,
sse_read_timeout: float = 5 * 60, sse_read_timeout: float = 5 * 60,
) -> None: ):
"""Initialize the SSE transport. """Initialize the SSE transport.
Args: Args:
@@ -76,7 +76,7 @@ class SSETransport:
return url_parsed.netloc == endpoint_parsed.netloc and url_parsed.scheme == endpoint_parsed.scheme return url_parsed.netloc == endpoint_parsed.netloc and url_parsed.scheme == endpoint_parsed.scheme
def _handle_endpoint_event(self, sse_data: str, status_queue: StatusQueue) -> None: def _handle_endpoint_event(self, sse_data: str, status_queue: StatusQueue):
"""Handle an 'endpoint' SSE event. """Handle an 'endpoint' SSE event.
Args: Args:
@@ -94,7 +94,7 @@ class SSETransport:
status_queue.put(_StatusReady(endpoint_url)) status_queue.put(_StatusReady(endpoint_url))
def _handle_message_event(self, sse_data: str, read_queue: ReadQueue) -> None: def _handle_message_event(self, sse_data: str, read_queue: ReadQueue):
"""Handle a 'message' SSE event. """Handle a 'message' SSE event.
Args: Args:
@@ -110,7 +110,7 @@ class SSETransport:
logger.exception("Error parsing server message") logger.exception("Error parsing server message")
read_queue.put(exc) read_queue.put(exc)
def _handle_sse_event(self, sse: ServerSentEvent, read_queue: ReadQueue, status_queue: StatusQueue) -> None: def _handle_sse_event(self, sse: ServerSentEvent, read_queue: ReadQueue, status_queue: StatusQueue):
"""Handle a single SSE event. """Handle a single SSE event.
Args: Args:
@@ -126,7 +126,7 @@ class SSETransport:
case _: case _:
logger.warning("Unknown SSE event: %s", sse.event) logger.warning("Unknown SSE event: %s", sse.event)
def sse_reader(self, event_source: EventSource, read_queue: ReadQueue, status_queue: StatusQueue) -> None: def sse_reader(self, event_source: EventSource, read_queue: ReadQueue, status_queue: StatusQueue):
"""Read and process SSE events. """Read and process SSE events.
Args: Args:
@@ -144,7 +144,7 @@ class SSETransport:
finally: finally:
read_queue.put(None) read_queue.put(None)
def _send_message(self, client: httpx.Client, endpoint_url: str, message: SessionMessage) -> None: def _send_message(self, client: httpx.Client, endpoint_url: str, message: SessionMessage):
"""Send a single message to the server. """Send a single message to the server.
Args: Args:
@@ -163,7 +163,7 @@ class SSETransport:
response.raise_for_status() response.raise_for_status()
logger.debug("Client message sent successfully: %s", response.status_code) logger.debug("Client message sent successfully: %s", response.status_code)
def post_writer(self, client: httpx.Client, endpoint_url: str, write_queue: WriteQueue) -> None: def post_writer(self, client: httpx.Client, endpoint_url: str, write_queue: WriteQueue):
"""Handle writing messages to the server. """Handle writing messages to the server.
Args: Args:
@@ -303,7 +303,7 @@ def sse_client(
write_queue.put(None) write_queue.put(None)
def send_message(http_client: httpx.Client, endpoint_url: str, session_message: SessionMessage) -> None: def send_message(http_client: httpx.Client, endpoint_url: str, session_message: SessionMessage):
""" """
Send a message to the server using the provided HTTP client. Send a message to the server using the provided HTTP client.
+12 -12
View File
@@ -82,7 +82,7 @@ class StreamableHTTPTransport:
headers: dict[str, Any] | None = None, headers: dict[str, Any] | None = None,
timeout: float | timedelta = 30, timeout: float | timedelta = 30,
sse_read_timeout: float | timedelta = 60 * 5, sse_read_timeout: float | timedelta = 60 * 5,
) -> None: ):
"""Initialize the StreamableHTTP transport. """Initialize the StreamableHTTP transport.
Args: Args:
@@ -122,7 +122,7 @@ class StreamableHTTPTransport:
def _maybe_extract_session_id_from_response( def _maybe_extract_session_id_from_response(
self, self,
response: httpx.Response, response: httpx.Response,
) -> None: ):
"""Extract and store session ID from response headers.""" """Extract and store session ID from response headers."""
new_session_id = response.headers.get(MCP_SESSION_ID) new_session_id = response.headers.get(MCP_SESSION_ID)
if new_session_id: if new_session_id:
@@ -173,7 +173,7 @@ class StreamableHTTPTransport:
self, self,
client: httpx.Client, client: httpx.Client,
server_to_client_queue: ServerToClientQueue, server_to_client_queue: ServerToClientQueue,
) -> None: ):
"""Handle GET stream for server-initiated messages.""" """Handle GET stream for server-initiated messages."""
try: try:
if not self.session_id: if not self.session_id:
@@ -197,7 +197,7 @@ class StreamableHTTPTransport:
except Exception as exc: except Exception as exc:
logger.debug("GET stream error (non-fatal): %s", exc) logger.debug("GET stream error (non-fatal): %s", exc)
def _handle_resumption_request(self, ctx: RequestContext) -> None: def _handle_resumption_request(self, ctx: RequestContext):
"""Handle a resumption request using GET with SSE.""" """Handle a resumption request using GET with SSE."""
headers = self._update_headers_with_session(ctx.headers) headers = self._update_headers_with_session(ctx.headers)
if ctx.metadata and ctx.metadata.resumption_token: if ctx.metadata and ctx.metadata.resumption_token:
@@ -230,7 +230,7 @@ class StreamableHTTPTransport:
if is_complete: if is_complete:
break break
def _handle_post_request(self, ctx: RequestContext) -> None: def _handle_post_request(self, ctx: RequestContext):
"""Handle a POST request with response processing.""" """Handle a POST request with response processing."""
headers = self._update_headers_with_session(ctx.headers) headers = self._update_headers_with_session(ctx.headers)
message = ctx.session_message.message message = ctx.session_message.message
@@ -278,7 +278,7 @@ class StreamableHTTPTransport:
self, self,
response: httpx.Response, response: httpx.Response,
server_to_client_queue: ServerToClientQueue, server_to_client_queue: ServerToClientQueue,
) -> None: ):
"""Handle JSON response from the server.""" """Handle JSON response from the server."""
try: try:
content = response.read() content = response.read()
@@ -288,7 +288,7 @@ class StreamableHTTPTransport:
except Exception as exc: except Exception as exc:
server_to_client_queue.put(exc) server_to_client_queue.put(exc)
def _handle_sse_response(self, response: httpx.Response, ctx: RequestContext) -> None: def _handle_sse_response(self, response: httpx.Response, ctx: RequestContext):
"""Handle SSE response from the server.""" """Handle SSE response from the server."""
try: try:
event_source = EventSource(response) event_source = EventSource(response)
@@ -307,7 +307,7 @@ class StreamableHTTPTransport:
self, self,
content_type: str, content_type: str,
server_to_client_queue: ServerToClientQueue, server_to_client_queue: ServerToClientQueue,
) -> None: ):
"""Handle unexpected content type in response.""" """Handle unexpected content type in response."""
error_msg = f"Unexpected content type: {content_type}" error_msg = f"Unexpected content type: {content_type}"
logger.error(error_msg) logger.error(error_msg)
@@ -317,7 +317,7 @@ class StreamableHTTPTransport:
self, self,
server_to_client_queue: ServerToClientQueue, server_to_client_queue: ServerToClientQueue,
request_id: RequestId, request_id: RequestId,
) -> None: ):
"""Send a session terminated error response.""" """Send a session terminated error response."""
jsonrpc_error = JSONRPCError( jsonrpc_error = JSONRPCError(
jsonrpc="2.0", jsonrpc="2.0",
@@ -333,7 +333,7 @@ class StreamableHTTPTransport:
client_to_server_queue: ClientToServerQueue, client_to_server_queue: ClientToServerQueue,
server_to_client_queue: ServerToClientQueue, server_to_client_queue: ServerToClientQueue,
start_get_stream: Callable[[], None], start_get_stream: Callable[[], None],
) -> None: ):
"""Handle writing requests to the server. """Handle writing requests to the server.
This method processes messages from the client_to_server_queue and sends them to the server. This method processes messages from the client_to_server_queue and sends them to the server.
@@ -379,7 +379,7 @@ class StreamableHTTPTransport:
except Exception as exc: except Exception as exc:
server_to_client_queue.put(exc) server_to_client_queue.put(exc)
def terminate_session(self, client: httpx.Client) -> None: def terminate_session(self, client: httpx.Client):
"""Terminate the session by sending a DELETE request.""" """Terminate the session by sending a DELETE request."""
if not self.session_id: if not self.session_id:
return return
@@ -441,7 +441,7 @@ def streamablehttp_client(
timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout), timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout),
) as client: ) as client:
# Define callbacks that need access to thread pool # Define callbacks that need access to thread pool
def start_get_stream() -> None: def start_get_stream():
"""Start a worker thread to handle server-initiated messages.""" """Start a worker thread to handle server-initiated messages."""
executor.submit(transport.handle_get_stream, client, server_to_client_queue) executor.submit(transport.handle_get_stream, client, server_to_client_queue)
+14 -16
View File
@@ -76,7 +76,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
ReceiveNotificationT ReceiveNotificationT
]""", ]""",
on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any], on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any],
) -> None: ):
self.request_id = request_id self.request_id = request_id
self.request_meta = request_meta self.request_meta = request_meta
self.request = request self.request = request
@@ -95,7 +95,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
exc_type: type[BaseException] | None, exc_type: type[BaseException] | None,
exc_val: BaseException | None, exc_val: BaseException | None,
exc_tb: TracebackType | None, exc_tb: TracebackType | None,
) -> None: ):
"""Exit the context manager, performing cleanup and notifying completion.""" """Exit the context manager, performing cleanup and notifying completion."""
try: try:
if self._completed: if self._completed:
@@ -103,7 +103,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
finally: finally:
self._entered = False self._entered = False
def respond(self, response: SendResultT | ErrorData) -> None: def respond(self, response: SendResultT | ErrorData):
"""Send a response for this request. """Send a response for this request.
Must be called within a context manager block. Must be called within a context manager block.
@@ -119,7 +119,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
self._session._send_response(request_id=self.request_id, response=response) self._session._send_response(request_id=self.request_id, response=response)
def cancel(self) -> None: def cancel(self):
"""Cancel this request and mark it as completed.""" """Cancel this request and mark it as completed."""
if not self._entered: if not self._entered:
raise RuntimeError("RequestResponder must be used as a context manager") raise RuntimeError("RequestResponder must be used as a context manager")
@@ -163,7 +163,7 @@ class BaseSession(
receive_notification_type: type[ReceiveNotificationT], receive_notification_type: type[ReceiveNotificationT],
# If none, reading will never time out # If none, reading will never time out
read_timeout_seconds: timedelta | None = None, read_timeout_seconds: timedelta | None = None,
) -> None: ):
self._read_stream = read_stream self._read_stream = read_stream
self._write_stream = write_stream self._write_stream = write_stream
self._response_streams = {} self._response_streams = {}
@@ -183,7 +183,7 @@ class BaseSession(
self._receiver_future = self._executor.submit(self._receive_loop) self._receiver_future = self._executor.submit(self._receive_loop)
return self return self
def check_receiver_status(self) -> None: def check_receiver_status(self):
"""`check_receiver_status` ensures that any exceptions raised during the """`check_receiver_status` ensures that any exceptions raised during the
execution of `_receive_loop` are retrieved and propagated.""" execution of `_receive_loop` are retrieved and propagated."""
if self._receiver_future and self._receiver_future.done(): if self._receiver_future and self._receiver_future.done():
@@ -191,7 +191,7 @@ class BaseSession(
def __exit__( def __exit__(
self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None
) -> None: ):
self._read_stream.put(None) self._read_stream.put(None)
self._write_stream.put(None) self._write_stream.put(None)
@@ -277,7 +277,7 @@ class BaseSession(
self, self,
notification: SendNotificationT, notification: SendNotificationT,
related_request_id: RequestId | None = None, related_request_id: RequestId | None = None,
) -> None: ):
""" """
Emits a notification, which is a one-way message that does not expect Emits a notification, which is a one-way message that does not expect
a response. a response.
@@ -296,7 +296,7 @@ class BaseSession(
) )
self._write_stream.put(session_message) self._write_stream.put(session_message)
def _send_response(self, request_id: RequestId, response: SendResultT | ErrorData) -> None: def _send_response(self, request_id: RequestId, response: SendResultT | ErrorData):
if isinstance(response, ErrorData): if isinstance(response, ErrorData):
jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response) jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response)
session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_error)) session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_error))
@@ -310,7 +310,7 @@ class BaseSession(
session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_response)) session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_response))
self._write_stream.put(session_message) self._write_stream.put(session_message)
def _receive_loop(self) -> None: def _receive_loop(self):
""" """
Main message processing loop. Main message processing loop.
In a real synchronous implementation, this would likely run in a separate thread. In a real synchronous implementation, this would likely run in a separate thread.
@@ -382,7 +382,7 @@ class BaseSession(
logger.exception("Error in message processing loop") logger.exception("Error in message processing loop")
raise raise
def _received_request(self, responder: RequestResponder[ReceiveRequestT, SendResultT]) -> None: def _received_request(self, responder: RequestResponder[ReceiveRequestT, SendResultT]):
""" """
Can be overridden by subclasses to handle a request without needing to Can be overridden by subclasses to handle a request without needing to
listen on the message stream. listen on the message stream.
@@ -391,15 +391,13 @@ class BaseSession(
forwarded on to the message stream. forwarded on to the message stream.
""" """
def _received_notification(self, notification: ReceiveNotificationT) -> None: def _received_notification(self, notification: ReceiveNotificationT):
""" """
Can be overridden by subclasses to handle a notification without needing Can be overridden by subclasses to handle a notification without needing
to listen on the message stream. to listen on the message stream.
""" """
def send_progress_notification( def send_progress_notification(self, progress_token: str | int, progress: float, total: float | None = None):
self, progress_token: str | int, progress: float, total: float | None = None
) -> None:
""" """
Sends a progress notification for a request that is currently being Sends a progress notification for a request that is currently being
processed. processed.
@@ -408,5 +406,5 @@ class BaseSession(
def _handle_incoming( def _handle_incoming(
self, self,
req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception, req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception,
) -> None: ):
"""A generic handler for incoming messages. Overwritten by subclasses.""" """A generic handler for incoming messages. Overwritten by subclasses."""
+10 -12
View File
@@ -28,19 +28,19 @@ class LoggingFnT(Protocol):
def __call__( def __call__(
self, self,
params: types.LoggingMessageNotificationParams, params: types.LoggingMessageNotificationParams,
) -> None: ... ): ...
class MessageHandlerFnT(Protocol): class MessageHandlerFnT(Protocol):
def __call__( def __call__(
self, self,
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
) -> None: ... ): ...
def _default_message_handler( def _default_message_handler(
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
) -> None: ):
if isinstance(message, Exception): if isinstance(message, Exception):
raise ValueError(str(message)) raise ValueError(str(message))
elif isinstance(message, (types.ServerNotification | RequestResponder)): elif isinstance(message, (types.ServerNotification | RequestResponder)):
@@ -68,7 +68,7 @@ def _default_list_roots_callback(
def _default_logging_callback( def _default_logging_callback(
params: types.LoggingMessageNotificationParams, params: types.LoggingMessageNotificationParams,
) -> None: ):
pass pass
@@ -94,7 +94,7 @@ class ClientSession(
logging_callback: LoggingFnT | None = None, logging_callback: LoggingFnT | None = None,
message_handler: MessageHandlerFnT | None = None, message_handler: MessageHandlerFnT | None = None,
client_info: types.Implementation | None = None, client_info: types.Implementation | None = None,
) -> None: ):
super().__init__( super().__init__(
read_stream, read_stream,
write_stream, write_stream,
@@ -155,9 +155,7 @@ class ClientSession(
types.EmptyResult, types.EmptyResult,
) )
def send_progress_notification( def send_progress_notification(self, progress_token: str | int, progress: float, total: float | None = None):
self, progress_token: str | int, progress: float, total: float | None = None
) -> None:
"""Send a progress notification.""" """Send a progress notification."""
self.send_notification( self.send_notification(
types.ClientNotification( types.ClientNotification(
@@ -314,7 +312,7 @@ class ClientSession(
types.ListToolsResult, types.ListToolsResult,
) )
def send_roots_list_changed(self) -> None: def send_roots_list_changed(self):
"""Send a roots/list_changed notification.""" """Send a roots/list_changed notification."""
self.send_notification( self.send_notification(
types.ClientNotification( types.ClientNotification(
@@ -324,7 +322,7 @@ class ClientSession(
) )
) )
def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]) -> None: def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]):
ctx = RequestContext[ClientSession, Any]( ctx = RequestContext[ClientSession, Any](
request_id=responder.request_id, request_id=responder.request_id,
meta=responder.request_meta, meta=responder.request_meta,
@@ -352,11 +350,11 @@ class ClientSession(
def _handle_incoming( def _handle_incoming(
self, self,
req: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, req: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
) -> None: ):
"""Handle incoming messages by forwarding to the message handler.""" """Handle incoming messages by forwarding to the message handler."""
self._message_handler(req) self._message_handler(req)
def _received_notification(self, notification: types.ServerNotification) -> None: def _received_notification(self, notification: types.ServerNotification):
"""Handle notifications from the server.""" """Handle notifications from the server."""
# Process specific notification types # Process specific notification types
match notification.root: match notification.root:
+1 -1
View File
@@ -27,7 +27,7 @@ class TokenBufferMemory:
self, self,
conversation: Conversation, conversation: Conversation,
model_instance: ModelInstance, model_instance: ModelInstance,
) -> None: ):
self.conversation = conversation self.conversation = conversation
self.model_instance = model_instance self.model_instance = model_instance
+7 -7
View File
@@ -32,7 +32,7 @@ class ModelInstance:
Model instance class Model instance class
""" """
def __init__(self, provider_model_bundle: ProviderModelBundle, model: str) -> None: def __init__(self, provider_model_bundle: ProviderModelBundle, model: str):
self.provider_model_bundle = provider_model_bundle self.provider_model_bundle = provider_model_bundle
self.model = model self.model = model
self.provider = provider_model_bundle.configuration.provider.provider self.provider = provider_model_bundle.configuration.provider.provider
@@ -46,7 +46,7 @@ class ModelInstance:
) )
@staticmethod @staticmethod
def _fetch_credentials_from_bundle(provider_model_bundle: ProviderModelBundle, model: str) -> dict: def _fetch_credentials_from_bundle(provider_model_bundle: ProviderModelBundle, model: str):
""" """
Fetch credentials from provider model bundle Fetch credentials from provider model bundle
:param provider_model_bundle: provider model bundle :param provider_model_bundle: provider model bundle
@@ -342,7 +342,7 @@ class ModelInstance:
), ),
) )
def _round_robin_invoke(self, function: Callable[..., Any], *args, **kwargs) -> Any: def _round_robin_invoke(self, function: Callable[..., Any], *args, **kwargs):
""" """
Round-robin invoke Round-robin invoke
:param function: function to invoke :param function: function to invoke
@@ -379,7 +379,7 @@ class ModelInstance:
except Exception as e: except Exception as e:
raise e raise e
def get_tts_voices(self, language: Optional[str] = None) -> list: def get_tts_voices(self, language: Optional[str] = None):
""" """
Invoke large language tts model voices Invoke large language tts model voices
@@ -394,7 +394,7 @@ class ModelInstance:
class ModelManager: class ModelManager:
def __init__(self) -> None: def __init__(self):
self._provider_manager = ProviderManager() self._provider_manager = ProviderManager()
def get_model_instance(self, tenant_id: str, provider: str, model_type: ModelType, model: str) -> ModelInstance: def get_model_instance(self, tenant_id: str, provider: str, model_type: ModelType, model: str) -> ModelInstance:
@@ -453,7 +453,7 @@ class LBModelManager:
model: str, model: str,
load_balancing_configs: list[ModelLoadBalancingConfiguration], load_balancing_configs: list[ModelLoadBalancingConfiguration],
managed_credentials: Optional[dict] = None, managed_credentials: Optional[dict] = None,
) -> None: ):
""" """
Load balancing model manager Load balancing model manager
:param tenant_id: tenant_id :param tenant_id: tenant_id
@@ -534,7 +534,7 @@ model: %s""",
return config return config
def cooldown(self, config: ModelLoadBalancingConfiguration, expire: int = 60) -> None: def cooldown(self, config: ModelLoadBalancingConfiguration, expire: int = 60):
""" """
Cooldown model load balancing config Cooldown model load balancing config
:param config: model load balancing config :param config: model load balancing config
@@ -35,7 +35,7 @@ class Callback(ABC):
stop: Optional[Sequence[str]] = None, stop: Optional[Sequence[str]] = None,
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
) -> None: ):
""" """
Before invoke callback Before invoke callback
@@ -94,7 +94,7 @@ class Callback(ABC):
stop: Optional[Sequence[str]] = None, stop: Optional[Sequence[str]] = None,
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
) -> None: ):
""" """
After invoke callback After invoke callback
@@ -124,7 +124,7 @@ class Callback(ABC):
stop: Optional[Sequence[str]] = None, stop: Optional[Sequence[str]] = None,
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
) -> None: ):
""" """
Invoke error callback Invoke error callback
@@ -141,7 +141,7 @@ class Callback(ABC):
""" """
raise NotImplementedError() raise NotImplementedError()
def print_text(self, text: str, color: Optional[str] = None, end: str = "") -> None: def print_text(self, text: str, color: Optional[str] = None, end: str = ""):
"""Print text with highlighting and no end characters.""" """Print text with highlighting and no end characters."""
text_to_print = self._get_colored_text(text, color) if color else text text_to_print = self._get_colored_text(text, color) if color else text
print(text_to_print, end=end) print(text_to_print, end=end)
@@ -24,7 +24,7 @@ class LoggingCallback(Callback):
stop: Optional[Sequence[str]] = None, stop: Optional[Sequence[str]] = None,
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
) -> None: ):
""" """
Before invoke callback Before invoke callback
@@ -110,7 +110,7 @@ class LoggingCallback(Callback):
stop: Optional[Sequence[str]] = None, stop: Optional[Sequence[str]] = None,
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
) -> None: ):
""" """
After invoke callback After invoke callback
@@ -151,7 +151,7 @@ class LoggingCallback(Callback):
stop: Optional[Sequence[str]] = None, stop: Optional[Sequence[str]] = None,
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
) -> None: ):
""" """
Invoke error callback Invoke error callback
+1 -1
View File
@@ -6,7 +6,7 @@ class InvokeError(ValueError):
description: Optional[str] = None description: Optional[str] = None
def __init__(self, description: Optional[str] = None) -> None: def __init__(self, description: Optional[str] = None):
self.description = description self.description = description
def __str__(self): def __str__(self):
@@ -239,7 +239,7 @@ class AIModel(BaseModel):
""" """
return None return None
def _get_default_parameter_rule_variable_map(self, name: DefaultParameterName) -> dict: def _get_default_parameter_rule_variable_map(self, name: DefaultParameterName):
""" """
Get default parameter rule for given name Get default parameter rule for given name
@@ -408,7 +408,7 @@ class LargeLanguageModel(AIModel):
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None, callbacks: Optional[list[Callback]] = None,
) -> None: ):
""" """
Trigger before invoke callbacks Trigger before invoke callbacks
@@ -456,7 +456,7 @@ class LargeLanguageModel(AIModel):
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None, callbacks: Optional[list[Callback]] = None,
) -> None: ):
""" """
Trigger new chunk callbacks Trigger new chunk callbacks
@@ -503,7 +503,7 @@ class LargeLanguageModel(AIModel):
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None, callbacks: Optional[list[Callback]] = None,
) -> None: ):
""" """
Trigger after invoke callbacks Trigger after invoke callbacks
@@ -553,7 +553,7 @@ class LargeLanguageModel(AIModel):
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None, callbacks: Optional[list[Callback]] = None,
) -> None: ):
""" """
Trigger invoke error callbacks Trigger invoke error callbacks
@@ -28,7 +28,7 @@ class GPT2Tokenizer:
return GPT2Tokenizer._get_num_tokens_by_gpt2(text) return GPT2Tokenizer._get_num_tokens_by_gpt2(text)
@staticmethod @staticmethod
def get_encoder() -> Any: def get_encoder():
global _tokenizer, _lock global _tokenizer, _lock
if _tokenizer is not None: if _tokenizer is not None:
return _tokenizer return _tokenizer
@@ -56,7 +56,7 @@ class TTSModel(AIModel):
except Exception as e: except Exception as e:
raise self._transform_invoke_error(e) raise self._transform_invoke_error(e)
def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None) -> list[dict]: def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None):
""" """
Retrieves the list of voices supported by a given text-to-speech (TTS) model. Retrieves the list of voices supported by a given text-to-speech (TTS) model.
@@ -36,7 +36,7 @@ class ModelProviderExtension(BaseModel):
class ModelProviderFactory: class ModelProviderFactory:
provider_position_map: dict[str, int] provider_position_map: dict[str, int]
def __init__(self, tenant_id: str) -> None: def __init__(self, tenant_id: str):
self.provider_position_map = {} self.provider_position_map = {}
self.tenant_id = tenant_id self.tenant_id = tenant_id
@@ -132,7 +132,7 @@ class ModelProviderFactory:
return plugin_model_provider_entity return plugin_model_provider_entity
def provider_credentials_validate(self, *, provider: str, credentials: dict) -> dict: def provider_credentials_validate(self, *, provider: str, credentials: dict):
""" """
Validate provider credentials Validate provider credentials
@@ -163,9 +163,7 @@ class ModelProviderFactory:
return filtered_credentials return filtered_credentials
def model_credentials_validate( def model_credentials_validate(self, *, provider: str, model_type: ModelType, model: str, credentials: dict):
self, *, provider: str, model_type: ModelType, model: str, credentials: dict
) -> dict:
""" """
Validate model credentials Validate model credentials
@@ -6,7 +6,7 @@ from core.model_runtime.entities.provider_entities import CredentialFormSchema,
class CommonValidator: class CommonValidator:
def _validate_and_filter_credential_form_schemas( def _validate_and_filter_credential_form_schemas(
self, credential_form_schemas: list[CredentialFormSchema], credentials: dict self, credential_form_schemas: list[CredentialFormSchema], credentials: dict
) -> dict: ):
need_validate_credential_form_schema_map = {} need_validate_credential_form_schema_map = {}
for credential_form_schema in credential_form_schemas: for credential_form_schema in credential_form_schemas:
if not credential_form_schema.show_on: if not credential_form_schema.show_on:
@@ -8,7 +8,7 @@ class ModelCredentialSchemaValidator(CommonValidator):
self.model_type = model_type self.model_type = model_type
self.model_credential_schema = model_credential_schema self.model_credential_schema = model_credential_schema
def validate_and_filter(self, credentials: dict) -> dict: def validate_and_filter(self, credentials: dict):
""" """
Validate model credentials Validate model credentials
@@ -6,7 +6,7 @@ class ProviderCredentialSchemaValidator(CommonValidator):
def __init__(self, provider_credential_schema: ProviderCredentialSchema): def __init__(self, provider_credential_schema: ProviderCredentialSchema):
self.provider_credential_schema = provider_credential_schema self.provider_credential_schema = provider_credential_schema
def validate_and_filter(self, credentials: dict) -> dict: def validate_and_filter(self, credentials: dict):
""" """
Validate provider credentials Validate provider credentials
+2 -2
View File
@@ -18,7 +18,7 @@ from pydantic_core import Url
from pydantic_extra_types.color import Color from pydantic_extra_types.color import Color
def _model_dump(model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any) -> Any: def _model_dump(model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any):
return model.model_dump(mode=mode, **kwargs) return model.model_dump(mode=mode, **kwargs)
@@ -100,7 +100,7 @@ def jsonable_encoder(
exclude_none: bool = False, exclude_none: bool = False,
custom_encoder: Optional[dict[Any, Callable[[Any], Any]]] = None, custom_encoder: Optional[dict[Any, Callable[[Any], Any]]] = None,
sqlalchemy_safe: bool = True, sqlalchemy_safe: bool = True,
) -> Any: ):
custom_encoder = custom_encoder or {} custom_encoder = custom_encoder or {}
if custom_encoder: if custom_encoder:
if type(obj) in custom_encoder: if type(obj) in custom_encoder:
+2 -2
View File
@@ -25,7 +25,7 @@ class ApiModeration(Moderation):
name: str = "api" name: str = "api"
@classmethod @classmethod
def validate_config(cls, tenant_id: str, config: dict) -> None: def validate_config(cls, tenant_id: str, config: dict):
""" """
Validate the incoming form config data. Validate the incoming form config data.
@@ -75,7 +75,7 @@ class ApiModeration(Moderation):
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
) )
def _get_config_by_requestor(self, extension_point: APIBasedExtensionPoint, params: dict) -> dict: def _get_config_by_requestor(self, extension_point: APIBasedExtensionPoint, params: dict):
if self.config is None: if self.config is None:
raise ValueError("The config is not set.") raise ValueError("The config is not set.")
extension = self._get_api_based_extension(self.tenant_id, self.config.get("api_based_extension_id", "")) extension = self._get_api_based_extension(self.tenant_id, self.config.get("api_based_extension_id", ""))
+3 -3
View File
@@ -34,13 +34,13 @@ class Moderation(Extensible, ABC):
module: ExtensionModule = ExtensionModule.MODERATION module: ExtensionModule = ExtensionModule.MODERATION
def __init__(self, app_id: str, tenant_id: str, config: Optional[dict] = None) -> None: def __init__(self, app_id: str, tenant_id: str, config: Optional[dict] = None):
super().__init__(tenant_id, config) super().__init__(tenant_id, config)
self.app_id = app_id self.app_id = app_id
@classmethod @classmethod
@abstractmethod @abstractmethod
def validate_config(cls, tenant_id: str, config: dict) -> None: def validate_config(cls, tenant_id: str, config: dict):
""" """
Validate the incoming form config data. Validate the incoming form config data.
@@ -76,7 +76,7 @@ class Moderation(Extensible, ABC):
raise NotImplementedError raise NotImplementedError
@classmethod @classmethod
def _validate_inputs_and_outputs_config(cls, config: dict, is_preset_response_required: bool) -> None: def _validate_inputs_and_outputs_config(cls, config: dict, is_preset_response_required: bool):
# inputs_config # inputs_config
inputs_config = config.get("inputs_config") inputs_config = config.get("inputs_config")
if not isinstance(inputs_config, dict): if not isinstance(inputs_config, dict):
+2 -2
View File
@@ -6,12 +6,12 @@ from extensions.ext_code_based_extension import code_based_extension
class ModerationFactory: class ModerationFactory:
__extension_instance: Moderation __extension_instance: Moderation
def __init__(self, name: str, app_id: str, tenant_id: str, config: dict) -> None: def __init__(self, name: str, app_id: str, tenant_id: str, config: dict):
extension_class = code_based_extension.extension_class(ExtensionModule.MODERATION, name) extension_class = code_based_extension.extension_class(ExtensionModule.MODERATION, name)
self.__extension_instance = extension_class(app_id, tenant_id, config) self.__extension_instance = extension_class(app_id, tenant_id, config)
@classmethod @classmethod
def validate_config(cls, name: str, tenant_id: str, config: dict) -> None: def validate_config(cls, name: str, tenant_id: str, config: dict):
""" """
Validate the incoming form config data. Validate the incoming form config data.
+1 -1
View File
@@ -8,7 +8,7 @@ class KeywordsModeration(Moderation):
name: str = "keywords" name: str = "keywords"
@classmethod @classmethod
def validate_config(cls, tenant_id: str, config: dict) -> None: def validate_config(cls, tenant_id: str, config: dict):
""" """
Validate the incoming form config data. Validate the incoming form config data.
@@ -7,7 +7,7 @@ class OpenAIModeration(Moderation):
name: str = "openai_moderation" name: str = "openai_moderation"
@classmethod @classmethod
def validate_config(cls, tenant_id: str, config: dict) -> None: def validate_config(cls, tenant_id: str, config: dict):
""" """
Validate the incoming form config data. Validate the incoming form config data.
+1 -1
View File
@@ -40,7 +40,7 @@ class OutputModeration(BaseModel):
def get_final_output(self) -> str: def get_final_output(self) -> str:
return self.final_output or "" return self.final_output or ""
def append_new_token(self, token: str) -> None: def append_new_token(self, token: str):
self.buffer += token self.buffer += token
if not self.thread: if not self.thread:
@@ -6,7 +6,7 @@ from models.account import Tenant
class PluginEncrypter: class PluginEncrypter:
@classmethod @classmethod
def invoke_encrypt(cls, tenant: Tenant, payload: RequestInvokeEncrypt) -> dict: def invoke_encrypt(cls, tenant: Tenant, payload: RequestInvokeEncrypt):
encrypter, cache = create_provider_encrypter( encrypter, cache = create_provider_encrypter(
tenant_id=tenant.id, tenant_id=tenant.id,
config=payload.config, config=payload.config,
+2 -2
View File
@@ -27,7 +27,7 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation):
model_config: ParameterExtractorModelConfig, model_config: ParameterExtractorModelConfig,
instruction: str, instruction: str,
query: str, query: str,
) -> dict: ):
""" """
Invoke parameter extractor node. Invoke parameter extractor node.
@@ -78,7 +78,7 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation):
classes: list[ClassConfig], classes: list[ClassConfig],
instruction: str, instruction: str,
query: str, query: str,
) -> dict: ):
""" """
Invoke question classifier node. Invoke question classifier node.
+4 -4
View File
@@ -117,7 +117,7 @@ class PluginDeclaration(BaseModel):
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def validate_category(cls, values: dict) -> dict: def validate_category(cls, values: dict):
# auto detect category # auto detect category
if values.get("tool"): if values.get("tool"):
values["category"] = PluginCategory.Tool values["category"] = PluginCategory.Tool
@@ -168,7 +168,7 @@ class GenericProviderID:
def __str__(self) -> str: def __str__(self) -> str:
return f"{self.organization}/{self.plugin_name}/{self.provider_name}" return f"{self.organization}/{self.plugin_name}/{self.provider_name}"
def __init__(self, value: str, is_hardcoded: bool = False) -> None: def __init__(self, value: str, is_hardcoded: bool = False):
if not value: if not value:
raise NotFound("plugin not found, please add plugin") raise NotFound("plugin not found, please add plugin")
# check if the value is a valid plugin id with format: $organization/$plugin_name/$provider_name # check if the value is a valid plugin id with format: $organization/$plugin_name/$provider_name
@@ -191,14 +191,14 @@ class GenericProviderID:
class ModelProviderID(GenericProviderID): class ModelProviderID(GenericProviderID):
def __init__(self, value: str, is_hardcoded: bool = False) -> None: def __init__(self, value: str, is_hardcoded: bool = False):
super().__init__(value, is_hardcoded) super().__init__(value, is_hardcoded)
if self.organization == "langgenius" and self.provider_name == "google": if self.organization == "langgenius" and self.provider_name == "google":
self.plugin_name = "gemini" self.plugin_name = "gemini"
class ToolProviderID(GenericProviderID): class ToolProviderID(GenericProviderID):
def __init__(self, value: str, is_hardcoded: bool = False) -> None: def __init__(self, value: str, is_hardcoded: bool = False):
super().__init__(value, is_hardcoded) super().__init__(value, is_hardcoded)
if self.organization == "langgenius": if self.organization == "langgenius":
if self.provider_name in ["jina", "siliconflow", "stepfun", "gitee_ai"]: if self.provider_name in ["jina", "siliconflow", "stepfun", "gitee_ai"]:
+2 -2
View File
@@ -17,7 +17,7 @@ class PluginAgentClient(BasePluginClient):
Fetch agent providers for the given tenant. Fetch agent providers for the given tenant.
""" """
def transformer(json_response: dict[str, Any]) -> dict: def transformer(json_response: dict[str, Any]):
for provider in json_response.get("data", []): for provider in json_response.get("data", []):
declaration = provider.get("declaration", {}) or {} declaration = provider.get("declaration", {}) or {}
provider_name = declaration.get("identity", {}).get("name") provider_name = declaration.get("identity", {}).get("name")
@@ -49,7 +49,7 @@ class PluginAgentClient(BasePluginClient):
""" """
agent_provider_id = GenericProviderID(provider) agent_provider_id = GenericProviderID(provider)
def transformer(json_response: dict[str, Any]) -> dict: def transformer(json_response: dict[str, Any]):
# skip if error occurs # skip if error occurs
if json_response.get("data") is None or json_response.get("data", {}).get("declaration") is None: if json_response.get("data") is None or json_response.get("data", {}).get("declaration") is None:
return json_response return json_response
+1 -1
View File
@@ -8,7 +8,7 @@ from extensions.ext_logging import get_request_id
class PluginDaemonError(Exception): class PluginDaemonError(Exception):
"""Base class for all plugin daemon errors.""" """Base class for all plugin daemon errors."""
def __init__(self, description: str) -> None: def __init__(self, description: str):
self.description = description self.description = description
def __str__(self) -> str: def __str__(self) -> str:
+1 -1
View File
@@ -415,7 +415,7 @@ class PluginModelClient(BasePluginClient):
model: str, model: str,
credentials: dict, credentials: dict,
language: Optional[str] = None, language: Optional[str] = None,
) -> list[dict]: ):
""" """
Get tts model voices Get tts model voices
""" """
+2 -2
View File
@@ -16,7 +16,7 @@ class PluginToolManager(BasePluginClient):
Fetch tool providers for the given tenant. Fetch tool providers for the given tenant.
""" """
def transformer(json_response: dict[str, Any]) -> dict: def transformer(json_response: dict[str, Any]):
for provider in json_response.get("data", []): for provider in json_response.get("data", []):
declaration = provider.get("declaration", {}) or {} declaration = provider.get("declaration", {}) or {}
provider_name = declaration.get("identity", {}).get("name") provider_name = declaration.get("identity", {}).get("name")
@@ -48,7 +48,7 @@ class PluginToolManager(BasePluginClient):
""" """
tool_provider_id = ToolProviderID(provider) tool_provider_id = ToolProviderID(provider)
def transformer(json_response: dict[str, Any]) -> dict: def transformer(json_response: dict[str, Any]):
data = json_response.get("data") data = json_response.get("data")
if data: if data:
for tool in data.get("declaration", {}).get("tools", []): for tool in data.get("declaration", {}).get("tools", []):
+1 -1
View File
@@ -18,7 +18,7 @@ class FileChunk:
bytes_written: int = field(default=0, init=False) bytes_written: int = field(default=0, init=False)
data: bytearray = field(init=False) data: bytearray = field(init=False)
def __post_init__(self) -> None: def __post_init__(self):
self.data = bytearray(self.total_length) self.data = bytearray(self.total_length)

Some files were not shown because too many files have changed in this diff Show More