Merge 升级到1.11.4

# Conflicts:
#	.github/workflows/tool-test-sdks.yaml
#	api/.env.example
#	api/README.md
#	api/commands.py
#	api/controllers/console/explore/wraps.py
#	api/controllers/web/workflow.py
#	api/extensions/ext_commands.py
#	api/models/model.py
#	api/pyproject.toml
#	api/services/feature_service.py
#	web/README.md
#	web/app/components/explore/app-card/index.tsx
#	web/app/components/explore/app-list/index.tsx
#	web/app/components/explore/sidebar/index.tsx
#	web/app/signin/components/mail-and-password-auth.tsx
#	web/i18n/uk-UA/app-overview.json
#	web/i18n/uk-UA/app.json
#	web/i18n/uk-UA/billing.json
#	web/i18n/uk-UA/common.json
#	web/i18n/uk-UA/dataset-creation.json
#	web/i18n/uk-UA/dataset-documents.json
#	web/i18n/uk-UA/dataset-hit-testing.json
#	web/i18n/uk-UA/dataset-settings.json
#	web/i18n/uk-UA/dataset.json
#	web/i18n/uk-UA/explore.json
#	web/i18n/uk-UA/plugin.json
#	web/i18n/uk-UA/tools.json
#	web/next.config.js
#	web/package.json
#	web/pnpm-lock.yaml
#	web/service/common.ts
#	web/service/explore.ts
#	web/service/fetch.ts
#	web/service/use-explore.ts
#	web/types/feature.ts
This commit is contained in:
npc0-hue
2026-01-26 07:08:40 +08:00
1275 changed files with 84140 additions and 16144 deletions
+33 -20
View File
@@ -1,6 +1,7 @@
import json
import logging
import uuid
from decimal import Decimal
from typing import Union, cast
from sqlalchemy import select
@@ -41,6 +42,7 @@ from core.tools.tool_manager import ToolManager
from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool
from extensions.ext_database import db
from factories import file_factory
from models.enums import CreatorUserRole
from models.model import Conversation, Message, MessageAgentThought, MessageFile
logger = logging.getLogger(__name__)
@@ -289,6 +291,7 @@ class BaseAgentRunner(AppRunner):
thought = MessageAgentThought(
message_id=message_id,
message_chain_id=None,
tool_process_data=None,
thought="",
tool=tool_name,
tool_labels_str="{}",
@@ -296,20 +299,20 @@ class BaseAgentRunner(AppRunner):
tool_input=tool_input,
message=message,
message_token=0,
message_unit_price=0,
message_price_unit=0,
message_unit_price=Decimal(0),
message_price_unit=Decimal("0.001"),
message_files=json.dumps(messages_ids) if messages_ids else "",
answer="",
observation="",
answer_token=0,
answer_unit_price=0,
answer_price_unit=0,
answer_unit_price=Decimal(0),
answer_price_unit=Decimal("0.001"),
tokens=0,
total_price=0,
total_price=Decimal(0),
position=self.agent_thought_count + 1,
currency="USD",
latency=0,
created_by_role="account",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=self.user_id,
)
@@ -342,7 +345,8 @@ class BaseAgentRunner(AppRunner):
raise ValueError("agent thought not found")
if thought:
agent_thought.thought += thought
existing_thought = agent_thought.thought or ""
agent_thought.thought = f"{existing_thought}{thought}"
if tool_name:
agent_thought.tool = tool_name
@@ -440,21 +444,30 @@ class BaseAgentRunner(AppRunner):
agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
if agent_thoughts:
for agent_thought in agent_thoughts:
tools = agent_thought.tool
if tools:
tools = tools.split(";")
tool_names_raw = agent_thought.tool
if tool_names_raw:
tool_names = tool_names_raw.split(";")
tool_calls: list[AssistantPromptMessage.ToolCall] = []
tool_call_response: list[ToolPromptMessage] = []
try:
tool_inputs = json.loads(agent_thought.tool_input)
except Exception:
tool_inputs = {tool: {} for tool in tools}
try:
tool_responses = json.loads(agent_thought.observation)
except Exception:
tool_responses = dict.fromkeys(tools, agent_thought.observation)
tool_input_payload = agent_thought.tool_input
if tool_input_payload:
try:
tool_inputs = json.loads(tool_input_payload)
except Exception:
tool_inputs = {tool: {} for tool in tool_names}
else:
tool_inputs = {tool: {} for tool in tool_names}
for tool in tools:
observation_payload = agent_thought.observation
if observation_payload:
try:
tool_responses = json.loads(observation_payload)
except Exception:
tool_responses = dict.fromkeys(tool_names, observation_payload)
else:
tool_responses = dict.fromkeys(tool_names, observation_payload)
for tool in tool_names:
# generate a uuid for tool call
tool_call_id = str(uuid.uuid4())
tool_calls.append(
@@ -484,7 +497,7 @@ class BaseAgentRunner(AppRunner):
*tool_call_response,
]
)
if not tools:
if not tool_names_raw:
result.append(AssistantPromptMessage(content=agent_thought.thought))
else:
if message.answer:
+1 -3
View File
@@ -188,7 +188,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
),
)
assistant_message = AssistantPromptMessage(content="", tool_calls=[])
assistant_message = AssistantPromptMessage(content=response, tool_calls=[])
if tool_calls:
assistant_message.tool_calls = [
AssistantPromptMessage.ToolCall(
@@ -200,8 +200,6 @@ class FunctionCallAgentRunner(BaseAgentRunner):
)
for tool_call in tool_calls
]
else:
assistant_message.content = response
self._current_thoughts.append(assistant_message)
+3 -10
View File
@@ -1,4 +1,3 @@
import json
from collections.abc import Sequence
from enum import StrEnum, auto
from typing import Any, Literal
@@ -121,7 +120,7 @@ class VariableEntity(BaseModel):
allowed_file_types: Sequence[FileType] | None = Field(default_factory=list)
allowed_file_extensions: Sequence[str] | None = Field(default_factory=list)
allowed_file_upload_methods: Sequence[FileTransferMethod] | None = Field(default_factory=list)
json_schema: str | None = Field(default=None)
json_schema: dict | None = Field(default=None)
@field_validator("description", mode="before")
@classmethod
@@ -135,17 +134,11 @@ class VariableEntity(BaseModel):
@field_validator("json_schema")
@classmethod
def validate_json_schema(cls, schema: str | None) -> str | None:
def validate_json_schema(cls, schema: dict | None) -> dict | None:
if schema is None:
return None
try:
json_schema = json.loads(schema)
except json.JSONDecodeError:
raise ValueError(f"invalid json_schema value {schema}")
try:
Draft7Validator.check_schema(json_schema)
Draft7Validator.check_schema(schema)
except SchemaError as e:
raise ValueError(f"Invalid JSON schema: {e.message}")
return schema
@@ -26,7 +26,6 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
@classmethod
def get_app_config(cls, app_model: App, workflow: Workflow) -> AdvancedChatAppConfig:
features_dict = workflow.features_dict
app_mode = AppMode.value_of(app_model.mode)
app_config = AdvancedChatAppConfig(
tenant_id=app_model.tenant_id,
@@ -24,7 +24,7 @@ from core.app.layers.conversation_variable_persist_layer import ConversationVari
from core.db.session_factory import session_factory
from core.moderation.base import ModerationError
from core.moderation.input_moderation import InputModeration
from core.variables.variables import VariableUnion
from core.variables.variables import Variable
from core.workflow.enums import WorkflowType
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
from core.workflow.graph_engine.layers.base import GraphEngineLayer
@@ -149,8 +149,8 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
system_variables=system_inputs,
user_inputs=inputs,
environment_variables=self._workflow.environment_variables,
# Based on the definition of `VariableUnion`,
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
# Based on the definition of `Variable`,
# `VariableBase` instances can be safely used as `Variable` since they are compatible.
conversation_variables=conversation_variables,
)
@@ -319,7 +319,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
trace_manager=app_generate_entity.trace_manager,
)
def _initialize_conversation_variables(self) -> list[VariableUnion]:
def _initialize_conversation_variables(self) -> list[Variable]:
"""
Initialize conversation variables for the current conversation.
@@ -344,7 +344,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
conversation_variables = [var.to_variable() for var in existing_variables]
session.commit()
return cast(list[VariableUnion], conversation_variables)
return cast(list[Variable], conversation_variables)
def _load_existing_conversation_variables(self, session: Session) -> list[ConversationVariable]:
"""
@@ -374,25 +374,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
if node_finish_resp:
yield node_finish_resp
# For ANSWER nodes, check if we need to send a message_replace event
# Only send if the final output differs from the accumulated task_state.answer
# This happens when variables were updated by variable_assigner during workflow execution
if event.node_type == NodeType.ANSWER and event.outputs:
final_answer = event.outputs.get("answer")
if final_answer is not None and final_answer != self._task_state.answer:
logger.info(
"ANSWER node final output '%s' differs from accumulated answer '%s', sending message_replace event",
final_answer,
self._task_state.answer,
)
# Update the task state answer
self._task_state.answer = str(final_answer)
# Send message_replace event to update the UI
yield self._message_cycle_manager.message_replace_to_stream_response(
answer=str(final_answer),
reason="variable_update",
)
def _handle_node_failed_events(
self,
event: Union[QueueNodeFailedEvent, QueueNodeExceptionEvent],
+20 -13
View File
@@ -1,4 +1,3 @@
import json
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Union, final
@@ -76,12 +75,24 @@ class BaseAppGenerator:
user_inputs = {**user_inputs, **files_inputs, **file_list_inputs}
# Check if all files are converted to File
if any(filter(lambda v: isinstance(v, dict), user_inputs.values())):
raise ValueError("Invalid input type")
if any(
filter(lambda v: isinstance(v, dict), filter(lambda item: isinstance(item, list), user_inputs.values()))
):
raise ValueError("Invalid input type")
invalid_dict_keys = [
k
for k, v in user_inputs.items()
if isinstance(v, dict)
and entity_dictionary[k].type not in {VariableEntityType.FILE, VariableEntityType.JSON_OBJECT}
]
if invalid_dict_keys:
raise ValueError(f"Invalid input type for {invalid_dict_keys}")
invalid_list_dict_keys = [
k
for k, v in user_inputs.items()
if isinstance(v, list)
and any(isinstance(item, dict) for item in v)
and entity_dictionary[k].type != VariableEntityType.FILE_LIST
]
if invalid_list_dict_keys:
raise ValueError(f"Invalid input type for {invalid_list_dict_keys}")
return user_inputs
@@ -178,12 +189,8 @@ class BaseAppGenerator:
elif value == 0:
value = False
case VariableEntityType.JSON_OBJECT:
if not isinstance(value, str):
raise ValueError(f"{variable_entity.variable} in input form must be a string")
try:
json.loads(value)
except json.JSONDecodeError:
raise ValueError(f"{variable_entity.variable} in input form must be a valid JSON object")
if value and not isinstance(value, dict):
raise ValueError(f"{variable_entity.variable} in input form must be a dict")
case _:
raise AssertionError("this statement should be unreachable.")
@@ -9,13 +9,13 @@ from core.app.entities.app_invoke_entities import (
InvokeFrom,
RagPipelineGenerateEntity,
)
from core.app.workflow.node_factory import DifyNodeFactory
from core.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput
from core.workflow.entities.graph_init_params import GraphInitParams
from core.workflow.enums import WorkflowType
from core.workflow.graph import Graph
from core.workflow.graph_engine.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
from core.workflow.graph_events import GraphEngineEvent, GraphRunFailedEvent
from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.runtime import GraphRuntimeState, VariablePool
+3 -2
View File
@@ -10,7 +10,7 @@ from typing import Any, Literal, Union, cast, overload
from flask import Flask, current_app
from pydantic import ValidationError
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.orm import sessionmaker
import contexts
from configs import dify_config
@@ -25,6 +25,7 @@ from core.app.apps.workflow.generate_response_converter import WorkflowAppGenera
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
from core.db.session_factory import session_factory
from core.helper.trace_id_helper import extract_external_trace_id_from_args
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager
@@ -492,7 +493,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
:return:
"""
with preserve_flask_contexts(flask_app, context_vars=context):
with Session(db.engine, expire_on_commit=False) as session:
with session_factory.create_session() as session:
workflow = session.scalar(
select(Workflow).where(
Workflow.tenant_id == application_generate_entity.app_config.tenant_id,
+7 -41
View File
@@ -25,6 +25,7 @@ from core.app.entities.queue_entities import (
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
from core.app.workflow.node_factory import DifyNodeFactory
from core.workflow.entities import GraphInitParams
from core.workflow.graph import Graph
from core.workflow.graph_engine.layers.base import GraphEngineLayer
@@ -53,7 +54,6 @@ from core.workflow.graph_events import (
)
from core.workflow.graph_events.graph import GraphRunAbortedEvent
from core.workflow.nodes import NodeType
from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
@@ -166,18 +166,22 @@ class WorkflowBasedAppRunner:
# Determine which type of single node execution and get graph/variable_pool
if single_iteration_run:
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
graph, variable_pool = self._get_graph_and_variable_pool_for_single_node_run(
workflow=workflow,
node_id=single_iteration_run.node_id,
user_inputs=dict(single_iteration_run.inputs),
graph_runtime_state=graph_runtime_state,
node_type_filter_key="iteration_id",
node_type_label="iteration",
)
elif single_loop_run:
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
graph, variable_pool = self._get_graph_and_variable_pool_for_single_node_run(
workflow=workflow,
node_id=single_loop_run.node_id,
user_inputs=dict(single_loop_run.inputs),
graph_runtime_state=graph_runtime_state,
node_type_filter_key="loop_id",
node_type_label="loop",
)
else:
raise ValueError("Neither single_iteration_run nor single_loop_run is specified")
@@ -314,44 +318,6 @@ class WorkflowBasedAppRunner:
return graph, variable_pool
def _get_graph_and_variable_pool_of_single_iteration(
self,
workflow: Workflow,
node_id: str,
user_inputs: dict[str, Any],
graph_runtime_state: GraphRuntimeState,
) -> tuple[Graph, VariablePool]:
"""
Get variable pool of single iteration
"""
return self._get_graph_and_variable_pool_for_single_node_run(
workflow=workflow,
node_id=node_id,
user_inputs=user_inputs,
graph_runtime_state=graph_runtime_state,
node_type_filter_key="iteration_id",
node_type_label="iteration",
)
def _get_graph_and_variable_pool_of_single_loop(
self,
workflow: Workflow,
node_id: str,
user_inputs: dict[str, Any],
graph_runtime_state: GraphRuntimeState,
) -> tuple[Graph, VariablePool]:
"""
Get variable pool of single loop
"""
return self._get_graph_and_variable_pool_for_single_node_run(
workflow=workflow,
node_id=node_id,
user_inputs=user_inputs,
graph_runtime_state=graph_runtime_state,
node_type_filter_key="loop_id",
node_type_label="loop",
)
def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent):
"""
Handle event
@@ -1,6 +1,6 @@
import logging
from core.variables import Variable
from core.variables import VariableBase
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
from core.workflow.enums import NodeType
@@ -44,7 +44,7 @@ class ConversationVariablePersistenceLayer(GraphEngineLayer):
if selector[0] != CONVERSATION_VARIABLE_NODE_ID:
continue
variable = self.graph_runtime_state.variable_pool.get(selector)
if not isinstance(variable, Variable):
if not isinstance(variable, VariableBase):
logger.warning(
"Conversation variable not found in variable pool. selector=%s",
selector,
+2 -4
View File
@@ -3,8 +3,8 @@ from datetime import UTC, datetime
from typing import Any, ClassVar
from pydantic import TypeAdapter
from sqlalchemy.orm import Session, sessionmaker
from core.db.session_factory import session_factory
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_events.base import GraphEngineEvent
from core.workflow.graph_events.graph import GraphRunFailedEvent, GraphRunPausedEvent, GraphRunSucceededEvent
@@ -31,13 +31,11 @@ class TriggerPostLayer(GraphEngineLayer):
cfs_plan_scheduler_entity: AsyncWorkflowCFSPlanEntity,
start_time: datetime,
trigger_log_id: str,
session_maker: sessionmaker[Session],
):
super().__init__()
self.trigger_log_id = trigger_log_id
self.start_time = start_time
self.cfs_plan_scheduler_entity = cfs_plan_scheduler_entity
self.session_maker = session_maker
def on_graph_start(self):
pass
@@ -47,7 +45,7 @@ class TriggerPostLayer(GraphEngineLayer):
Update trigger log with success or failure.
"""
if isinstance(event, tuple(self._STATUS_MAP.keys())):
with self.session_maker() as session:
with session_factory.create_session() as session:
repo = SQLAlchemyWorkflowTriggerLogRepository(session)
trigger_log = repo.get_by_id(self.trigger_log_id)
if not trigger_log:
+3
View File
@@ -0,0 +1,3 @@
from .node_factory import DifyNodeFactory
__all__ = ["DifyNodeFactory"]
@@ -1,16 +1,22 @@
from collections.abc import Sequence
from collections.abc import Callable, Sequence
from typing import TYPE_CHECKING, final
from typing_extensions import override
from configs import dify_config
from core.file import file_manager
from core.helper import ssrf_proxy
from core.helper.code_executor.code_executor import CodeExecutor
from core.helper.code_executor.code_node_provider import CodeNodeProvider
from core.tools.tool_file_manager import ToolFileManager
from core.workflow.enums import NodeType
from core.workflow.graph import NodeFactory
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.code.code_node import CodeNode
from core.workflow.nodes.code.limits import CodeNodeLimits
from core.workflow.nodes.http_request.node import HttpRequestNode
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
from core.workflow.nodes.protocols import FileManagerProtocol, HttpClientProtocol
from core.workflow.nodes.template_transform.template_renderer import (
CodeExecutorJinja2TemplateRenderer,
Jinja2TemplateRenderer,
@@ -18,8 +24,6 @@ from core.workflow.nodes.template_transform.template_renderer import (
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
from libs.typing import is_str, is_str_dict
from .node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
if TYPE_CHECKING:
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState
@@ -43,6 +47,9 @@ class DifyNodeFactory(NodeFactory):
code_providers: Sequence[type[CodeNodeProvider]] | None = None,
code_limits: CodeNodeLimits | None = None,
template_renderer: Jinja2TemplateRenderer | None = None,
http_request_http_client: HttpClientProtocol = ssrf_proxy,
http_request_tool_file_manager_factory: Callable[[], ToolFileManager] = ToolFileManager,
http_request_file_manager: FileManagerProtocol = file_manager,
) -> None:
self.graph_init_params = graph_init_params
self.graph_runtime_state = graph_runtime_state
@@ -61,6 +68,9 @@ class DifyNodeFactory(NodeFactory):
max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH,
)
self._template_renderer = template_renderer or CodeExecutorJinja2TemplateRenderer()
self._http_request_http_client = http_request_http_client
self._http_request_tool_file_manager_factory = http_request_tool_file_manager_factory
self._http_request_file_manager = http_request_file_manager
@override
def create_node(self, node_config: dict[str, object]) -> Node:
@@ -113,6 +123,7 @@ class DifyNodeFactory(NodeFactory):
code_providers=self._code_providers,
code_limits=self._code_limits,
)
if node_type == NodeType.TEMPLATE_TRANSFORM:
return TemplateTransformNode(
id=node_id,
@@ -122,6 +133,17 @@ class DifyNodeFactory(NodeFactory):
template_renderer=self._template_renderer,
)
if node_type == NodeType.HTTP_REQUEST:
return HttpRequestNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
http_client=self._http_request_http_client,
tool_file_manager_factory=self._http_request_tool_file_manager_factory,
file_manager=self._http_request_file_manager,
)
return node_class(
id=node_id,
config=node_config,
+4
View File
@@ -33,6 +33,10 @@ class MaxRetriesExceededError(ValueError):
pass
request_error = httpx.RequestError
max_retries_exceeded_error = MaxRetriesExceededError
def _create_proxy_mounts() -> dict[str, httpx.HTTPTransport]:
return {
"http://": httpx.HTTPTransport(
+11 -16
View File
@@ -71,8 +71,8 @@ class LLMGenerator:
response: LLMResult = model_instance.invoke_llm(
prompt_messages=list(prompts), model_parameters={"max_tokens": 500, "temperature": 1}, stream=False
)
answer = cast(str, response.message.content)
if answer is None:
answer = response.message.get_text_content()
if answer == "":
return ""
try:
result_dict = json.loads(answer)
@@ -184,7 +184,7 @@ class LLMGenerator:
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
)
rule_config["prompt"] = cast(str, response.message.content)
rule_config["prompt"] = response.message.get_text_content()
except InvokeError as e:
error = str(e)
@@ -237,13 +237,11 @@ class LLMGenerator:
return rule_config
rule_config["prompt"] = cast(str, prompt_content.message.content)
rule_config["prompt"] = prompt_content.message.get_text_content()
if not isinstance(prompt_content.message.content, str):
raise NotImplementedError("prompt content is not a string")
parameter_generate_prompt = parameter_template.format(
inputs={
"INPUT_TEXT": prompt_content.message.content,
"INPUT_TEXT": prompt_content.message.get_text_content(),
},
remove_template_variables=False,
)
@@ -253,7 +251,7 @@ class LLMGenerator:
statement_generate_prompt = statement_template.format(
inputs={
"TASK_DESCRIPTION": instruction,
"INPUT_TEXT": prompt_content.message.content,
"INPUT_TEXT": prompt_content.message.get_text_content(),
},
remove_template_variables=False,
)
@@ -263,7 +261,7 @@ class LLMGenerator:
parameter_content: LLMResult = model_instance.invoke_llm(
prompt_messages=list(parameter_messages), model_parameters=model_parameters, stream=False
)
rule_config["variables"] = re.findall(r'"\s*([^"]+)\s*"', cast(str, parameter_content.message.content))
rule_config["variables"] = re.findall(r'"\s*([^"]+)\s*"', parameter_content.message.get_text_content())
except InvokeError as e:
error = str(e)
error_step = "generate variables"
@@ -272,7 +270,7 @@ class LLMGenerator:
statement_content: LLMResult = model_instance.invoke_llm(
prompt_messages=list(statement_messages), model_parameters=model_parameters, stream=False
)
rule_config["opening_statement"] = cast(str, statement_content.message.content)
rule_config["opening_statement"] = statement_content.message.get_text_content()
except InvokeError as e:
error = str(e)
error_step = "generate conversation opener"
@@ -315,7 +313,7 @@ class LLMGenerator:
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
)
generated_code = cast(str, response.message.content)
generated_code = response.message.get_text_content()
return {"code": generated_code, "language": code_language, "error": ""}
except InvokeError as e:
@@ -351,7 +349,7 @@ class LLMGenerator:
raise TypeError("Expected LLMResult when stream=False")
response = result
answer = cast(str, response.message.content)
answer = response.message.get_text_content()
return answer.strip()
@classmethod
@@ -375,10 +373,7 @@ class LLMGenerator:
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
)
raw_content = response.message.content
if not isinstance(raw_content, str):
raise ValueError(f"LLM response content must be a string, got: {type(raw_content)}")
raw_content = response.message.get_text_content()
try:
parsed_content = json.loads(raw_content)
@@ -251,10 +251,7 @@ class AssistantPromptMessage(PromptMessage):
:return: True if prompt message is empty, False otherwise
"""
if not super().is_empty() and not self.tool_calls:
return False
return True
return super().is_empty() and not self.tool_calls
class SystemPromptMessage(PromptMessage):
+6 -2
View File
@@ -1,6 +1,7 @@
import logging
from collections.abc import Sequence
from opentelemetry.trace import SpanKind
from sqlalchemy.orm import sessionmaker
from core.ops.aliyun_trace.data_exporter.traceclient import (
@@ -54,7 +55,7 @@ from core.ops.entities.trace_entity import (
ToolTraceInfo,
WorkflowTraceInfo,
)
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.repositories import DifyCoreRepositoryFactory
from core.workflow.entities import WorkflowNodeExecution
from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey
from extensions.ext_database import db
@@ -151,6 +152,7 @@ class AliyunDataTrace(BaseTraceInstance):
),
status=status,
links=trace_metadata.links,
span_kind=SpanKind.SERVER,
)
self.trace_client.add_span(message_span)
@@ -273,7 +275,7 @@ class AliyunDataTrace(BaseTraceInstance):
service_account = self.get_service_account_with_tenant(app_id)
session_factory = sessionmaker(bind=db.engine)
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=service_account,
app_id=app_id,
@@ -456,6 +458,7 @@ class AliyunDataTrace(BaseTraceInstance):
),
status=status,
links=trace_metadata.links,
span_kind=SpanKind.SERVER,
)
self.trace_client.add_span(message_span)
@@ -475,6 +478,7 @@ class AliyunDataTrace(BaseTraceInstance):
),
status=status,
links=trace_metadata.links,
span_kind=SpanKind.SERVER if message_span_id is None else SpanKind.INTERNAL,
)
self.trace_client.add_span(workflow_span)
@@ -166,7 +166,7 @@ class SpanBuilder:
attributes=span_data.attributes,
events=span_data.events,
links=span_data.links,
kind=trace_api.SpanKind.INTERNAL,
kind=span_data.span_kind,
status=span_data.status,
start_time=span_data.start_time,
end_time=span_data.end_time,
@@ -4,7 +4,7 @@ from typing import Any
from opentelemetry import trace as trace_api
from opentelemetry.sdk.trace import Event
from opentelemetry.trace import Status, StatusCode
from opentelemetry.trace import SpanKind, Status, StatusCode
from pydantic import BaseModel, Field
@@ -34,3 +34,4 @@ class SpanData(BaseModel):
status: Status = Field(default=Status(StatusCode.UNSET), description="The status of the span.")
start_time: int | None = Field(..., description="The start time of the span in nanoseconds.")
end_time: int | None = Field(..., description="The end time of the span in nanoseconds.")
span_kind: SpanKind = Field(default=SpanKind.INTERNAL, description="The OpenTelemetry SpanKind for this span.")
+3 -1
View File
@@ -35,7 +35,6 @@ from extensions.ext_database import db
from extensions.ext_storage import storage
from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig
from models.workflow import WorkflowAppLog
from repositories.factory import DifyAPIRepositoryFactory
from tasks.ops_trace_task import process_trace_tasks
if TYPE_CHECKING:
@@ -473,6 +472,9 @@ class TraceTask:
if cls._workflow_run_repo is None:
with cls._repo_lock:
if cls._workflow_run_repo is None:
# Lazy import to avoid circular import during module initialization
from repositories.factory import DifyAPIRepositoryFactory
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
cls._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
return cls._workflow_run_repo
+7 -8
View File
@@ -320,18 +320,17 @@ class BasePluginClient:
case PluginInvokeError.__name__:
error_object = json.loads(message)
invoke_error_type = error_object.get("error_type")
args = error_object.get("args")
match invoke_error_type:
case InvokeRateLimitError.__name__:
raise InvokeRateLimitError(description=args.get("description"))
raise InvokeRateLimitError(description=error_object.get("message"))
case InvokeAuthorizationError.__name__:
raise InvokeAuthorizationError(description=args.get("description"))
raise InvokeAuthorizationError(description=error_object.get("message"))
case InvokeBadRequestError.__name__:
raise InvokeBadRequestError(description=args.get("description"))
raise InvokeBadRequestError(description=error_object.get("message"))
case InvokeConnectionError.__name__:
raise InvokeConnectionError(description=args.get("description"))
raise InvokeConnectionError(description=error_object.get("message"))
case InvokeServerUnavailableError.__name__:
raise InvokeServerUnavailableError(description=args.get("description"))
raise InvokeServerUnavailableError(description=error_object.get("message"))
case CredentialsValidateFailedError.__name__:
raise CredentialsValidateFailedError(error_object.get("message"))
case EndpointSetupFailedError.__name__:
@@ -339,11 +338,11 @@ class BasePluginClient:
case TriggerProviderCredentialValidationError.__name__:
raise TriggerProviderCredentialValidationError(error_object.get("message"))
case TriggerPluginInvokeError.__name__:
raise TriggerPluginInvokeError(description=error_object.get("description"))
raise TriggerPluginInvokeError(description=error_object.get("message"))
case TriggerInvokeError.__name__:
raise TriggerInvokeError(error_object.get("message"))
case EventIgnoreError.__name__:
raise EventIgnoreError(description=error_object.get("description"))
raise EventIgnoreError(description=error_object.get("message"))
case _:
raise PluginInvokeError(description=message)
case PluginDaemonInternalServerError.__name__:
+21 -11
View File
@@ -1,5 +1,6 @@
from core.plugin.entities.endpoint import EndpointEntityWithInstance
from core.plugin.impl.base import BasePluginClient
from core.plugin.impl.exc import PluginDaemonInternalServerError
class PluginEndpointClient(BasePluginClient):
@@ -70,18 +71,27 @@ class PluginEndpointClient(BasePluginClient):
def delete_endpoint(self, tenant_id: str, user_id: str, endpoint_id: str):
"""
Delete the given endpoint.
This operation is idempotent: if the endpoint is already deleted (record not found),
it will return True instead of raising an error.
"""
return self._request_with_plugin_daemon_response(
"POST",
f"plugin/{tenant_id}/endpoint/remove",
bool,
data={
"endpoint_id": endpoint_id,
},
headers={
"Content-Type": "application/json",
},
)
try:
return self._request_with_plugin_daemon_response(
"POST",
f"plugin/{tenant_id}/endpoint/remove",
bool,
data={
"endpoint_id": endpoint_id,
},
headers={
"Content-Type": "application/json",
},
)
except PluginDaemonInternalServerError as e:
# Make delete idempotent: if record is not found, consider it a success
if "record not found" in str(e.description).lower():
return True
raise
def enable_endpoint(self, tenant_id: str, user_id: str, endpoint_id: str):
"""
+114 -17
View File
@@ -154,7 +154,7 @@ class IrisConnectionPool:
# Add to cache to skip future checks
self._schemas_initialized.add(schema)
except Exception as e:
except Exception:
conn.rollback()
logger.exception("Failed to ensure schema %s exists", schema)
raise
@@ -177,6 +177,9 @@ class IrisConnectionPool:
class IrisVector(BaseVector):
"""IRIS vector database implementation using native VECTOR type and HNSW indexing."""
# Fallback score for full-text search when Rank function unavailable or TEXT_INDEX disabled
_FULL_TEXT_FALLBACK_SCORE = 0.5
def __init__(self, collection_name: str, config: IrisVectorConfig) -> None:
super().__init__(collection_name)
self.config = config
@@ -272,41 +275,131 @@ class IrisVector(BaseVector):
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
"""Search documents by full-text using iFind index or fallback to LIKE search."""
"""Search documents by full-text using iFind index with BM25 relevance scoring.
When IRIS_TEXT_INDEX is enabled, this method uses the auto-generated Rank
function from %iFind.Index.Basic to calculate BM25 relevance scores. The Rank
function is automatically created with naming: {schema}.{table_name}_{index}Rank
Args:
query: Search query string
**kwargs: Optional parameters including top_k, document_ids_filter
Returns:
List of Document objects with relevance scores in metadata["score"]
"""
top_k = kwargs.get("top_k", 5)
document_ids_filter = kwargs.get("document_ids_filter")
with self._get_cursor() as cursor:
if self.config.IRIS_TEXT_INDEX:
# Use iFind full-text search with index
# Use iFind full-text search with auto-generated Rank function
text_index_name = f"idx_{self.table_name}_text"
# IRIS removes underscores from function names
table_no_underscore = self.table_name.replace("_", "")
index_no_underscore = text_index_name.replace("_", "")
rank_function = f"{self.schema}.{table_no_underscore}_{index_no_underscore}Rank"
# Build WHERE clause with document ID filter if provided
where_clause = f"WHERE %ID %FIND search_index({text_index_name}, ?)"
# First param for Rank function, second for FIND
params = [query, query]
if document_ids_filter:
# Add document ID filter
placeholders = ",".join("?" * len(document_ids_filter))
where_clause += f" AND JSON_VALUE(meta, '$.document_id') IN ({placeholders})"
params.extend(document_ids_filter)
sql = f"""
SELECT TOP {top_k} id, text, meta
SELECT TOP {top_k}
id,
text,
meta,
{rank_function}(%ID, ?) AS score
FROM {self.schema}.{self.table_name}
WHERE %ID %FIND search_index({text_index_name}, ?)
{where_clause}
ORDER BY score DESC
"""
cursor.execute(sql, (query,))
logger.debug(
"iFind search: query='%s', index='%s', rank='%s'",
query,
text_index_name,
rank_function,
)
try:
cursor.execute(sql, params)
except Exception: # pylint: disable=broad-exception-caught
# Fallback to query without Rank function if it fails
logger.warning(
"Rank function '%s' failed, using fixed score",
rank_function,
exc_info=True,
)
sql_fallback = f"""
SELECT TOP {top_k} id, text, meta, {self._FULL_TEXT_FALLBACK_SCORE} AS score
FROM {self.schema}.{self.table_name}
{where_clause}
"""
# Skip first param (for Rank function)
cursor.execute(sql_fallback, params[1:])
else:
# Fallback to LIKE search (inefficient for large datasets)
# Escape special characters for LIKE clause to prevent SQL injection
from libs.helper import escape_like_pattern
# Fallback to LIKE search (IRIS_TEXT_INDEX disabled)
from libs.helper import ( # pylint: disable=import-outside-toplevel
escape_like_pattern,
)
escaped_query = escape_like_pattern(query)
query_pattern = f"%{escaped_query}%"
# Build WHERE clause with document ID filter if provided
where_clause = "WHERE text LIKE ? ESCAPE '\\\\'"
params = [query_pattern]
if document_ids_filter:
placeholders = ",".join("?" * len(document_ids_filter))
where_clause += f" AND JSON_VALUE(meta, '$.document_id') IN ({placeholders})"
params.extend(document_ids_filter)
sql = f"""
SELECT TOP {top_k} id, text, meta
SELECT TOP {top_k} id, text, meta, {self._FULL_TEXT_FALLBACK_SCORE} AS score
FROM {self.schema}.{self.table_name}
WHERE text LIKE ? ESCAPE '\\'
{where_clause}
ORDER BY LENGTH(text) ASC
"""
cursor.execute(sql, (query_pattern,))
logger.debug(
"LIKE fallback (TEXT_INDEX disabled): query='%s'",
query_pattern,
)
cursor.execute(sql, params)
docs = []
for row in cursor.fetchall():
if len(row) >= 3:
metadata = json.loads(row[2]) if row[2] else {}
docs.append(Document(page_content=row[1], metadata=metadata))
# Expecting 4 columns: id, text, meta, score
if len(row) >= 4:
text_content = row[1]
meta_str = row[2]
score_value = row[3]
metadata = json.loads(meta_str) if meta_str else {}
# Add score to metadata for hybrid search compatibility
score = float(score_value) if score_value is not None else 0.0
metadata["score"] = score
docs.append(Document(page_content=text_content, metadata=metadata))
logger.info(
"Full-text search completed: query='%s', results=%d/%d",
query,
len(docs),
top_k,
)
if not docs:
logger.info("Full-text search for '%s' returned no results", query)
logger.warning("Full-text search for '%s' returned no results", query)
return docs
@@ -370,7 +463,11 @@ class IrisVector(BaseVector):
AS %iFind.Index.Basic
(LANGUAGE = '{language}', LOWER = 1, INDEXOPTION = 0)
"""
logger.info("Creating text index: %s with language: %s", text_index_name, language)
logger.info(
"Creating text index: %s with language: %s",
text_index_name,
language,
)
logger.info("SQL for text index: %s", sql_text_index)
cursor.execute(sql_text_index)
logger.info("Text index created successfully: %s", text_index_name)
+22 -3
View File
@@ -130,7 +130,7 @@ class ToolInvokeMessage(BaseModel):
text: str
class JsonMessage(BaseModel):
json_object: dict
json_object: dict | list
suppress_output: bool = Field(default=False, description="Whether to suppress JSON output in result string")
class BlobMessage(BaseModel):
@@ -144,7 +144,14 @@ class ToolInvokeMessage(BaseModel):
end: bool = Field(..., description="Whether the chunk is the last chunk")
class FileMessage(BaseModel):
pass
file_marker: str = Field(default="file_marker")
@model_validator(mode="before")
@classmethod
def validate_file_message(cls, values):
if isinstance(values, dict) and "file_marker" not in values:
raise ValueError("Invalid FileMessage: missing file_marker")
return values
class VariableMessage(BaseModel):
variable_name: str = Field(..., description="The name of the variable")
@@ -234,10 +241,22 @@ class ToolInvokeMessage(BaseModel):
@field_validator("message", mode="before")
@classmethod
def decode_blob_message(cls, v):
def decode_blob_message(cls, v, info: ValidationInfo):
# 处理 blob 解码
if isinstance(v, dict) and "blob" in v:
with contextlib.suppress(Exception):
v["blob"] = base64.b64decode(v["blob"])
# Force correct message type based on type field
# Only wrap dict types to avoid wrapping already parsed Pydantic model objects
if info.data and isinstance(info.data, dict) and isinstance(v, dict):
msg_type = info.data.get("type")
if msg_type == cls.MessageType.JSON:
if "json_object" not in v:
v = {"json_object": v}
elif msg_type == cls.MessageType.FILE:
v = {"file_marker": "file_marker"}
return v
@field_serializer("message")
+9
View File
@@ -1,5 +1,6 @@
import contextlib
import json
import logging
from collections.abc import Generator, Iterable
from copy import deepcopy
from datetime import UTC, datetime
@@ -36,6 +37,8 @@ from extensions.ext_database import db
from models.enums import CreatorUserRole
from models.model import Message, MessageFile
logger = logging.getLogger(__name__)
class ToolEngine:
"""
@@ -123,25 +126,31 @@ class ToolEngine:
# transform tool invoke message to get LLM friendly message
return plain_text, message_files, meta
except ToolProviderCredentialValidationError as e:
logger.error(e, exc_info=True)
error_response = "Please check your tool provider credentials"
agent_tool_callback.on_tool_error(e)
except (ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError) as e:
error_response = f"there is not a tool named {tool.entity.identity.name}"
logger.error(e, exc_info=True)
agent_tool_callback.on_tool_error(e)
except ToolParameterValidationError as e:
error_response = f"tool parameters validation error: {e}, please check your tool parameters"
agent_tool_callback.on_tool_error(e)
logger.error(e, exc_info=True)
except ToolInvokeError as e:
error_response = f"tool invoke error: {e}"
agent_tool_callback.on_tool_error(e)
logger.error(e, exc_info=True)
except ToolEngineInvokeError as e:
meta = e.meta
error_response = f"tool invoke error: {meta.error}"
agent_tool_callback.on_tool_error(e)
logger.error(e, exc_info=True)
return error_response, [], meta
except Exception as e:
error_response = f"unknown error: {e}"
agent_tool_callback.on_tool_error(e)
logger.error(e, exc_info=True)
return error_response, [], ToolInvokeMeta.error_instance(error_response)
+30 -43
View File
@@ -5,10 +5,9 @@ import logging
from collections.abc import Generator, Mapping, Sequence
from typing import Any, cast
from flask import has_request_context
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.db.session_factory import session_factory
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
from core.tools.__base.tool import Tool
@@ -20,9 +19,7 @@ from core.tools.entities.tool_entities import (
ToolProviderType,
)
from core.tools.errors import ToolInvokeError
from extensions.ext_database import db
from factories.file_factory import build_from_mapping
from libs.login import current_user
from models import Account, Tenant
from models.model import App, EndUser
from models.workflow import Workflow
@@ -210,50 +207,38 @@ class WorkflowTool(Tool):
Returns:
Account | EndUser | None: The resolved user object, or None if resolution fails.
"""
if has_request_context():
return self._resolve_user_from_request()
else:
return self._resolve_user_from_database(user_id=user_id)
def _resolve_user_from_request(self) -> Account | EndUser | None:
"""
Resolve user from Flask request context.
"""
try:
# Note: `current_user` is a LocalProxy. Never compare it with None directly.
return getattr(current_user, "_get_current_object", lambda: current_user)()
except Exception as e:
logger.warning("Failed to resolve user from request context: %s", e)
return None
return self._resolve_user_from_database(user_id=user_id)
def _resolve_user_from_database(self, user_id: str) -> Account | EndUser | None:
"""
Resolve user from database (worker/Celery context).
"""
with session_factory.create_session() as session:
tenant_stmt = select(Tenant).where(Tenant.id == self.runtime.tenant_id)
tenant = session.scalar(tenant_stmt)
if not tenant:
return None
user_stmt = select(Account).where(Account.id == user_id)
user = session.scalar(user_stmt)
if user:
user.current_tenant = tenant
session.expunge(user)
return user
end_user_stmt = select(EndUser).where(EndUser.id == user_id, EndUser.tenant_id == tenant.id)
end_user = session.scalar(end_user_stmt)
if end_user:
session.expunge(end_user)
return end_user
tenant_stmt = select(Tenant).where(Tenant.id == self.runtime.tenant_id)
tenant = db.session.scalar(tenant_stmt)
if not tenant:
return None
user_stmt = select(Account).where(Account.id == user_id)
user = db.session.scalar(user_stmt)
if user:
user.current_tenant = tenant
return user
end_user_stmt = select(EndUser).where(EndUser.id == user_id, EndUser.tenant_id == tenant.id)
end_user = db.session.scalar(end_user_stmt)
if end_user:
return end_user
return None
def _get_workflow(self, app_id: str, version: str) -> Workflow:
"""
get the workflow by app id and version
"""
with Session(db.engine, expire_on_commit=False) as session, session.begin():
with session_factory.create_session() as session, session.begin():
if not version:
stmt = (
select(Workflow)
@@ -265,22 +250,24 @@ class WorkflowTool(Tool):
stmt = select(Workflow).where(Workflow.app_id == app_id, Workflow.version == version)
workflow = session.scalar(stmt)
if not workflow:
raise ValueError("workflow not found or not published")
if not workflow:
raise ValueError("workflow not found or not published")
return workflow
session.expunge(workflow)
return workflow
def _get_app(self, app_id: str) -> App:
"""
get the app by app id
"""
stmt = select(App).where(App.id == app_id)
with Session(db.engine, expire_on_commit=False) as session, session.begin():
with session_factory.create_session() as session, session.begin():
app = session.scalar(stmt)
if not app:
raise ValueError("app not found")
if not app:
raise ValueError("app not found")
return app
session.expunge(app)
return app
def _transform_args(self, tool_parameters: dict) -> tuple[dict, list[dict]]:
"""
+2
View File
@@ -30,6 +30,7 @@ from .variables import (
SecretVariable,
StringVariable,
Variable,
VariableBase,
)
__all__ = [
@@ -62,4 +63,5 @@ __all__ = [
"StringSegment",
"StringVariable",
"Variable",
"VariableBase",
]
+1 -1
View File
@@ -232,7 +232,7 @@ def get_segment_discriminator(v: Any) -> SegmentType | None:
# - All variants in `SegmentUnion` must inherit from the `Segment` class.
# - The union must include all non-abstract subclasses of `Segment`, except:
# - `SegmentGroup`, which is not added to the variable pool.
# - `Variable` and its subclasses, which are handled by `VariableUnion`.
# - `VariableBase` and its subclasses, which are handled by `Variable`.
SegmentUnion: TypeAlias = Annotated[
(
Annotated[NoneSegment, Tag(SegmentType.NONE)]
+14 -14
View File
@@ -27,7 +27,7 @@ from .segments import (
from .types import SegmentType
class Variable(Segment):
class VariableBase(Segment):
"""
A variable is a segment that has a name.
@@ -45,23 +45,23 @@ class Variable(Segment):
selector: Sequence[str] = Field(default_factory=list)
class StringVariable(StringSegment, Variable):
class StringVariable(StringSegment, VariableBase):
pass
class FloatVariable(FloatSegment, Variable):
class FloatVariable(FloatSegment, VariableBase):
pass
class IntegerVariable(IntegerSegment, Variable):
class IntegerVariable(IntegerSegment, VariableBase):
pass
class ObjectVariable(ObjectSegment, Variable):
class ObjectVariable(ObjectSegment, VariableBase):
pass
class ArrayVariable(ArraySegment, Variable):
class ArrayVariable(ArraySegment, VariableBase):
pass
@@ -89,16 +89,16 @@ class SecretVariable(StringVariable):
return encrypter.obfuscated_token(self.value)
class NoneVariable(NoneSegment, Variable):
class NoneVariable(NoneSegment, VariableBase):
value_type: SegmentType = SegmentType.NONE
value: None = None
class FileVariable(FileSegment, Variable):
class FileVariable(FileSegment, VariableBase):
pass
class BooleanVariable(BooleanSegment, Variable):
class BooleanVariable(BooleanSegment, VariableBase):
pass
@@ -139,13 +139,13 @@ class RAGPipelineVariableInput(BaseModel):
value: Any
# The `VariableUnion`` type is used to enable serialization and deserialization with Pydantic.
# Use `Variable` for type hinting when serialization is not required.
# The `Variable` type is used to enable serialization and deserialization with Pydantic.
# Use `VariableBase` for type hinting when serialization is not required.
#
# Note:
# - All variants in `VariableUnion` must inherit from the `Variable` class.
# - The union must include all non-abstract subclasses of `Segment`, except:
VariableUnion: TypeAlias = Annotated[
# - All variants in `Variable` must inherit from the `VariableBase` class.
# - The union must include all non-abstract subclasses of `VariableBase`.
Variable: TypeAlias = Annotated[
(
Annotated[NoneVariable, Tag(SegmentType.NONE)]
| Annotated[StringVariable, Tag(SegmentType.STRING)]
+34
View File
@@ -0,0 +1,34 @@
"""
Execution Context - Context management for workflow execution.
This package provides Flask-independent context management for workflow
execution in multi-threaded environments.
"""
from core.workflow.context.execution_context import (
AppContext,
ContextProviderNotFoundError,
ExecutionContext,
IExecutionContext,
NullAppContext,
capture_current_context,
read_context,
register_context,
register_context_capturer,
reset_context_provider,
)
from core.workflow.context.models import SandboxContext
__all__ = [
"AppContext",
"ContextProviderNotFoundError",
"ExecutionContext",
"IExecutionContext",
"NullAppContext",
"SandboxContext",
"capture_current_context",
"read_context",
"register_context",
"register_context_capturer",
"reset_context_provider",
]
@@ -0,0 +1,284 @@
"""
Execution Context - Abstracted context management for workflow execution.
"""
import contextvars
import threading
from abc import ABC, abstractmethod
from collections.abc import Callable, Generator
from contextlib import AbstractContextManager, contextmanager
from typing import Any, Protocol, TypeVar, final, runtime_checkable
from pydantic import BaseModel
class AppContext(ABC):
"""
Abstract application context interface.
This abstraction allows workflow execution to work with or without Flask
by providing a common interface for application context management.
"""
@abstractmethod
def get_config(self, key: str, default: Any = None) -> Any:
"""Get configuration value by key."""
pass
@abstractmethod
def get_extension(self, name: str) -> Any:
"""Get Flask extension by name (e.g., 'db', 'cache')."""
pass
@abstractmethod
def enter(self) -> AbstractContextManager[None]:
"""Enter the application context."""
pass
@runtime_checkable
class IExecutionContext(Protocol):
"""
Protocol for execution context.
This protocol defines the interface that all execution contexts must implement,
allowing both ExecutionContext and FlaskExecutionContext to be used interchangeably.
"""
def __enter__(self) -> "IExecutionContext":
"""Enter the execution context."""
...
def __exit__(self, *args: Any) -> None:
"""Exit the execution context."""
...
@property
def user(self) -> Any:
"""Get user object."""
...
@final
class ExecutionContext:
"""
Execution context for workflow execution in worker threads.
This class encapsulates all context needed for workflow execution:
- Application context (Flask app or standalone)
- Context variables for Python contextvars
- User information (optional)
It is designed to be serializable and passable to worker threads.
"""
def __init__(
self,
app_context: AppContext | None = None,
context_vars: contextvars.Context | None = None,
user: Any = None,
) -> None:
"""
Initialize execution context.
Args:
app_context: Application context (Flask or standalone)
context_vars: Python contextvars to preserve
user: User object (optional)
"""
self._app_context = app_context
self._context_vars = context_vars
self._user = user
self._local = threading.local()
@property
def app_context(self) -> AppContext | None:
"""Get application context."""
return self._app_context
@property
def context_vars(self) -> contextvars.Context | None:
"""Get context variables."""
return self._context_vars
@property
def user(self) -> Any:
"""Get user object."""
return self._user
@contextmanager
def enter(self) -> Generator[None, None, None]:
"""
Enter this execution context.
This is a convenience method that creates a context manager.
"""
# Restore context variables if provided
if self._context_vars:
for var, val in self._context_vars.items():
var.set(val)
# Enter app context if available
if self._app_context is not None:
with self._app_context.enter():
yield
else:
yield
def __enter__(self) -> "ExecutionContext":
"""Enter the execution context."""
cm = self.enter()
self._local.cm = cm
cm.__enter__()
return self
def __exit__(self, *args: Any) -> None:
"""Exit the execution context."""
cm = getattr(self._local, "cm", None)
if cm is not None:
cm.__exit__(*args)
class NullAppContext(AppContext):
"""
Null implementation of AppContext for non-Flask environments.
This is used when running without Flask (e.g., in tests or standalone mode).
"""
def __init__(self, config: dict[str, Any] | None = None) -> None:
"""
Initialize null app context.
Args:
config: Optional configuration dictionary
"""
self._config = config or {}
self._extensions: dict[str, Any] = {}
def get_config(self, key: str, default: Any = None) -> Any:
"""Get configuration value by key."""
return self._config.get(key, default)
def get_extension(self, name: str) -> Any:
"""Get extension by name."""
return self._extensions.get(name)
def set_extension(self, name: str, extension: Any) -> None:
"""Set extension by name."""
self._extensions[name] = extension
@contextmanager
def enter(self) -> Generator[None, None, None]:
"""Enter null context (no-op)."""
yield
class ExecutionContextBuilder:
"""
Builder for creating ExecutionContext instances.
This provides a fluent API for building execution contexts.
"""
def __init__(self) -> None:
self._app_context: AppContext | None = None
self._context_vars: contextvars.Context | None = None
self._user: Any = None
def with_app_context(self, app_context: AppContext) -> "ExecutionContextBuilder":
"""Set application context."""
self._app_context = app_context
return self
def with_context_vars(self, context_vars: contextvars.Context) -> "ExecutionContextBuilder":
"""Set context variables."""
self._context_vars = context_vars
return self
def with_user(self, user: Any) -> "ExecutionContextBuilder":
"""Set user."""
self._user = user
return self
def build(self) -> ExecutionContext:
"""Build the execution context."""
return ExecutionContext(
app_context=self._app_context,
context_vars=self._context_vars,
user=self._user,
)
_capturer: Callable[[], IExecutionContext] | None = None
# Tenant-scoped providers using tuple keys for clarity and constant-time lookup.
# Key mapping:
# (name, tenant_id) -> provider
# - name: namespaced identifier (recommend prefixing, e.g. "workflow.sandbox")
# - tenant_id: tenant identifier string
# Value:
# provider: Callable[[], BaseModel] returning the typed context value
# Type-safety note:
# - This registry cannot enforce that all providers for a given name return the same BaseModel type.
# - Implementors SHOULD provide typed wrappers around register/read (like Go's context best practice),
# e.g. def register_sandbox_ctx(tenant_id: str, p: Callable[[], SandboxContext]) and
# def read_sandbox_ctx(tenant_id: str) -> SandboxContext.
_tenant_context_providers: dict[tuple[str, str], Callable[[], BaseModel]] = {}
T = TypeVar("T", bound=BaseModel)
class ContextProviderNotFoundError(KeyError):
"""Raised when a tenant-scoped context provider is missing for a given (name, tenant_id)."""
pass
def register_context_capturer(capturer: Callable[[], IExecutionContext]) -> None:
"""Register a single enterable execution context capturer (e.g., Flask)."""
global _capturer
_capturer = capturer
def register_context(name: str, tenant_id: str, provider: Callable[[], BaseModel]) -> None:
"""Register a tenant-specific provider for a named context.
Tip: use a namespaced "name" (e.g., "workflow.sandbox") to avoid key collisions.
Consider adding a typed wrapper for this registration in your feature module.
"""
_tenant_context_providers[(name, tenant_id)] = provider
def read_context(name: str, *, tenant_id: str) -> BaseModel:
"""
Read a context value for a specific tenant.
Raises KeyError if the provider for (name, tenant_id) is not registered.
"""
prov = _tenant_context_providers.get((name, tenant_id))
if prov is None:
raise ContextProviderNotFoundError(f"Context provider '{name}' not registered for tenant '{tenant_id}'")
return prov()
def capture_current_context() -> IExecutionContext:
"""
Capture current execution context from the calling environment.
If a capturer is registered (e.g., Flask), use it. Otherwise, return a minimal
context with NullAppContext + copy of current contextvars.
"""
if _capturer is None:
return ExecutionContext(
app_context=NullAppContext(),
context_vars=contextvars.copy_context(),
)
return _capturer()
def reset_context_provider() -> None:
"""Reset the capturer and all tenant-scoped context providers (primarily for tests)."""
global _capturer
_capturer = None
_tenant_context_providers.clear()
+13
View File
@@ -0,0 +1,13 @@
from __future__ import annotations
from pydantic import AnyHttpUrl, BaseModel
class SandboxContext(BaseModel):
"""Typed context for sandbox integration. All fields optional by design."""
sandbox_url: AnyHttpUrl | None = None
sandbox_token: str | None = None # optional, if later needed for auth
__all__ = ["SandboxContext"]
@@ -1,7 +1,7 @@
import abc
from typing import Protocol
from core.variables import Variable
from core.variables import VariableBase
class ConversationVariableUpdater(Protocol):
@@ -20,12 +20,12 @@ class ConversationVariableUpdater(Protocol):
"""
@abc.abstractmethod
def update(self, conversation_id: str, variable: "Variable"):
def update(self, conversation_id: str, variable: "VariableBase"):
"""
Updates the value of the specified conversation variable in the underlying storage.
:param conversation_id: The ID of the conversation to update. Typically references `ConversationVariable.id`.
:param variable: The `Variable` instance containing the updated value.
:param variable: The `VariableBase` instance containing the updated value.
"""
pass
+4
View File
@@ -211,6 +211,10 @@ class WorkflowExecutionStatus(StrEnum):
def is_ended(self) -> bool:
return self in _END_STATE
@classmethod
def ended_values(cls) -> list[str]:
return [status.value for status in _END_STATE]
_END_STATE = frozenset(
[
@@ -11,7 +11,7 @@ from typing import Any
from pydantic import BaseModel, Field
from core.variables.variables import VariableUnion
from core.variables.variables import Variable
class CommandType(StrEnum):
@@ -46,7 +46,7 @@ class PauseCommand(GraphEngineCommand):
class VariableUpdate(BaseModel):
"""Represents a single variable update instruction."""
value: VariableUnion = Field(description="New variable value")
value: Variable = Field(description="New variable value")
class UpdateVariablesCommand(GraphEngineCommand):
+4 -16
View File
@@ -7,15 +7,13 @@ Domain-Driven Design principles for improved maintainability and testability.
from __future__ import annotations
import contextvars
import logging
import queue
import threading
from collections.abc import Generator
from typing import TYPE_CHECKING, cast, final
from flask import Flask, current_app
from core.workflow.context import capture_current_context
from core.workflow.enums import NodeExecutionType
from core.workflow.graph import Graph
from core.workflow.graph_events import (
@@ -159,17 +157,8 @@ class GraphEngine:
self._layers: list[GraphEngineLayer] = []
# === Worker Pool Setup ===
# Capture Flask app context for worker threads
flask_app: Flask | None = None
try:
app = current_app._get_current_object() # type: ignore
if isinstance(app, Flask):
flask_app = app
except RuntimeError:
pass
# Capture context variables for worker threads
context_vars = contextvars.copy_context()
# Capture execution context for worker threads
execution_context = capture_current_context()
# Create worker pool for parallel node execution
self._worker_pool = WorkerPool(
@@ -177,8 +166,7 @@ class GraphEngine:
event_queue=self._event_queue,
graph=self._graph,
layers=self._layers,
flask_app=flask_app,
context_vars=context_vars,
execution_context=execution_context,
min_workers=self._min_workers,
max_workers=self._max_workers,
scale_up_threshold=self._scale_up_threshold,
+13 -18
View File
@@ -5,26 +5,26 @@ Workers pull node IDs from the ready_queue, execute nodes, and push events
to the event_queue for the dispatcher to process.
"""
import contextvars
import queue
import threading
import time
from collections.abc import Sequence
from datetime import datetime
from typing import final
from uuid import uuid4
from typing import TYPE_CHECKING, final
from flask import Flask
from typing_extensions import override
from core.workflow.context import IExecutionContext
from core.workflow.graph import Graph
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent
from core.workflow.nodes.base.node import Node
from libs.flask_utils import preserve_flask_contexts
from .ready_queue import ReadyQueue
if TYPE_CHECKING:
pass
@final
class Worker(threading.Thread):
@@ -44,8 +44,7 @@ class Worker(threading.Thread):
layers: Sequence[GraphEngineLayer],
stop_event: threading.Event,
worker_id: int = 0,
flask_app: Flask | None = None,
context_vars: contextvars.Context | None = None,
execution_context: IExecutionContext | None = None,
) -> None:
"""
Initialize worker thread.
@@ -56,19 +55,17 @@ class Worker(threading.Thread):
graph: Graph containing nodes to execute
layers: Graph engine layers for node execution hooks
worker_id: Unique identifier for this worker
flask_app: Optional Flask application for context preservation
context_vars: Optional context variables to preserve in worker thread
execution_context: Optional execution context for context preservation
"""
super().__init__(name=f"GraphWorker-{worker_id}", daemon=True)
self._ready_queue = ready_queue
self._event_queue = event_queue
self._graph = graph
self._worker_id = worker_id
self._flask_app = flask_app
self._context_vars = context_vars
self._last_task_time = time.time()
self._execution_context = execution_context
self._stop_event = stop_event
self._layers = layers if layers is not None else []
self._last_task_time = time.time()
def stop(self) -> None:
"""Worker is controlled via shared stop_event from GraphEngine.
@@ -115,7 +112,7 @@ class Worker(threading.Thread):
self._ready_queue.task_done()
except Exception as e:
error_event = NodeRunFailedEvent(
id=str(uuid4()),
id=node.execution_id,
node_id=node.id,
node_type=node.node_type,
in_iteration_id=None,
@@ -135,11 +132,9 @@ class Worker(threading.Thread):
error: Exception | None = None
if self._flask_app and self._context_vars:
with preserve_flask_contexts(
flask_app=self._flask_app,
context_vars=self._context_vars,
):
# Execute the node with preserved context if execution context is provided
if self._execution_context is not None:
with self._execution_context:
self._invoke_node_run_start_hooks(node)
try:
node_events = node.run()
@@ -8,9 +8,10 @@ DynamicScaler, and WorkerFactory into a single class.
import logging
import queue
import threading
from typing import TYPE_CHECKING, final
from typing import final
from configs import dify_config
from core.workflow.context import IExecutionContext
from core.workflow.graph import Graph
from core.workflow.graph_events import GraphNodeEventBase
@@ -20,11 +21,6 @@ from ..worker import Worker
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from contextvars import Context
from flask import Flask
@final
class WorkerPool:
@@ -42,8 +38,7 @@ class WorkerPool:
graph: Graph,
layers: list[GraphEngineLayer],
stop_event: threading.Event,
flask_app: "Flask | None" = None,
context_vars: "Context | None" = None,
execution_context: IExecutionContext | None = None,
min_workers: int | None = None,
max_workers: int | None = None,
scale_up_threshold: int | None = None,
@@ -57,8 +52,7 @@ class WorkerPool:
event_queue: Queue for worker events
graph: The workflow graph
layers: Graph engine layers for node execution hooks
flask_app: Optional Flask app for context preservation
context_vars: Optional context variables
execution_context: Optional execution context for context preservation
min_workers: Minimum number of workers
max_workers: Maximum number of workers
scale_up_threshold: Queue depth to trigger scale up
@@ -67,8 +61,7 @@ class WorkerPool:
self._ready_queue = ready_queue
self._event_queue = event_queue
self._graph = graph
self._flask_app = flask_app
self._context_vars = context_vars
self._execution_context = execution_context
self._layers = layers
# Scaling parameters with defaults
@@ -152,8 +145,7 @@ class WorkerPool:
graph=self._graph,
layers=self._layers,
worker_id=worker_id,
flask_app=self._flask_app,
context_vars=self._context_vars,
execution_context=self._execution_context,
stop_event=self._stop_event,
)
+26 -10
View File
@@ -235,7 +235,18 @@ class AgentNode(Node[AgentNodeData]):
0,
):
value_param = param.get("value", {})
params[key] = value_param.get("value", "") if value_param is not None else None
if value_param and value_param.get("type", "") == "variable":
variable_selector = value_param.get("value")
if not variable_selector:
raise ValueError("Variable selector is missing for a variable-type parameter.")
variable = variable_pool.get(variable_selector)
if variable is None:
raise AgentVariableNotFoundError(str(variable_selector))
params[key] = variable.value
else:
params[key] = value_param.get("value", "") if value_param is not None else None
else:
params[key] = None
parameters = params
@@ -483,7 +494,7 @@ class AgentNode(Node[AgentNodeData]):
text = ""
files: list[File] = []
json_list: list[dict] = []
json_list: list[dict | list] = []
agent_logs: list[AgentLogEvent] = []
agent_execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {}
@@ -557,13 +568,18 @@ class AgentNode(Node[AgentNodeData]):
elif message.type == ToolInvokeMessage.MessageType.JSON:
assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
if node_type == NodeType.AGENT:
msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {})
llm_usage = LLMUsage.from_metadata(cast(LLMUsageMetadata, msg_metadata))
agent_execution_metadata = {
WorkflowNodeExecutionMetadataKey(key): value
for key, value in msg_metadata.items()
if key in WorkflowNodeExecutionMetadataKey.__members__.values()
}
if isinstance(message.message.json_object, dict):
msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {})
llm_usage = LLMUsage.from_metadata(cast(LLMUsageMetadata, msg_metadata))
agent_execution_metadata = {
WorkflowNodeExecutionMetadataKey(key): value
for key, value in msg_metadata.items()
if key in WorkflowNodeExecutionMetadataKey.__members__.values()
}
else:
msg_metadata = {}
llm_usage = LLMUsage.empty_usage()
agent_execution_metadata = {}
if message.message.json_object:
json_list.append(message.message.json_object)
elif message.type == ToolInvokeMessage.MessageType.LINK:
@@ -672,7 +688,7 @@ class AgentNode(Node[AgentNodeData]):
yield agent_log
# Add agent_logs to outputs['json'] to ensure frontend can access thinking process
json_output: list[dict[str, Any]] = []
json_output: list[dict[str, Any] | list[Any]] = []
# Step 1: append each agent log as its own dict.
if agent_logs:
+2 -6
View File
@@ -469,12 +469,8 @@ class Node(Generic[NodeDataT]):
import core.workflow.nodes as _nodes_pkg
for _, _modname, _ in pkgutil.walk_packages(_nodes_pkg.__path__, _nodes_pkg.__name__ + "."):
# Avoid importing modules that depend on the registry to prevent circular imports
# e.g. node_factory imports node_mapping which builds the mapping here.
if _modname in {
"core.workflow.nodes.node_factory",
"core.workflow.nodes.node_mapping",
}:
# Avoid importing modules that depend on the registry to prevent circular imports.
if _modname == "core.workflow.nodes.node_mapping":
continue
importlib.import_module(_modname)
@@ -301,7 +301,7 @@ class DatasourceNode(Node[DatasourceNodeData]):
text = ""
files: list[File] = []
json: list[dict] = []
json: list[dict | list] = []
variables: dict[str, Any] = {}
@@ -17,6 +17,7 @@ from core.helper import ssrf_proxy
from core.variables.segments import ArrayFileSegment, FileSegment
from core.workflow.runtime import VariablePool
from ..protocols import FileManagerProtocol, HttpClientProtocol
from .entities import (
HttpRequestNodeAuthorization,
HttpRequestNodeData,
@@ -78,6 +79,8 @@ class Executor:
timeout: HttpRequestNodeTimeout,
variable_pool: VariablePool,
max_retries: int = dify_config.SSRF_DEFAULT_MAX_RETRIES,
http_client: HttpClientProtocol = ssrf_proxy,
file_manager: FileManagerProtocol = file_manager,
):
# If authorization API key is present, convert the API key using the variable pool
if node_data.authorization.type == "api-key":
@@ -104,6 +107,8 @@ class Executor:
self.data = None
self.json = None
self.max_retries = max_retries
self._http_client = http_client
self._file_manager = file_manager
# init template
self.variable_pool = variable_pool
@@ -200,7 +205,7 @@ class Executor:
if file_variable is None:
raise FileFetchError(f"cannot fetch file with selector {file_selector}")
file = file_variable.value
self.content = file_manager.download(file)
self.content = self._file_manager.download(file)
case "x-www-form-urlencoded":
form_data = {
self.variable_pool.convert_template(item.key).text: self.variable_pool.convert_template(
@@ -239,7 +244,7 @@ class Executor:
):
file_tuple = (
file.filename,
file_manager.download(file),
self._file_manager.download(file),
file.mime_type or "application/octet-stream",
)
if key not in files:
@@ -332,19 +337,18 @@ class Executor:
do http request depending on api bundle
"""
_METHOD_MAP = {
"get": ssrf_proxy.get,
"head": ssrf_proxy.head,
"post": ssrf_proxy.post,
"put": ssrf_proxy.put,
"delete": ssrf_proxy.delete,
"patch": ssrf_proxy.patch,
"get": self._http_client.get,
"head": self._http_client.head,
"post": self._http_client.post,
"put": self._http_client.put,
"delete": self._http_client.delete,
"patch": self._http_client.patch,
}
method_lc = self.method.lower()
if method_lc not in _METHOD_MAP:
raise InvalidHttpMethodError(f"Invalid http method {self.method}")
request_args = {
"url": self.url,
"data": self.data,
"files": self.files,
"json": self.json,
@@ -357,8 +361,12 @@ class Executor:
}
# request_args = {k: v for k, v in request_args.items() if v is not None}
try:
response: httpx.Response = _METHOD_MAP[method_lc](**request_args, max_retries=self.max_retries)
except (ssrf_proxy.MaxRetriesExceededError, httpx.RequestError) as e:
response: httpx.Response = _METHOD_MAP[method_lc](
url=self.url,
**request_args,
max_retries=self.max_retries,
)
except (self._http_client.max_retries_exceeded_error, self._http_client.request_error) as e:
raise HttpRequestNodeError(str(e)) from e
# FIXME: fix type ignore, this maybe httpx type issue
return response
+33 -4
View File
@@ -1,10 +1,11 @@
import logging
import mimetypes
from collections.abc import Mapping, Sequence
from typing import Any
from collections.abc import Callable, Mapping, Sequence
from typing import TYPE_CHECKING, Any
from configs import dify_config
from core.file import File, FileTransferMethod
from core.file import File, FileTransferMethod, file_manager
from core.helper import ssrf_proxy
from core.tools.tool_file_manager import ToolFileManager
from core.variables.segments import ArrayFileSegment
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
@@ -13,6 +14,7 @@ from core.workflow.nodes.base import variable_template_parser
from core.workflow.nodes.base.entities import VariableSelector
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.http_request.executor import Executor
from core.workflow.nodes.protocols import FileManagerProtocol, HttpClientProtocol
from factories import file_factory
from .entities import (
@@ -30,10 +32,35 @@ HTTP_REQUEST_DEFAULT_TIMEOUT = HttpRequestNodeTimeout(
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState
class HttpRequestNode(Node[HttpRequestNodeData]):
node_type = NodeType.HTTP_REQUEST
def __init__(
self,
id: str,
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
*,
http_client: HttpClientProtocol = ssrf_proxy,
tool_file_manager_factory: Callable[[], ToolFileManager] = ToolFileManager,
file_manager: FileManagerProtocol = file_manager,
) -> None:
super().__init__(
id=id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
self._http_client = http_client
self._tool_file_manager_factory = tool_file_manager_factory
self._file_manager = file_manager
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
return {
@@ -71,6 +98,8 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
timeout=self._get_request_timeout(self.node_data),
variable_pool=self.graph_runtime_state.variable_pool,
max_retries=0,
http_client=self._http_client,
file_manager=self._file_manager,
)
process_data["request"] = http_executor.to_log()
@@ -199,7 +228,7 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
mime_type = (
content_disposition_type or content_type or mimetypes.guess_type(filename)[0] or "application/octet-stream"
)
tool_file_manager = ToolFileManager()
tool_file_manager = self._tool_file_manager_factory()
tool_file = tool_file_manager.create_file_by_raw(
user_id=self.user_id,
@@ -1,17 +1,15 @@
import contextvars
import logging
from collections.abc import Generator, Mapping, Sequence
from concurrent.futures import Future, ThreadPoolExecutor, as_completed
from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any, NewType, cast
from flask import Flask, current_app
from typing_extensions import TypeIs
from core.model_runtime.entities.llm_entities import LLMUsage
from core.variables import IntegerVariable, NoneSegment
from core.variables.segments import ArrayAnySegment, ArraySegment
from core.variables.variables import VariableUnion
from core.variables.variables import Variable
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.enums import (
NodeExecutionType,
@@ -39,7 +37,6 @@ from core.workflow.nodes.base.node import Node
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
from core.workflow.runtime import VariablePool
from libs.datetime_utils import naive_utc_now
from libs.flask_utils import preserve_flask_contexts
from .exc import (
InvalidIteratorValueError,
@@ -51,6 +48,7 @@ from .exc import (
)
if TYPE_CHECKING:
from core.workflow.context import IExecutionContext
from core.workflow.graph_engine import GraphEngine
logger = logging.getLogger(__name__)
@@ -240,7 +238,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
datetime,
list[GraphNodeEventBase],
object | None,
dict[str, VariableUnion],
dict[str, Variable],
LLMUsage,
]
],
@@ -252,8 +250,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
self._execute_single_iteration_parallel,
index=index,
item=item,
flask_app=current_app._get_current_object(), # type: ignore
context_vars=contextvars.copy_context(),
execution_context=self._capture_execution_context(),
)
future_to_index[future] = index
@@ -306,11 +303,10 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
self,
index: int,
item: object,
flask_app: Flask,
context_vars: contextvars.Context,
) -> tuple[datetime, list[GraphNodeEventBase], object | None, dict[str, VariableUnion], LLMUsage]:
execution_context: "IExecutionContext",
) -> tuple[datetime, list[GraphNodeEventBase], object | None, dict[str, Variable], LLMUsage]:
"""Execute a single iteration in parallel mode and return results."""
with preserve_flask_contexts(flask_app=flask_app, context_vars=context_vars):
with execution_context:
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
events: list[GraphNodeEventBase] = []
outputs_temp: list[object] = []
@@ -339,6 +335,12 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
graph_engine.graph_runtime_state.llm_usage,
)
def _capture_execution_context(self) -> "IExecutionContext":
"""Capture current execution context for parallel iterations."""
from core.workflow.context import capture_current_context
return capture_current_context()
def _handle_iteration_success(
self,
started_at: datetime,
@@ -515,11 +517,11 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
return variable_mapping
def _extract_conversation_variable_snapshot(self, *, variable_pool: VariablePool) -> dict[str, VariableUnion]:
def _extract_conversation_variable_snapshot(self, *, variable_pool: VariablePool) -> dict[str, Variable]:
conversation_variables = variable_pool.variable_dictionary.get(CONVERSATION_VARIABLE_NODE_ID, {})
return {name: variable.model_copy(deep=True) for name, variable in conversation_variables.items()}
def _sync_conversation_variables_from_snapshot(self, snapshot: dict[str, VariableUnion]) -> None:
def _sync_conversation_variables_from_snapshot(self, snapshot: dict[str, Variable]) -> None:
parent_pool = self.graph_runtime_state.variable_pool
parent_conversations = parent_pool.variable_dictionary.get(CONVERSATION_VARIABLE_NODE_ID, {})
@@ -586,11 +588,11 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
def _create_graph_engine(self, index: int, item: object):
# Import dependencies
from core.app.workflow.node_factory import DifyNodeFactory
from core.workflow.entities import GraphInitParams
from core.workflow.graph import Graph
from core.workflow.graph_engine import GraphEngine
from core.workflow.graph_engine.command_channels import InMemoryChannel
from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.runtime import GraphRuntimeState
# Create GraphInitParams from node attributes
+1 -1
View File
@@ -413,11 +413,11 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
def _create_graph_engine(self, start_at: datetime, root_node_id: str):
# Import dependencies
from core.app.workflow.node_factory import DifyNodeFactory
from core.workflow.entities import GraphInitParams
from core.workflow.graph import Graph
from core.workflow.graph_engine import GraphEngine
from core.workflow.graph_engine.command_channels import InMemoryChannel
from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.runtime import GraphRuntimeState
# Create GraphInitParams from node attributes
+29
View File
@@ -0,0 +1,29 @@
from typing import Protocol
import httpx
from core.file import File
class HttpClientProtocol(Protocol):
@property
def max_retries_exceeded_error(self) -> type[Exception]: ...
@property
def request_error(self) -> type[Exception]: ...
def get(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
def head(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
def post(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
def put(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
def delete(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
def patch(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
class FileManagerProtocol(Protocol):
def download(self, f: File, /) -> bytes: ...
+12 -16
View File
@@ -1,4 +1,3 @@
import json
from typing import Any
from jsonschema import Draft7Validator, ValidationError
@@ -43,25 +42,22 @@ class StartNode(Node[StartNodeData]):
if value is None and variable.required:
raise ValueError(f"{key} is required in input form")
# If no value provided, skip further processing for this key
if not value:
continue
if not isinstance(value, dict):
raise ValueError(f"JSON object for '{key}' must be an object")
# Overwrite with normalized dict to ensure downstream consistency
node_inputs[key] = value
# If schema exists, then validate against it
schema = variable.json_schema
if not schema:
continue
if not value:
continue
try:
json_schema = json.loads(schema)
except json.JSONDecodeError as e:
raise ValueError(f"{schema} must be a valid JSON object")
try:
json_value = json.loads(value)
except json.JSONDecodeError as e:
raise ValueError(f"{value} must be a valid JSON object")
try:
Draft7Validator(json_schema).validate(json_value)
Draft7Validator(schema).validate(value)
except ValidationError as e:
raise ValueError(f"JSON object for '{key}' does not match schema: {e.message}")
node_inputs[key] = json_value
+2 -2
View File
@@ -244,7 +244,7 @@ class ToolNode(Node[ToolNodeData]):
text = ""
files: list[File] = []
json: list[dict] = []
json: list[dict | list] = []
variables: dict[str, Any] = {}
@@ -400,7 +400,7 @@ class ToolNode(Node[ToolNodeData]):
message.message.metadata = dict_metadata
# Add agent_logs to outputs['json'] to ensure frontend can access thinking process
json_output: list[dict[str, Any]] = []
json_output: list[dict[str, Any] | list[Any]] = []
# Step 2: normalize JSON into {"data": [...]}.change json to list[dict]
if json:
@@ -1,7 +1,7 @@
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any
from core.variables import SegmentType, Variable
from core.variables import SegmentType, VariableBase
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.entities import GraphInitParams
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
@@ -33,6 +33,15 @@ class VariableAssignerNode(Node[VariableAssignerData]):
graph_runtime_state=graph_runtime_state,
)
def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool:
"""
Check if this Variable Assigner node blocks the output of specific variables.
Returns True if this node updates any of the requested conversation variables.
"""
assigned_selector = tuple(self.node_data.assigned_variable_selector)
return assigned_selector in variable_selectors
@classmethod
def version(cls) -> str:
return "1"
@@ -64,7 +73,7 @@ class VariableAssignerNode(Node[VariableAssignerData]):
assigned_variable_selector = self.node_data.assigned_variable_selector
# Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject
original_variable = self.graph_runtime_state.variable_pool.get(assigned_variable_selector)
if not isinstance(original_variable, Variable):
if not isinstance(original_variable, VariableBase):
raise VariableOperatorNodeError("assigned variable not found")
match self.node_data.write_mode:
@@ -2,7 +2,7 @@ import json
from collections.abc import Mapping, MutableMapping, Sequence
from typing import TYPE_CHECKING, Any
from core.variables import SegmentType, Variable
from core.variables import SegmentType, VariableBase
from core.variables.consts import SELECTORS_LENGTH
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
@@ -118,7 +118,7 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
# ==================== Validation Part
# Check if variable exists
if not isinstance(variable, Variable):
if not isinstance(variable, VariableBase):
raise VariableNotFoundError(variable_selector=item.variable_selector)
# Check if operation is supported
@@ -192,7 +192,7 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
for selector in updated_variable_selectors:
variable = self.graph_runtime_state.variable_pool.get(selector)
if not isinstance(variable, Variable):
if not isinstance(variable, VariableBase):
raise VariableNotFoundError(variable_selector=selector)
process_data[variable.name] = variable.value
@@ -213,7 +213,7 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
def _handle_item(
self,
*,
variable: Variable,
variable: VariableBase,
operation: Operation,
value: Any,
):
+11 -11
View File
@@ -9,10 +9,10 @@ from typing import Annotated, Any, Union, cast
from pydantic import BaseModel, Field
from core.file import File, FileAttribute, file_manager
from core.variables import Segment, SegmentGroup, Variable
from core.variables import Segment, SegmentGroup, VariableBase
from core.variables.consts import SELECTORS_LENGTH
from core.variables.segments import FileSegment, ObjectSegment
from core.variables.variables import RAGPipelineVariableInput, VariableUnion
from core.variables.variables import RAGPipelineVariableInput, Variable
from core.workflow.constants import (
CONVERSATION_VARIABLE_NODE_ID,
ENVIRONMENT_VARIABLE_NODE_ID,
@@ -32,7 +32,7 @@ class VariablePool(BaseModel):
# The first element of the selector is the node id, it's the first-level key in the dictionary.
# Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
# elements of the selector except the first one.
variable_dictionary: defaultdict[str, Annotated[dict[str, VariableUnion], Field(default_factory=dict)]] = Field(
variable_dictionary: defaultdict[str, Annotated[dict[str, Variable], Field(default_factory=dict)]] = Field(
description="Variables mapping",
default=defaultdict(dict),
)
@@ -46,13 +46,13 @@ class VariablePool(BaseModel):
description="System variables",
default_factory=SystemVariable.empty,
)
environment_variables: Sequence[VariableUnion] = Field(
environment_variables: Sequence[Variable] = Field(
description="Environment variables.",
default_factory=list[VariableUnion],
default_factory=list[Variable],
)
conversation_variables: Sequence[VariableUnion] = Field(
conversation_variables: Sequence[Variable] = Field(
description="Conversation variables.",
default_factory=list[VariableUnion],
default_factory=list[Variable],
)
rag_pipeline_variables: list[RAGPipelineVariableInput] = Field(
description="RAG pipeline variables.",
@@ -105,7 +105,7 @@ class VariablePool(BaseModel):
f"got {len(selector)} elements"
)
if isinstance(value, Variable):
if isinstance(value, VariableBase):
variable = value
elif isinstance(value, Segment):
variable = variable_factory.segment_to_variable(segment=value, selector=selector)
@@ -114,9 +114,9 @@ class VariablePool(BaseModel):
variable = variable_factory.segment_to_variable(segment=segment, selector=selector)
node_id, name = self._selector_to_keys(selector)
# Based on the definition of `VariableUnion`,
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
self.variable_dictionary[node_id][name] = cast(VariableUnion, variable)
# Based on the definition of `Variable`,
# `VariableBase` instances can be safely used as `Variable` since they are compatible.
self.variable_dictionary[node_id][name] = cast(Variable, variable)
@classmethod
def _selector_to_keys(cls, selector: Sequence[str]) -> tuple[str, str]:
+4 -4
View File
@@ -2,7 +2,7 @@ import abc
from collections.abc import Mapping, Sequence
from typing import Any, Protocol
from core.variables import Variable
from core.variables import VariableBase
from core.variables.consts import SELECTORS_LENGTH
from core.workflow.runtime import VariablePool
@@ -26,7 +26,7 @@ class VariableLoader(Protocol):
"""
@abc.abstractmethod
def load_variables(self, selectors: list[list[str]]) -> list[Variable]:
def load_variables(self, selectors: list[list[str]]) -> list[VariableBase]:
"""Load variables based on the provided selectors. If the selectors are empty,
this method should return an empty list.
@@ -36,7 +36,7 @@ class VariableLoader(Protocol):
:param: selectors: a list of string list, each inner list should have at least two elements:
- the first element is the node ID,
- the second element is the variable name.
:return: a list of Variable objects that match the provided selectors.
:return: a list of VariableBase objects that match the provided selectors.
"""
pass
@@ -46,7 +46,7 @@ class _DummyVariableLoader(VariableLoader):
Serves as a placeholder when no variable loading is needed.
"""
def load_variables(self, selectors: list[list[str]]) -> list[Variable]:
def load_variables(self, selectors: list[list[str]]) -> list[VariableBase]:
return []
+31 -11
View File
@@ -7,6 +7,7 @@ from typing import Any
from configs import dify_config
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.workflow.node_factory import DifyNodeFactory
from core.file.models import File
from core.workflow.constants import ENVIRONMENT_VARIABLE_NODE_ID
from core.workflow.entities import GraphInitParams
@@ -136,13 +137,11 @@ class WorkflowEntry:
:param user_inputs: user inputs
:return:
"""
node_config = workflow.get_node_config_by_id(node_id)
node_config = dict(workflow.get_node_config_by_id(node_id))
node_config_data = node_config.get("data", {})
# Get node class
# Get node type
node_type = NodeType(node_config_data.get("type"))
node_version = node_config_data.get("version", "1")
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
# init graph init params and runtime state
graph_init_params = GraphInitParams(
@@ -158,12 +157,12 @@ class WorkflowEntry:
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
# init workflow run state
node = node_cls(
id=str(uuid.uuid4()),
config=node_config,
node_factory = DifyNodeFactory(
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
node = node_factory.create_node(node_config)
node_cls = type(node)
try:
# variable selector to variable mapping
@@ -190,8 +189,7 @@ class WorkflowEntry:
)
try:
# run node
generator = node.run()
generator = cls._traced_node_run(node)
except Exception as e:
logger.exception(
"error while running node, workflow_id=%s, node_id=%s, node_type=%s, node_version=%s",
@@ -324,8 +322,7 @@ class WorkflowEntry:
tenant_id=tenant_id,
)
# run node
generator = node.run()
generator = cls._traced_node_run(node)
return node, generator
except Exception as e:
@@ -431,3 +428,26 @@ class WorkflowEntry:
input_value = current_variable.value | input_value
variable_pool.add([variable_node_id] + variable_key_list, input_value)
@staticmethod
def _traced_node_run(node: Node) -> Generator[GraphNodeEventBase, None, None]:
"""
Wraps a node's run method with OpenTelemetry tracing and returns a generator.
"""
# Wrap node.run() with ObservabilityLayer hooks to produce node-level spans
layer = ObservabilityLayer()
layer.on_graph_start()
node.ensure_execution_id()
def _gen():
error: Exception | None = None
layer.on_node_run_start(node)
try:
yield from node.run()
except Exception as exc:
error = exc
raise
finally:
layer.on_node_run_end(node, error)
return _gen()