feat: Add conversation variable persistence layer (#30531)

This commit is contained in:
-LAN-
2026-01-06 14:05:33 +08:00
committed by GitHub
parent b2124a7358
commit d6e9c3310f
14 changed files with 305 additions and 104 deletions
@@ -0,0 +1,144 @@
from collections.abc import Sequence
from datetime import datetime
from unittest.mock import Mock
from core.app.layers.conversation_variable_persist_layer import ConversationVariablePersistenceLayer
from core.variables import StringVariable
from core.variables.segments import Segment
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.graph_engine.protocols.command_channel import CommandChannel
from core.workflow.graph_events.node import NodeRunSucceededEvent
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
from core.workflow.runtime.graph_runtime_state_protocol import ReadOnlyGraphRuntimeState
from core.workflow.system_variable import SystemVariable
class MockReadOnlyVariablePool:
def __init__(self, variables: dict[tuple[str, str], Segment] | None = None) -> None:
self._variables = variables or {}
def get(self, selector: Sequence[str]) -> Segment | None:
if len(selector) < 2:
return None
return self._variables.get((selector[0], selector[1]))
def get_all_by_node(self, node_id: str) -> dict[str, object]:
return {key: value for (nid, key), value in self._variables.items() if nid == node_id}
def get_by_prefix(self, prefix: str) -> dict[str, object]:
return {key: value for (nid, key), value in self._variables.items() if nid == prefix}
def _build_graph_runtime_state(
variable_pool: MockReadOnlyVariablePool,
conversation_id: str | None = None,
) -> ReadOnlyGraphRuntimeState:
graph_runtime_state = Mock(spec=ReadOnlyGraphRuntimeState)
graph_runtime_state.variable_pool = variable_pool
graph_runtime_state.system_variable = SystemVariable(conversation_id=conversation_id).as_view()
return graph_runtime_state
def _build_node_run_succeeded_event(
*,
node_type: NodeType,
outputs: dict[str, object] | None = None,
process_data: dict[str, object] | None = None,
) -> NodeRunSucceededEvent:
return NodeRunSucceededEvent(
id="node-exec-id",
node_id="assigner",
node_type=node_type,
start_at=datetime.utcnow(),
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs=outputs or {},
process_data=process_data or {},
),
)
def test_persists_conversation_variables_from_assigner_output():
conversation_id = "conv-123"
variable = StringVariable(
id="var-1",
name="name",
value="updated",
selector=[CONVERSATION_VARIABLE_NODE_ID, "name"],
)
process_data = common_helpers.set_updated_variables(
{}, [common_helpers.variable_to_processed_data(variable.selector, variable)]
)
variable_pool = MockReadOnlyVariablePool({(CONVERSATION_VARIABLE_NODE_ID, "name"): variable})
updater = Mock()
layer = ConversationVariablePersistenceLayer(updater)
layer.initialize(_build_graph_runtime_state(variable_pool, conversation_id), Mock(spec=CommandChannel))
event = _build_node_run_succeeded_event(node_type=NodeType.VARIABLE_ASSIGNER, process_data=process_data)
layer.on_event(event)
updater.update.assert_called_once_with(conversation_id=conversation_id, variable=variable)
updater.flush.assert_called_once()
def test_skips_when_outputs_missing():
conversation_id = "conv-456"
variable = StringVariable(
id="var-2",
name="name",
value="updated",
selector=[CONVERSATION_VARIABLE_NODE_ID, "name"],
)
variable_pool = MockReadOnlyVariablePool({(CONVERSATION_VARIABLE_NODE_ID, "name"): variable})
updater = Mock()
layer = ConversationVariablePersistenceLayer(updater)
layer.initialize(_build_graph_runtime_state(variable_pool, conversation_id), Mock(spec=CommandChannel))
event = _build_node_run_succeeded_event(node_type=NodeType.VARIABLE_ASSIGNER)
layer.on_event(event)
updater.update.assert_not_called()
updater.flush.assert_not_called()
def test_skips_non_assigner_nodes():
updater = Mock()
layer = ConversationVariablePersistenceLayer(updater)
layer.initialize(_build_graph_runtime_state(MockReadOnlyVariablePool()), Mock(spec=CommandChannel))
event = _build_node_run_succeeded_event(node_type=NodeType.LLM)
layer.on_event(event)
updater.update.assert_not_called()
updater.flush.assert_not_called()
def test_skips_non_conversation_variables():
conversation_id = "conv-789"
non_conversation_variable = StringVariable(
id="var-3",
name="name",
value="updated",
selector=["environment", "name"],
)
process_data = common_helpers.set_updated_variables(
{}, [common_helpers.variable_to_processed_data(non_conversation_variable.selector, non_conversation_variable)]
)
variable_pool = MockReadOnlyVariablePool()
updater = Mock()
layer = ConversationVariablePersistenceLayer(updater)
layer.initialize(_build_graph_runtime_state(variable_pool, conversation_id), Mock(spec=CommandChannel))
event = _build_node_run_succeeded_event(node_type=NodeType.VARIABLE_ASSIGNER, process_data=process_data)
layer.on_event(event)
updater.update.assert_not_called()
updater.flush.assert_not_called()
@@ -1,4 +1,5 @@
import json
from collections.abc import Sequence
from time import time
from unittest.mock import Mock
@@ -67,8 +68,10 @@ class MockReadOnlyVariablePool:
def __init__(self, variables: dict[tuple[str, str], object] | None = None):
self._variables = variables or {}
def get(self, node_id: str, variable_key: str) -> Segment | None:
value = self._variables.get((node_id, variable_key))
def get(self, selector: Sequence[str]) -> Segment | None:
if len(selector) < 2:
return None
value = self._variables.get((selector[0], selector[1]))
if value is None:
return None
mock_segment = Mock(spec=Segment)
@@ -1,14 +1,14 @@
import time
import uuid
from unittest import mock
from uuid import uuid4
from core.app.entities.app_invoke_entities import InvokeFrom
from core.variables import ArrayStringVariable, StringVariable
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
from core.workflow.entities import GraphInitParams
from core.workflow.graph import Graph
from core.workflow.graph_events.node import NodeRunSucceededEvent
from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
from core.workflow.nodes.variable_assigner.v1 import VariableAssignerNode
from core.workflow.nodes.variable_assigner.v1.node_data import WriteMode
from core.workflow.runtime import GraphRuntimeState, VariablePool
@@ -86,9 +86,6 @@ def test_overwrite_string_variable():
)
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater)
mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater)
node_config = {
"id": "node_id",
"data": {
@@ -104,20 +101,14 @@ def test_overwrite_string_variable():
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
config=node_config,
conv_var_updater_factory=mock_conv_var_updater_factory,
)
list(node.run())
expected_var = StringVariable(
id=conversation_variable.id,
name=conversation_variable.name,
description=conversation_variable.description,
selector=conversation_variable.selector,
value_type=conversation_variable.value_type,
value=input_variable.value,
)
mock_conv_var_updater.update.assert_called_once_with(conversation_id=conversation_id, variable=expected_var)
mock_conv_var_updater.flush.assert_called_once()
events = list(node.run())
succeeded_event = next(event for event in events if isinstance(event, NodeRunSucceededEvent))
updated_variables = common_helpers.get_updated_variables(succeeded_event.node_run_result.process_data)
assert updated_variables is not None
assert updated_variables[0].name == conversation_variable.name
assert updated_variables[0].new_value == input_variable.value
got = variable_pool.get(["conversation", conversation_variable.name])
assert got is not None
@@ -191,9 +182,6 @@ def test_append_variable_to_array():
)
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater)
mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater)
node_config = {
"id": "node_id",
"data": {
@@ -209,22 +197,14 @@ def test_append_variable_to_array():
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
config=node_config,
conv_var_updater_factory=mock_conv_var_updater_factory,
)
list(node.run())
expected_value = list(conversation_variable.value)
expected_value.append(input_variable.value)
expected_var = ArrayStringVariable(
id=conversation_variable.id,
name=conversation_variable.name,
description=conversation_variable.description,
selector=conversation_variable.selector,
value_type=conversation_variable.value_type,
value=expected_value,
)
mock_conv_var_updater.update.assert_called_once_with(conversation_id=conversation_id, variable=expected_var)
mock_conv_var_updater.flush.assert_called_once()
events = list(node.run())
succeeded_event = next(event for event in events if isinstance(event, NodeRunSucceededEvent))
updated_variables = common_helpers.get_updated_variables(succeeded_event.node_run_result.process_data)
assert updated_variables is not None
assert updated_variables[0].name == conversation_variable.name
assert updated_variables[0].new_value == ["the first value", "the second value"]
got = variable_pool.get(["conversation", conversation_variable.name])
assert got is not None
@@ -287,9 +267,6 @@ def test_clear_array():
)
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater)
mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater)
node_config = {
"id": "node_id",
"data": {
@@ -305,20 +282,14 @@ def test_clear_array():
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
config=node_config,
conv_var_updater_factory=mock_conv_var_updater_factory,
)
list(node.run())
expected_var = ArrayStringVariable(
id=conversation_variable.id,
name=conversation_variable.name,
description=conversation_variable.description,
selector=conversation_variable.selector,
value_type=conversation_variable.value_type,
value=[],
)
mock_conv_var_updater.update.assert_called_once_with(conversation_id=conversation_id, variable=expected_var)
mock_conv_var_updater.flush.assert_called_once()
events = list(node.run())
succeeded_event = next(event for event in events if isinstance(event, NodeRunSucceededEvent))
updated_variables = common_helpers.get_updated_variables(succeeded_event.node_run_result.process_data)
assert updated_variables is not None
assert updated_variables[0].name == conversation_variable.name
assert updated_variables[0].new_value == []
got = variable_pool.get(["conversation", conversation_variable.name])
assert got is not None
@@ -390,3 +390,42 @@ def test_remove_last_from_empty_array():
got = variable_pool.get(["conversation", conversation_variable.name])
assert got is not None
assert got.to_object() == []
def test_node_factory_creates_variable_assigner_node():
graph_config = {
"edges": [],
"nodes": [
{
"data": {"type": "assigner", "version": "2", "title": "Variable Assigner", "items": []},
"id": "assigner",
},
],
}
init_params = GraphInitParams(
tenant_id="1",
app_id="1",
workflow_id="1",
graph_config=graph_config,
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
)
variable_pool = VariablePool(
system_variables=SystemVariable(conversation_id="conversation_id"),
user_inputs={},
environment_variables=[],
conversation_variables=[],
)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
node_factory = DifyNodeFactory(
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
)
node = node_factory.create_node(graph_config["nodes"][0])
assert isinstance(node, VariableAssignerNode)