mirror of
https://github.com/YFGaia/dify-plus.git
synced 2026-06-12 18:11:42 +08:00
feat: Add conversation variable persistence layer (#30531)
This commit is contained in:
@@ -0,0 +1,144 @@
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from unittest.mock import Mock
|
||||
|
||||
from core.app.layers.conversation_variable_persist_layer import ConversationVariablePersistenceLayer
|
||||
from core.variables import StringVariable
|
||||
from core.variables.segments import Segment
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph_engine.protocols.command_channel import CommandChannel
|
||||
from core.workflow.graph_events.node import NodeRunSucceededEvent
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
|
||||
from core.workflow.runtime.graph_runtime_state_protocol import ReadOnlyGraphRuntimeState
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
|
||||
|
||||
class MockReadOnlyVariablePool:
|
||||
def __init__(self, variables: dict[tuple[str, str], Segment] | None = None) -> None:
|
||||
self._variables = variables or {}
|
||||
|
||||
def get(self, selector: Sequence[str]) -> Segment | None:
|
||||
if len(selector) < 2:
|
||||
return None
|
||||
return self._variables.get((selector[0], selector[1]))
|
||||
|
||||
def get_all_by_node(self, node_id: str) -> dict[str, object]:
|
||||
return {key: value for (nid, key), value in self._variables.items() if nid == node_id}
|
||||
|
||||
def get_by_prefix(self, prefix: str) -> dict[str, object]:
|
||||
return {key: value for (nid, key), value in self._variables.items() if nid == prefix}
|
||||
|
||||
|
||||
def _build_graph_runtime_state(
|
||||
variable_pool: MockReadOnlyVariablePool,
|
||||
conversation_id: str | None = None,
|
||||
) -> ReadOnlyGraphRuntimeState:
|
||||
graph_runtime_state = Mock(spec=ReadOnlyGraphRuntimeState)
|
||||
graph_runtime_state.variable_pool = variable_pool
|
||||
graph_runtime_state.system_variable = SystemVariable(conversation_id=conversation_id).as_view()
|
||||
return graph_runtime_state
|
||||
|
||||
|
||||
def _build_node_run_succeeded_event(
|
||||
*,
|
||||
node_type: NodeType,
|
||||
outputs: dict[str, object] | None = None,
|
||||
process_data: dict[str, object] | None = None,
|
||||
) -> NodeRunSucceededEvent:
|
||||
return NodeRunSucceededEvent(
|
||||
id="node-exec-id",
|
||||
node_id="assigner",
|
||||
node_type=node_type,
|
||||
start_at=datetime.utcnow(),
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs=outputs or {},
|
||||
process_data=process_data or {},
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def test_persists_conversation_variables_from_assigner_output():
|
||||
conversation_id = "conv-123"
|
||||
variable = StringVariable(
|
||||
id="var-1",
|
||||
name="name",
|
||||
value="updated",
|
||||
selector=[CONVERSATION_VARIABLE_NODE_ID, "name"],
|
||||
)
|
||||
process_data = common_helpers.set_updated_variables(
|
||||
{}, [common_helpers.variable_to_processed_data(variable.selector, variable)]
|
||||
)
|
||||
|
||||
variable_pool = MockReadOnlyVariablePool({(CONVERSATION_VARIABLE_NODE_ID, "name"): variable})
|
||||
|
||||
updater = Mock()
|
||||
layer = ConversationVariablePersistenceLayer(updater)
|
||||
layer.initialize(_build_graph_runtime_state(variable_pool, conversation_id), Mock(spec=CommandChannel))
|
||||
|
||||
event = _build_node_run_succeeded_event(node_type=NodeType.VARIABLE_ASSIGNER, process_data=process_data)
|
||||
layer.on_event(event)
|
||||
|
||||
updater.update.assert_called_once_with(conversation_id=conversation_id, variable=variable)
|
||||
updater.flush.assert_called_once()
|
||||
|
||||
|
||||
def test_skips_when_outputs_missing():
|
||||
conversation_id = "conv-456"
|
||||
variable = StringVariable(
|
||||
id="var-2",
|
||||
name="name",
|
||||
value="updated",
|
||||
selector=[CONVERSATION_VARIABLE_NODE_ID, "name"],
|
||||
)
|
||||
|
||||
variable_pool = MockReadOnlyVariablePool({(CONVERSATION_VARIABLE_NODE_ID, "name"): variable})
|
||||
|
||||
updater = Mock()
|
||||
layer = ConversationVariablePersistenceLayer(updater)
|
||||
layer.initialize(_build_graph_runtime_state(variable_pool, conversation_id), Mock(spec=CommandChannel))
|
||||
|
||||
event = _build_node_run_succeeded_event(node_type=NodeType.VARIABLE_ASSIGNER)
|
||||
layer.on_event(event)
|
||||
|
||||
updater.update.assert_not_called()
|
||||
updater.flush.assert_not_called()
|
||||
|
||||
|
||||
def test_skips_non_assigner_nodes():
|
||||
updater = Mock()
|
||||
layer = ConversationVariablePersistenceLayer(updater)
|
||||
layer.initialize(_build_graph_runtime_state(MockReadOnlyVariablePool()), Mock(spec=CommandChannel))
|
||||
|
||||
event = _build_node_run_succeeded_event(node_type=NodeType.LLM)
|
||||
layer.on_event(event)
|
||||
|
||||
updater.update.assert_not_called()
|
||||
updater.flush.assert_not_called()
|
||||
|
||||
|
||||
def test_skips_non_conversation_variables():
|
||||
conversation_id = "conv-789"
|
||||
non_conversation_variable = StringVariable(
|
||||
id="var-3",
|
||||
name="name",
|
||||
value="updated",
|
||||
selector=["environment", "name"],
|
||||
)
|
||||
process_data = common_helpers.set_updated_variables(
|
||||
{}, [common_helpers.variable_to_processed_data(non_conversation_variable.selector, non_conversation_variable)]
|
||||
)
|
||||
|
||||
variable_pool = MockReadOnlyVariablePool()
|
||||
|
||||
updater = Mock()
|
||||
layer = ConversationVariablePersistenceLayer(updater)
|
||||
layer.initialize(_build_graph_runtime_state(variable_pool, conversation_id), Mock(spec=CommandChannel))
|
||||
|
||||
event = _build_node_run_succeeded_event(node_type=NodeType.VARIABLE_ASSIGNER, process_data=process_data)
|
||||
layer.on_event(event)
|
||||
|
||||
updater.update.assert_not_called()
|
||||
updater.flush.assert_not_called()
|
||||
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from time import time
|
||||
from unittest.mock import Mock
|
||||
|
||||
@@ -67,8 +68,10 @@ class MockReadOnlyVariablePool:
|
||||
def __init__(self, variables: dict[tuple[str, str], object] | None = None):
|
||||
self._variables = variables or {}
|
||||
|
||||
def get(self, node_id: str, variable_key: str) -> Segment | None:
|
||||
value = self._variables.get((node_id, variable_key))
|
||||
def get(self, selector: Sequence[str]) -> Segment | None:
|
||||
if len(selector) < 2:
|
||||
return None
|
||||
value = self._variables.get((selector[0], selector[1]))
|
||||
if value is None:
|
||||
return None
|
||||
mock_segment = Mock(spec=Segment)
|
||||
|
||||
+20
-49
@@ -1,14 +1,14 @@
|
||||
import time
|
||||
import uuid
|
||||
from unittest import mock
|
||||
from uuid import uuid4
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.variables import ArrayStringVariable, StringVariable
|
||||
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events.node import NodeRunSucceededEvent
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
|
||||
from core.workflow.nodes.variable_assigner.v1 import VariableAssignerNode
|
||||
from core.workflow.nodes.variable_assigner.v1.node_data import WriteMode
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
@@ -86,9 +86,6 @@ def test_overwrite_string_variable():
|
||||
)
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
|
||||
mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater)
|
||||
mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater)
|
||||
|
||||
node_config = {
|
||||
"id": "node_id",
|
||||
"data": {
|
||||
@@ -104,20 +101,14 @@ def test_overwrite_string_variable():
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
config=node_config,
|
||||
conv_var_updater_factory=mock_conv_var_updater_factory,
|
||||
)
|
||||
|
||||
list(node.run())
|
||||
expected_var = StringVariable(
|
||||
id=conversation_variable.id,
|
||||
name=conversation_variable.name,
|
||||
description=conversation_variable.description,
|
||||
selector=conversation_variable.selector,
|
||||
value_type=conversation_variable.value_type,
|
||||
value=input_variable.value,
|
||||
)
|
||||
mock_conv_var_updater.update.assert_called_once_with(conversation_id=conversation_id, variable=expected_var)
|
||||
mock_conv_var_updater.flush.assert_called_once()
|
||||
events = list(node.run())
|
||||
succeeded_event = next(event for event in events if isinstance(event, NodeRunSucceededEvent))
|
||||
updated_variables = common_helpers.get_updated_variables(succeeded_event.node_run_result.process_data)
|
||||
assert updated_variables is not None
|
||||
assert updated_variables[0].name == conversation_variable.name
|
||||
assert updated_variables[0].new_value == input_variable.value
|
||||
|
||||
got = variable_pool.get(["conversation", conversation_variable.name])
|
||||
assert got is not None
|
||||
@@ -191,9 +182,6 @@ def test_append_variable_to_array():
|
||||
)
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
|
||||
mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater)
|
||||
mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater)
|
||||
|
||||
node_config = {
|
||||
"id": "node_id",
|
||||
"data": {
|
||||
@@ -209,22 +197,14 @@ def test_append_variable_to_array():
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
config=node_config,
|
||||
conv_var_updater_factory=mock_conv_var_updater_factory,
|
||||
)
|
||||
|
||||
list(node.run())
|
||||
expected_value = list(conversation_variable.value)
|
||||
expected_value.append(input_variable.value)
|
||||
expected_var = ArrayStringVariable(
|
||||
id=conversation_variable.id,
|
||||
name=conversation_variable.name,
|
||||
description=conversation_variable.description,
|
||||
selector=conversation_variable.selector,
|
||||
value_type=conversation_variable.value_type,
|
||||
value=expected_value,
|
||||
)
|
||||
mock_conv_var_updater.update.assert_called_once_with(conversation_id=conversation_id, variable=expected_var)
|
||||
mock_conv_var_updater.flush.assert_called_once()
|
||||
events = list(node.run())
|
||||
succeeded_event = next(event for event in events if isinstance(event, NodeRunSucceededEvent))
|
||||
updated_variables = common_helpers.get_updated_variables(succeeded_event.node_run_result.process_data)
|
||||
assert updated_variables is not None
|
||||
assert updated_variables[0].name == conversation_variable.name
|
||||
assert updated_variables[0].new_value == ["the first value", "the second value"]
|
||||
|
||||
got = variable_pool.get(["conversation", conversation_variable.name])
|
||||
assert got is not None
|
||||
@@ -287,9 +267,6 @@ def test_clear_array():
|
||||
)
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
|
||||
mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater)
|
||||
mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater)
|
||||
|
||||
node_config = {
|
||||
"id": "node_id",
|
||||
"data": {
|
||||
@@ -305,20 +282,14 @@ def test_clear_array():
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
config=node_config,
|
||||
conv_var_updater_factory=mock_conv_var_updater_factory,
|
||||
)
|
||||
|
||||
list(node.run())
|
||||
expected_var = ArrayStringVariable(
|
||||
id=conversation_variable.id,
|
||||
name=conversation_variable.name,
|
||||
description=conversation_variable.description,
|
||||
selector=conversation_variable.selector,
|
||||
value_type=conversation_variable.value_type,
|
||||
value=[],
|
||||
)
|
||||
mock_conv_var_updater.update.assert_called_once_with(conversation_id=conversation_id, variable=expected_var)
|
||||
mock_conv_var_updater.flush.assert_called_once()
|
||||
events = list(node.run())
|
||||
succeeded_event = next(event for event in events if isinstance(event, NodeRunSucceededEvent))
|
||||
updated_variables = common_helpers.get_updated_variables(succeeded_event.node_run_result.process_data)
|
||||
assert updated_variables is not None
|
||||
assert updated_variables[0].name == conversation_variable.name
|
||||
assert updated_variables[0].new_value == []
|
||||
|
||||
got = variable_pool.get(["conversation", conversation_variable.name])
|
||||
assert got is not None
|
||||
|
||||
+39
@@ -390,3 +390,42 @@ def test_remove_last_from_empty_array():
|
||||
got = variable_pool.get(["conversation", conversation_variable.name])
|
||||
assert got is not None
|
||||
assert got.to_object() == []
|
||||
|
||||
|
||||
def test_node_factory_creates_variable_assigner_node():
|
||||
graph_config = {
|
||||
"edges": [],
|
||||
"nodes": [
|
||||
{
|
||||
"data": {"type": "assigner", "version": "2", "title": "Variable Assigner", "items": []},
|
||||
"id": "assigner",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_id="1",
|
||||
graph_config=graph_config,
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
)
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(conversation_id="conversation_id"),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
)
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
node = node_factory.create_node(graph_config["nodes"][0])
|
||||
|
||||
assert isinstance(node, VariableAssignerNode)
|
||||
|
||||
Reference in New Issue
Block a user