refactor(variables): clarify base vs union type naming (#30634)

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
-LAN-
2026-01-13 22:39:34 +08:00
committed by GitHub
parent 91da784f84
commit 206706987d
22 changed files with 124 additions and 125 deletions
@@ -24,7 +24,7 @@ from core.app.layers.conversation_variable_persist_layer import ConversationVari
from core.db.session_factory import session_factory from core.db.session_factory import session_factory
from core.moderation.base import ModerationError from core.moderation.base import ModerationError
from core.moderation.input_moderation import InputModeration from core.moderation.input_moderation import InputModeration
from core.variables.variables import VariableUnion from core.variables.variables import Variable
from core.workflow.enums import WorkflowType from core.workflow.enums import WorkflowType
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
from core.workflow.graph_engine.layers.base import GraphEngineLayer from core.workflow.graph_engine.layers.base import GraphEngineLayer
@@ -149,8 +149,8 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
system_variables=system_inputs, system_variables=system_inputs,
user_inputs=inputs, user_inputs=inputs,
environment_variables=self._workflow.environment_variables, environment_variables=self._workflow.environment_variables,
# Based on the definition of `VariableUnion`, # Based on the definition of `Variable`,
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible. # `VariableBase` instances can be safely used as `Variable` since they are compatible.
conversation_variables=conversation_variables, conversation_variables=conversation_variables,
) )
@@ -318,7 +318,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
trace_manager=app_generate_entity.trace_manager, trace_manager=app_generate_entity.trace_manager,
) )
def _initialize_conversation_variables(self) -> list[VariableUnion]: def _initialize_conversation_variables(self) -> list[Variable]:
""" """
Initialize conversation variables for the current conversation. Initialize conversation variables for the current conversation.
@@ -343,7 +343,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
conversation_variables = [var.to_variable() for var in existing_variables] conversation_variables = [var.to_variable() for var in existing_variables]
session.commit() session.commit()
return cast(list[VariableUnion], conversation_variables) return cast(list[Variable], conversation_variables)
def _load_existing_conversation_variables(self, session: Session) -> list[ConversationVariable]: def _load_existing_conversation_variables(self, session: Session) -> list[ConversationVariable]:
""" """
@@ -1,6 +1,6 @@
import logging import logging
from core.variables import Variable from core.variables import VariableBase
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.conversation_variable_updater import ConversationVariableUpdater from core.workflow.conversation_variable_updater import ConversationVariableUpdater
from core.workflow.enums import NodeType from core.workflow.enums import NodeType
@@ -44,7 +44,7 @@ class ConversationVariablePersistenceLayer(GraphEngineLayer):
if selector[0] != CONVERSATION_VARIABLE_NODE_ID: if selector[0] != CONVERSATION_VARIABLE_NODE_ID:
continue continue
variable = self.graph_runtime_state.variable_pool.get(selector) variable = self.graph_runtime_state.variable_pool.get(selector)
if not isinstance(variable, Variable): if not isinstance(variable, VariableBase):
logger.warning( logger.warning(
"Conversation variable not found in variable pool. selector=%s", "Conversation variable not found in variable pool. selector=%s",
selector, selector,
+2
View File
@@ -30,6 +30,7 @@ from .variables import (
SecretVariable, SecretVariable,
StringVariable, StringVariable,
Variable, Variable,
VariableBase,
) )
__all__ = [ __all__ = [
@@ -62,4 +63,5 @@ __all__ = [
"StringSegment", "StringSegment",
"StringVariable", "StringVariable",
"Variable", "Variable",
"VariableBase",
] ]
+1 -1
View File
@@ -232,7 +232,7 @@ def get_segment_discriminator(v: Any) -> SegmentType | None:
# - All variants in `SegmentUnion` must inherit from the `Segment` class. # - All variants in `SegmentUnion` must inherit from the `Segment` class.
# - The union must include all non-abstract subclasses of `Segment`, except: # - The union must include all non-abstract subclasses of `Segment`, except:
# - `SegmentGroup`, which is not added to the variable pool. # - `SegmentGroup`, which is not added to the variable pool.
# - `Variable` and its subclasses, which are handled by `VariableUnion`. # - `VariableBase` and its subclasses, which are handled by `Variable`.
SegmentUnion: TypeAlias = Annotated[ SegmentUnion: TypeAlias = Annotated[
( (
Annotated[NoneSegment, Tag(SegmentType.NONE)] Annotated[NoneSegment, Tag(SegmentType.NONE)]
+14 -14
View File
@@ -27,7 +27,7 @@ from .segments import (
from .types import SegmentType from .types import SegmentType
class Variable(Segment): class VariableBase(Segment):
""" """
A variable is a segment that has a name. A variable is a segment that has a name.
@@ -45,23 +45,23 @@ class Variable(Segment):
selector: Sequence[str] = Field(default_factory=list) selector: Sequence[str] = Field(default_factory=list)
class StringVariable(StringSegment, Variable): class StringVariable(StringSegment, VariableBase):
pass pass
class FloatVariable(FloatSegment, Variable): class FloatVariable(FloatSegment, VariableBase):
pass pass
class IntegerVariable(IntegerSegment, Variable): class IntegerVariable(IntegerSegment, VariableBase):
pass pass
class ObjectVariable(ObjectSegment, Variable): class ObjectVariable(ObjectSegment, VariableBase):
pass pass
class ArrayVariable(ArraySegment, Variable): class ArrayVariable(ArraySegment, VariableBase):
pass pass
@@ -89,16 +89,16 @@ class SecretVariable(StringVariable):
return encrypter.obfuscated_token(self.value) return encrypter.obfuscated_token(self.value)
class NoneVariable(NoneSegment, Variable): class NoneVariable(NoneSegment, VariableBase):
value_type: SegmentType = SegmentType.NONE value_type: SegmentType = SegmentType.NONE
value: None = None value: None = None
class FileVariable(FileSegment, Variable): class FileVariable(FileSegment, VariableBase):
pass pass
class BooleanVariable(BooleanSegment, Variable): class BooleanVariable(BooleanSegment, VariableBase):
pass pass
@@ -139,13 +139,13 @@ class RAGPipelineVariableInput(BaseModel):
value: Any value: Any
# The `VariableUnion`` type is used to enable serialization and deserialization with Pydantic. # The `Variable` type is used to enable serialization and deserialization with Pydantic.
# Use `Variable` for type hinting when serialization is not required. # Use `VariableBase` for type hinting when serialization is not required.
# #
# Note: # Note:
# - All variants in `VariableUnion` must inherit from the `Variable` class. # - All variants in `Variable` must inherit from the `VariableBase` class.
# - The union must include all non-abstract subclasses of `Segment`, except: # - The union must include all non-abstract subclasses of `VariableBase`.
VariableUnion: TypeAlias = Annotated[ Variable: TypeAlias = Annotated[
( (
Annotated[NoneVariable, Tag(SegmentType.NONE)] Annotated[NoneVariable, Tag(SegmentType.NONE)]
| Annotated[StringVariable, Tag(SegmentType.STRING)] | Annotated[StringVariable, Tag(SegmentType.STRING)]
@@ -1,7 +1,7 @@
import abc import abc
from typing import Protocol from typing import Protocol
from core.variables import Variable from core.variables import VariableBase
class ConversationVariableUpdater(Protocol): class ConversationVariableUpdater(Protocol):
@@ -20,12 +20,12 @@ class ConversationVariableUpdater(Protocol):
""" """
@abc.abstractmethod @abc.abstractmethod
def update(self, conversation_id: str, variable: "Variable"): def update(self, conversation_id: str, variable: "VariableBase"):
""" """
Updates the value of the specified conversation variable in the underlying storage. Updates the value of the specified conversation variable in the underlying storage.
:param conversation_id: The ID of the conversation to update. Typically references `ConversationVariable.id`. :param conversation_id: The ID of the conversation to update. Typically references `ConversationVariable.id`.
:param variable: The `Variable` instance containing the updated value. :param variable: The `VariableBase` instance containing the updated value.
""" """
pass pass
@@ -11,7 +11,7 @@ from typing import Any
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from core.variables.variables import VariableUnion from core.variables.variables import Variable
class CommandType(StrEnum): class CommandType(StrEnum):
@@ -46,7 +46,7 @@ class PauseCommand(GraphEngineCommand):
class VariableUpdate(BaseModel): class VariableUpdate(BaseModel):
"""Represents a single variable update instruction.""" """Represents a single variable update instruction."""
value: VariableUnion = Field(description="New variable value") value: Variable = Field(description="New variable value")
class UpdateVariablesCommand(GraphEngineCommand): class UpdateVariablesCommand(GraphEngineCommand):
@@ -11,7 +11,7 @@ from typing_extensions import TypeIs
from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.entities.llm_entities import LLMUsage
from core.variables import IntegerVariable, NoneSegment from core.variables import IntegerVariable, NoneSegment
from core.variables.segments import ArrayAnySegment, ArraySegment from core.variables.segments import ArrayAnySegment, ArraySegment
from core.variables.variables import VariableUnion from core.variables.variables import Variable
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.enums import ( from core.workflow.enums import (
NodeExecutionType, NodeExecutionType,
@@ -240,7 +240,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
datetime, datetime,
list[GraphNodeEventBase], list[GraphNodeEventBase],
object | None, object | None,
dict[str, VariableUnion], dict[str, Variable],
LLMUsage, LLMUsage,
] ]
], ],
@@ -308,7 +308,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
item: object, item: object,
flask_app: Flask, flask_app: Flask,
context_vars: contextvars.Context, context_vars: contextvars.Context,
) -> tuple[datetime, list[GraphNodeEventBase], object | None, dict[str, VariableUnion], LLMUsage]: ) -> tuple[datetime, list[GraphNodeEventBase], object | None, dict[str, Variable], LLMUsage]:
"""Execute a single iteration in parallel mode and return results.""" """Execute a single iteration in parallel mode and return results."""
with preserve_flask_contexts(flask_app=flask_app, context_vars=context_vars): with preserve_flask_contexts(flask_app=flask_app, context_vars=context_vars):
iter_start_at = datetime.now(UTC).replace(tzinfo=None) iter_start_at = datetime.now(UTC).replace(tzinfo=None)
@@ -515,11 +515,11 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
return variable_mapping return variable_mapping
def _extract_conversation_variable_snapshot(self, *, variable_pool: VariablePool) -> dict[str, VariableUnion]: def _extract_conversation_variable_snapshot(self, *, variable_pool: VariablePool) -> dict[str, Variable]:
conversation_variables = variable_pool.variable_dictionary.get(CONVERSATION_VARIABLE_NODE_ID, {}) conversation_variables = variable_pool.variable_dictionary.get(CONVERSATION_VARIABLE_NODE_ID, {})
return {name: variable.model_copy(deep=True) for name, variable in conversation_variables.items()} return {name: variable.model_copy(deep=True) for name, variable in conversation_variables.items()}
def _sync_conversation_variables_from_snapshot(self, snapshot: dict[str, VariableUnion]) -> None: def _sync_conversation_variables_from_snapshot(self, snapshot: dict[str, Variable]) -> None:
parent_pool = self.graph_runtime_state.variable_pool parent_pool = self.graph_runtime_state.variable_pool
parent_conversations = parent_pool.variable_dictionary.get(CONVERSATION_VARIABLE_NODE_ID, {}) parent_conversations = parent_pool.variable_dictionary.get(CONVERSATION_VARIABLE_NODE_ID, {})
@@ -1,7 +1,7 @@
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
from core.variables import SegmentType, Variable from core.variables import SegmentType, VariableBase
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.entities import GraphInitParams from core.workflow.entities import GraphInitParams
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
@@ -73,7 +73,7 @@ class VariableAssignerNode(Node[VariableAssignerData]):
assigned_variable_selector = self.node_data.assigned_variable_selector assigned_variable_selector = self.node_data.assigned_variable_selector
# Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject # Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject
original_variable = self.graph_runtime_state.variable_pool.get(assigned_variable_selector) original_variable = self.graph_runtime_state.variable_pool.get(assigned_variable_selector)
if not isinstance(original_variable, Variable): if not isinstance(original_variable, VariableBase):
raise VariableOperatorNodeError("assigned variable not found") raise VariableOperatorNodeError("assigned variable not found")
match self.node_data.write_mode: match self.node_data.write_mode:
@@ -2,7 +2,7 @@ import json
from collections.abc import Mapping, MutableMapping, Sequence from collections.abc import Mapping, MutableMapping, Sequence
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
from core.variables import SegmentType, Variable from core.variables import SegmentType, VariableBase
from core.variables.consts import SELECTORS_LENGTH from core.variables.consts import SELECTORS_LENGTH
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
@@ -118,7 +118,7 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
# ==================== Validation Part # ==================== Validation Part
# Check if variable exists # Check if variable exists
if not isinstance(variable, Variable): if not isinstance(variable, VariableBase):
raise VariableNotFoundError(variable_selector=item.variable_selector) raise VariableNotFoundError(variable_selector=item.variable_selector)
# Check if operation is supported # Check if operation is supported
@@ -192,7 +192,7 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
for selector in updated_variable_selectors: for selector in updated_variable_selectors:
variable = self.graph_runtime_state.variable_pool.get(selector) variable = self.graph_runtime_state.variable_pool.get(selector)
if not isinstance(variable, Variable): if not isinstance(variable, VariableBase):
raise VariableNotFoundError(variable_selector=selector) raise VariableNotFoundError(variable_selector=selector)
process_data[variable.name] = variable.value process_data[variable.name] = variable.value
@@ -213,7 +213,7 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
def _handle_item( def _handle_item(
self, self,
*, *,
variable: Variable, variable: VariableBase,
operation: Operation, operation: Operation,
value: Any, value: Any,
): ):
+11 -11
View File
@@ -9,10 +9,10 @@ from typing import Annotated, Any, Union, cast
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from core.file import File, FileAttribute, file_manager from core.file import File, FileAttribute, file_manager
from core.variables import Segment, SegmentGroup, Variable from core.variables import Segment, SegmentGroup, VariableBase
from core.variables.consts import SELECTORS_LENGTH from core.variables.consts import SELECTORS_LENGTH
from core.variables.segments import FileSegment, ObjectSegment from core.variables.segments import FileSegment, ObjectSegment
from core.variables.variables import RAGPipelineVariableInput, VariableUnion from core.variables.variables import RAGPipelineVariableInput, Variable
from core.workflow.constants import ( from core.workflow.constants import (
CONVERSATION_VARIABLE_NODE_ID, CONVERSATION_VARIABLE_NODE_ID,
ENVIRONMENT_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID,
@@ -32,7 +32,7 @@ class VariablePool(BaseModel):
# The first element of the selector is the node id, it's the first-level key in the dictionary. # The first element of the selector is the node id, it's the first-level key in the dictionary.
# Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the # Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
# elements of the selector except the first one. # elements of the selector except the first one.
variable_dictionary: defaultdict[str, Annotated[dict[str, VariableUnion], Field(default_factory=dict)]] = Field( variable_dictionary: defaultdict[str, Annotated[dict[str, Variable], Field(default_factory=dict)]] = Field(
description="Variables mapping", description="Variables mapping",
default=defaultdict(dict), default=defaultdict(dict),
) )
@@ -46,13 +46,13 @@ class VariablePool(BaseModel):
description="System variables", description="System variables",
default_factory=SystemVariable.empty, default_factory=SystemVariable.empty,
) )
environment_variables: Sequence[VariableUnion] = Field( environment_variables: Sequence[Variable] = Field(
description="Environment variables.", description="Environment variables.",
default_factory=list[VariableUnion], default_factory=list[Variable],
) )
conversation_variables: Sequence[VariableUnion] = Field( conversation_variables: Sequence[Variable] = Field(
description="Conversation variables.", description="Conversation variables.",
default_factory=list[VariableUnion], default_factory=list[Variable],
) )
rag_pipeline_variables: list[RAGPipelineVariableInput] = Field( rag_pipeline_variables: list[RAGPipelineVariableInput] = Field(
description="RAG pipeline variables.", description="RAG pipeline variables.",
@@ -105,7 +105,7 @@ class VariablePool(BaseModel):
f"got {len(selector)} elements" f"got {len(selector)} elements"
) )
if isinstance(value, Variable): if isinstance(value, VariableBase):
variable = value variable = value
elif isinstance(value, Segment): elif isinstance(value, Segment):
variable = variable_factory.segment_to_variable(segment=value, selector=selector) variable = variable_factory.segment_to_variable(segment=value, selector=selector)
@@ -114,9 +114,9 @@ class VariablePool(BaseModel):
variable = variable_factory.segment_to_variable(segment=segment, selector=selector) variable = variable_factory.segment_to_variable(segment=segment, selector=selector)
node_id, name = self._selector_to_keys(selector) node_id, name = self._selector_to_keys(selector)
# Based on the definition of `VariableUnion`, # Based on the definition of `Variable`,
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible. # `VariableBase` instances can be safely used as `Variable` since they are compatible.
self.variable_dictionary[node_id][name] = cast(VariableUnion, variable) self.variable_dictionary[node_id][name] = cast(Variable, variable)
@classmethod @classmethod
def _selector_to_keys(cls, selector: Sequence[str]) -> tuple[str, str]: def _selector_to_keys(cls, selector: Sequence[str]) -> tuple[str, str]:
+4 -4
View File
@@ -2,7 +2,7 @@ import abc
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from typing import Any, Protocol from typing import Any, Protocol
from core.variables import Variable from core.variables import VariableBase
from core.variables.consts import SELECTORS_LENGTH from core.variables.consts import SELECTORS_LENGTH
from core.workflow.runtime import VariablePool from core.workflow.runtime import VariablePool
@@ -26,7 +26,7 @@ class VariableLoader(Protocol):
""" """
@abc.abstractmethod @abc.abstractmethod
def load_variables(self, selectors: list[list[str]]) -> list[Variable]: def load_variables(self, selectors: list[list[str]]) -> list[VariableBase]:
"""Load variables based on the provided selectors. If the selectors are empty, """Load variables based on the provided selectors. If the selectors are empty,
this method should return an empty list. this method should return an empty list.
@@ -36,7 +36,7 @@ class VariableLoader(Protocol):
:param: selectors: a list of string list, each inner list should have at least two elements: :param: selectors: a list of string list, each inner list should have at least two elements:
- the first element is the node ID, - the first element is the node ID,
- the second element is the variable name. - the second element is the variable name.
:return: a list of Variable objects that match the provided selectors. :return: a list of VariableBase objects that match the provided selectors.
""" """
pass pass
@@ -46,7 +46,7 @@ class _DummyVariableLoader(VariableLoader):
Serves as a placeholder when no variable loading is needed. Serves as a placeholder when no variable loading is needed.
""" """
def load_variables(self, selectors: list[list[str]]) -> list[Variable]: def load_variables(self, selectors: list[list[str]]) -> list[VariableBase]:
return [] return []
+10 -10
View File
@@ -38,7 +38,7 @@ from core.variables.variables import (
ObjectVariable, ObjectVariable,
SecretVariable, SecretVariable,
StringVariable, StringVariable,
Variable, VariableBase,
) )
from core.workflow.constants import ( from core.workflow.constants import (
CONVERSATION_VARIABLE_NODE_ID, CONVERSATION_VARIABLE_NODE_ID,
@@ -72,25 +72,25 @@ SEGMENT_TO_VARIABLE_MAP = {
} }
def build_conversation_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable: def build_conversation_variable_from_mapping(mapping: Mapping[str, Any], /) -> VariableBase:
if not mapping.get("name"): if not mapping.get("name"):
raise VariableError("missing name") raise VariableError("missing name")
return _build_variable_from_mapping(mapping=mapping, selector=[CONVERSATION_VARIABLE_NODE_ID, mapping["name"]]) return _build_variable_from_mapping(mapping=mapping, selector=[CONVERSATION_VARIABLE_NODE_ID, mapping["name"]])
def build_environment_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable: def build_environment_variable_from_mapping(mapping: Mapping[str, Any], /) -> VariableBase:
if not mapping.get("name"): if not mapping.get("name"):
raise VariableError("missing name") raise VariableError("missing name")
return _build_variable_from_mapping(mapping=mapping, selector=[ENVIRONMENT_VARIABLE_NODE_ID, mapping["name"]]) return _build_variable_from_mapping(mapping=mapping, selector=[ENVIRONMENT_VARIABLE_NODE_ID, mapping["name"]])
def build_pipeline_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable: def build_pipeline_variable_from_mapping(mapping: Mapping[str, Any], /) -> VariableBase:
if not mapping.get("variable"): if not mapping.get("variable"):
raise VariableError("missing variable") raise VariableError("missing variable")
return mapping["variable"] return mapping["variable"]
def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequence[str]) -> Variable: def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequence[str]) -> VariableBase:
""" """
This factory function is used to create the environment variable or the conversation variable, This factory function is used to create the environment variable or the conversation variable,
not support the File type. not support the File type.
@@ -100,7 +100,7 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen
if (value := mapping.get("value")) is None: if (value := mapping.get("value")) is None:
raise VariableError("missing value") raise VariableError("missing value")
result: Variable result: VariableBase
match value_type: match value_type:
case SegmentType.STRING: case SegmentType.STRING:
result = StringVariable.model_validate(mapping) result = StringVariable.model_validate(mapping)
@@ -134,7 +134,7 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen
raise VariableError(f"variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}") raise VariableError(f"variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}")
if not result.selector: if not result.selector:
result = result.model_copy(update={"selector": selector}) result = result.model_copy(update={"selector": selector})
return cast(Variable, result) return cast(VariableBase, result)
def build_segment(value: Any, /) -> Segment: def build_segment(value: Any, /) -> Segment:
@@ -285,8 +285,8 @@ def segment_to_variable(
id: str | None = None, id: str | None = None,
name: str | None = None, name: str | None = None,
description: str = "", description: str = "",
) -> Variable: ) -> VariableBase:
if isinstance(segment, Variable): if isinstance(segment, VariableBase):
return segment return segment
name = name or selector[-1] name = name or selector[-1]
id = id or str(uuid4()) id = id or str(uuid4())
@@ -297,7 +297,7 @@ def segment_to_variable(
variable_class = SEGMENT_TO_VARIABLE_MAP[segment_type] variable_class = SEGMENT_TO_VARIABLE_MAP[segment_type]
return cast( return cast(
Variable, VariableBase,
variable_class( variable_class(
id=id, id=id,
name=name, name=name,
+2 -2
View File
@@ -1,7 +1,7 @@
from flask_restx import fields from flask_restx import fields
from core.helper import encrypter from core.helper import encrypter
from core.variables import SecretVariable, SegmentType, Variable from core.variables import SecretVariable, SegmentType, VariableBase
from fields.member_fields import simple_account_fields from fields.member_fields import simple_account_fields
from libs.helper import TimestampField from libs.helper import TimestampField
@@ -21,7 +21,7 @@ class EnvironmentVariableField(fields.Raw):
"value_type": value.value_type.value, "value_type": value.value_type.value,
"description": value.description, "description": value.description,
} }
if isinstance(value, Variable): if isinstance(value, VariableBase):
return { return {
"id": value.id, "id": value.id,
"name": value.name, "name": value.name,
+30 -32
View File
@@ -1,11 +1,9 @@
from __future__ import annotations
import json import json
import logging import logging
from collections.abc import Generator, Mapping, Sequence from collections.abc import Generator, Mapping, Sequence
from datetime import datetime from datetime import datetime
from enum import StrEnum from enum import StrEnum
from typing import TYPE_CHECKING, Any, Union, cast from typing import TYPE_CHECKING, Any, Optional, Union, cast
from uuid import uuid4 from uuid import uuid4
import sqlalchemy as sa import sqlalchemy as sa
@@ -46,7 +44,7 @@ if TYPE_CHECKING:
from constants import DEFAULT_FILE_NUMBER_LIMITS, HIDDEN_VALUE from constants import DEFAULT_FILE_NUMBER_LIMITS, HIDDEN_VALUE
from core.helper import encrypter from core.helper import encrypter
from core.variables import SecretVariable, Segment, SegmentType, Variable from core.variables import SecretVariable, Segment, SegmentType, VariableBase
from factories import variable_factory from factories import variable_factory
from libs import helper from libs import helper
@@ -69,7 +67,7 @@ class WorkflowType(StrEnum):
RAG_PIPELINE = "rag-pipeline" RAG_PIPELINE = "rag-pipeline"
@classmethod @classmethod
def value_of(cls, value: str) -> WorkflowType: def value_of(cls, value: str) -> "WorkflowType":
""" """
Get value of given mode. Get value of given mode.
@@ -82,7 +80,7 @@ class WorkflowType(StrEnum):
raise ValueError(f"invalid workflow type value {value}") raise ValueError(f"invalid workflow type value {value}")
@classmethod @classmethod
def from_app_mode(cls, app_mode: Union[str, AppMode]) -> WorkflowType: def from_app_mode(cls, app_mode: Union[str, "AppMode"]) -> "WorkflowType":
""" """
Get workflow type from app mode. Get workflow type from app mode.
@@ -178,12 +176,12 @@ class Workflow(Base): # bug
graph: str, graph: str,
features: str, features: str,
created_by: str, created_by: str,
environment_variables: Sequence[Variable], environment_variables: Sequence[VariableBase],
conversation_variables: Sequence[Variable], conversation_variables: Sequence[VariableBase],
rag_pipeline_variables: list[dict], rag_pipeline_variables: list[dict],
marked_name: str = "", marked_name: str = "",
marked_comment: str = "", marked_comment: str = "",
) -> Workflow: ) -> "Workflow":
workflow = Workflow() workflow = Workflow()
workflow.id = str(uuid4()) workflow.id = str(uuid4())
workflow.tenant_id = tenant_id workflow.tenant_id = tenant_id
@@ -447,7 +445,7 @@ class Workflow(Base): # bug
# decrypt secret variables value # decrypt secret variables value
def decrypt_func( def decrypt_func(
var: Variable, var: VariableBase,
) -> StringVariable | IntegerVariable | FloatVariable | SecretVariable: ) -> StringVariable | IntegerVariable | FloatVariable | SecretVariable:
if isinstance(var, SecretVariable): if isinstance(var, SecretVariable):
return var.model_copy(update={"value": encrypter.decrypt_token(tenant_id=tenant_id, token=var.value)}) return var.model_copy(update={"value": encrypter.decrypt_token(tenant_id=tenant_id, token=var.value)})
@@ -463,7 +461,7 @@ class Workflow(Base): # bug
return decrypted_results return decrypted_results
@environment_variables.setter @environment_variables.setter
def environment_variables(self, value: Sequence[Variable]): def environment_variables(self, value: Sequence[VariableBase]):
if not value: if not value:
self._environment_variables = "{}" self._environment_variables = "{}"
return return
@@ -487,7 +485,7 @@ class Workflow(Base): # bug
value[i] = origin_variables_dictionary[variable.id].model_copy(update={"name": variable.name}) value[i] = origin_variables_dictionary[variable.id].model_copy(update={"name": variable.name})
# encrypt secret variables value # encrypt secret variables value
def encrypt_func(var: Variable) -> Variable: def encrypt_func(var: VariableBase) -> VariableBase:
if isinstance(var, SecretVariable): if isinstance(var, SecretVariable):
return var.model_copy(update={"value": encrypter.encrypt_token(tenant_id=tenant_id, token=var.value)}) return var.model_copy(update={"value": encrypter.encrypt_token(tenant_id=tenant_id, token=var.value)})
else: else:
@@ -517,7 +515,7 @@ class Workflow(Base): # bug
return result return result
@property @property
def conversation_variables(self) -> Sequence[Variable]: def conversation_variables(self) -> Sequence[VariableBase]:
# TODO: find some way to init `self._conversation_variables` when instance created. # TODO: find some way to init `self._conversation_variables` when instance created.
if self._conversation_variables is None: if self._conversation_variables is None:
self._conversation_variables = "{}" self._conversation_variables = "{}"
@@ -527,7 +525,7 @@ class Workflow(Base): # bug
return results return results
@conversation_variables.setter @conversation_variables.setter
def conversation_variables(self, value: Sequence[Variable]): def conversation_variables(self, value: Sequence[VariableBase]):
self._conversation_variables = json.dumps( self._conversation_variables = json.dumps(
{var.name: var.model_dump() for var in value}, {var.name: var.model_dump() for var in value},
ensure_ascii=False, ensure_ascii=False,
@@ -622,7 +620,7 @@ class WorkflowRun(Base):
finished_at: Mapped[datetime | None] = mapped_column(DateTime) finished_at: Mapped[datetime | None] = mapped_column(DateTime)
exceptions_count: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True) exceptions_count: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True)
pause: Mapped[WorkflowPause | None] = orm.relationship( pause: Mapped[Optional["WorkflowPause"]] = orm.relationship(
"WorkflowPause", "WorkflowPause",
primaryjoin="WorkflowRun.id == foreign(WorkflowPause.workflow_run_id)", primaryjoin="WorkflowRun.id == foreign(WorkflowPause.workflow_run_id)",
uselist=False, uselist=False,
@@ -692,7 +690,7 @@ class WorkflowRun(Base):
} }
@classmethod @classmethod
def from_dict(cls, data: dict[str, Any]) -> WorkflowRun: def from_dict(cls, data: dict[str, Any]) -> "WorkflowRun":
return cls( return cls(
id=data.get("id"), id=data.get("id"),
tenant_id=data.get("tenant_id"), tenant_id=data.get("tenant_id"),
@@ -844,7 +842,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
created_by: Mapped[str] = mapped_column(StringUUID) created_by: Mapped[str] = mapped_column(StringUUID)
finished_at: Mapped[datetime | None] = mapped_column(DateTime) finished_at: Mapped[datetime | None] = mapped_column(DateTime)
offload_data: Mapped[list[WorkflowNodeExecutionOffload]] = orm.relationship( offload_data: Mapped[list["WorkflowNodeExecutionOffload"]] = orm.relationship(
"WorkflowNodeExecutionOffload", "WorkflowNodeExecutionOffload",
primaryjoin="WorkflowNodeExecutionModel.id == foreign(WorkflowNodeExecutionOffload.node_execution_id)", primaryjoin="WorkflowNodeExecutionModel.id == foreign(WorkflowNodeExecutionOffload.node_execution_id)",
uselist=True, uselist=True,
@@ -854,13 +852,13 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
@staticmethod @staticmethod
def preload_offload_data( def preload_offload_data(
query: Select[tuple[WorkflowNodeExecutionModel]] | orm.Query[WorkflowNodeExecutionModel], query: Select[tuple["WorkflowNodeExecutionModel"]] | orm.Query["WorkflowNodeExecutionModel"],
): ):
return query.options(orm.selectinload(WorkflowNodeExecutionModel.offload_data)) return query.options(orm.selectinload(WorkflowNodeExecutionModel.offload_data))
@staticmethod @staticmethod
def preload_offload_data_and_files( def preload_offload_data_and_files(
query: Select[tuple[WorkflowNodeExecutionModel]] | orm.Query[WorkflowNodeExecutionModel], query: Select[tuple["WorkflowNodeExecutionModel"]] | orm.Query["WorkflowNodeExecutionModel"],
): ):
return query.options( return query.options(
orm.selectinload(WorkflowNodeExecutionModel.offload_data).options( orm.selectinload(WorkflowNodeExecutionModel.offload_data).options(
@@ -935,7 +933,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
) )
return extras return extras
def _get_offload_by_type(self, type_: ExecutionOffLoadType) -> WorkflowNodeExecutionOffload | None: def _get_offload_by_type(self, type_: ExecutionOffLoadType) -> Optional["WorkflowNodeExecutionOffload"]:
return next(iter([i for i in self.offload_data if i.type_ == type_]), None) return next(iter([i for i in self.offload_data if i.type_ == type_]), None)
@property @property
@@ -1049,7 +1047,7 @@ class WorkflowNodeExecutionOffload(Base):
back_populates="offload_data", back_populates="offload_data",
) )
file: Mapped[UploadFile | None] = orm.relationship( file: Mapped[Optional["UploadFile"]] = orm.relationship(
foreign_keys=[file_id], foreign_keys=[file_id],
lazy="raise", lazy="raise",
uselist=False, uselist=False,
@@ -1067,7 +1065,7 @@ class WorkflowAppLogCreatedFrom(StrEnum):
INSTALLED_APP = "installed-app" INSTALLED_APP = "installed-app"
@classmethod @classmethod
def value_of(cls, value: str) -> WorkflowAppLogCreatedFrom: def value_of(cls, value: str) -> "WorkflowAppLogCreatedFrom":
""" """
Get value of given mode. Get value of given mode.
@@ -1184,7 +1182,7 @@ class ConversationVariable(TypeBase):
) )
@classmethod @classmethod
def from_variable(cls, *, app_id: str, conversation_id: str, variable: Variable) -> ConversationVariable: def from_variable(cls, *, app_id: str, conversation_id: str, variable: VariableBase) -> "ConversationVariable":
obj = cls( obj = cls(
id=variable.id, id=variable.id,
app_id=app_id, app_id=app_id,
@@ -1193,7 +1191,7 @@ class ConversationVariable(TypeBase):
) )
return obj return obj
def to_variable(self) -> Variable: def to_variable(self) -> VariableBase:
mapping = json.loads(self.data) mapping = json.loads(self.data)
return variable_factory.build_conversation_variable_from_mapping(mapping) return variable_factory.build_conversation_variable_from_mapping(mapping)
@@ -1337,7 +1335,7 @@ class WorkflowDraftVariable(Base):
) )
# Relationship to WorkflowDraftVariableFile # Relationship to WorkflowDraftVariableFile
variable_file: Mapped[WorkflowDraftVariableFile | None] = orm.relationship( variable_file: Mapped[Optional["WorkflowDraftVariableFile"]] = orm.relationship(
foreign_keys=[file_id], foreign_keys=[file_id],
lazy="raise", lazy="raise",
uselist=False, uselist=False,
@@ -1507,7 +1505,7 @@ class WorkflowDraftVariable(Base):
node_execution_id: str | None, node_execution_id: str | None,
description: str = "", description: str = "",
file_id: str | None = None, file_id: str | None = None,
) -> WorkflowDraftVariable: ) -> "WorkflowDraftVariable":
variable = WorkflowDraftVariable() variable = WorkflowDraftVariable()
variable.id = str(uuid4()) variable.id = str(uuid4())
variable.created_at = naive_utc_now() variable.created_at = naive_utc_now()
@@ -1530,7 +1528,7 @@ class WorkflowDraftVariable(Base):
name: str, name: str,
value: Segment, value: Segment,
description: str = "", description: str = "",
) -> WorkflowDraftVariable: ) -> "WorkflowDraftVariable":
variable = cls._new( variable = cls._new(
app_id=app_id, app_id=app_id,
node_id=CONVERSATION_VARIABLE_NODE_ID, node_id=CONVERSATION_VARIABLE_NODE_ID,
@@ -1551,7 +1549,7 @@ class WorkflowDraftVariable(Base):
value: Segment, value: Segment,
node_execution_id: str, node_execution_id: str,
editable: bool = False, editable: bool = False,
) -> WorkflowDraftVariable: ) -> "WorkflowDraftVariable":
variable = cls._new( variable = cls._new(
app_id=app_id, app_id=app_id,
node_id=SYSTEM_VARIABLE_NODE_ID, node_id=SYSTEM_VARIABLE_NODE_ID,
@@ -1574,7 +1572,7 @@ class WorkflowDraftVariable(Base):
visible: bool = True, visible: bool = True,
editable: bool = True, editable: bool = True,
file_id: str | None = None, file_id: str | None = None,
) -> WorkflowDraftVariable: ) -> "WorkflowDraftVariable":
variable = cls._new( variable = cls._new(
app_id=app_id, app_id=app_id,
node_id=node_id, node_id=node_id,
@@ -1670,7 +1668,7 @@ class WorkflowDraftVariableFile(Base):
) )
# Relationship to UploadFile # Relationship to UploadFile
upload_file: Mapped[UploadFile] = orm.relationship( upload_file: Mapped["UploadFile"] = orm.relationship(
foreign_keys=[upload_file_id], foreign_keys=[upload_file_id],
lazy="raise", lazy="raise",
uselist=False, uselist=False,
@@ -1737,7 +1735,7 @@ class WorkflowPause(DefaultFieldsMixin, Base):
state_object_key: Mapped[str] = mapped_column(String(length=255), nullable=False) state_object_key: Mapped[str] = mapped_column(String(length=255), nullable=False)
# Relationship to WorkflowRun # Relationship to WorkflowRun
workflow_run: Mapped[WorkflowRun] = orm.relationship( workflow_run: Mapped["WorkflowRun"] = orm.relationship(
foreign_keys=[workflow_run_id], foreign_keys=[workflow_run_id],
# require explicit preloading. # require explicit preloading.
lazy="raise", lazy="raise",
@@ -1793,7 +1791,7 @@ class WorkflowPauseReason(DefaultFieldsMixin, Base):
) )
@classmethod @classmethod
def from_entity(cls, pause_reason: PauseReason) -> WorkflowPauseReason: def from_entity(cls, pause_reason: PauseReason) -> "WorkflowPauseReason":
if isinstance(pause_reason, HumanInputRequired): if isinstance(pause_reason, HumanInputRequired):
return cls( return cls(
type_=PauseReasonType.HUMAN_INPUT_REQUIRED, form_id=pause_reason.form_id, node_id=pause_reason.node_id type_=PauseReasonType.HUMAN_INPUT_REQUIRED, form_id=pause_reason.form_id, node_id=pause_reason.node_id
@@ -1,7 +1,7 @@
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.orm import Session, sessionmaker
from core.variables.variables import Variable from core.variables.variables import VariableBase
from models import ConversationVariable from models import ConversationVariable
@@ -13,7 +13,7 @@ class ConversationVariableUpdater:
def __init__(self, session_maker: sessionmaker[Session]) -> None: def __init__(self, session_maker: sessionmaker[Session]) -> None:
self._session_maker: sessionmaker[Session] = session_maker self._session_maker: sessionmaker[Session] = session_maker
def update(self, conversation_id: str, variable: Variable) -> None: def update(self, conversation_id: str, variable: VariableBase) -> None:
stmt = select(ConversationVariable).where( stmt = select(ConversationVariable).where(
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
) )
+3 -3
View File
@@ -36,7 +36,7 @@ from core.rag.entities.event import (
) )
from core.repositories.factory import DifyCoreRepositoryFactory from core.repositories.factory import DifyCoreRepositoryFactory
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
from core.variables.variables import Variable from core.variables.variables import VariableBase
from core.workflow.entities.workflow_node_execution import ( from core.workflow.entities.workflow_node_execution import (
WorkflowNodeExecution, WorkflowNodeExecution,
WorkflowNodeExecutionStatus, WorkflowNodeExecutionStatus,
@@ -270,8 +270,8 @@ class RagPipelineService:
graph: dict, graph: dict,
unique_hash: str | None, unique_hash: str | None,
account: Account, account: Account,
environment_variables: Sequence[Variable], environment_variables: Sequence[VariableBase],
conversation_variables: Sequence[Variable], conversation_variables: Sequence[VariableBase],
rag_pipeline_variables: list, rag_pipeline_variables: list,
) -> Workflow: ) -> Workflow:
""" """
@@ -15,7 +15,7 @@ from sqlalchemy.sql.expression import and_, or_
from configs import dify_config from configs import dify_config
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.file.models import File from core.file.models import File
from core.variables import Segment, StringSegment, Variable from core.variables import Segment, StringSegment, VariableBase
from core.variables.consts import SELECTORS_LENGTH from core.variables.consts import SELECTORS_LENGTH
from core.variables.segments import ( from core.variables.segments import (
ArrayFileSegment, ArrayFileSegment,
@@ -77,14 +77,14 @@ class DraftVarLoader(VariableLoader):
# Application ID for which variables are being loaded. # Application ID for which variables are being loaded.
_app_id: str _app_id: str
_tenant_id: str _tenant_id: str
_fallback_variables: Sequence[Variable] _fallback_variables: Sequence[VariableBase]
def __init__( def __init__(
self, self,
engine: Engine, engine: Engine,
app_id: str, app_id: str,
tenant_id: str, tenant_id: str,
fallback_variables: Sequence[Variable] | None = None, fallback_variables: Sequence[VariableBase] | None = None,
): ):
self._engine = engine self._engine = engine
self._app_id = app_id self._app_id = app_id
@@ -94,12 +94,12 @@ class DraftVarLoader(VariableLoader):
def _selector_to_tuple(self, selector: Sequence[str]) -> tuple[str, str]: def _selector_to_tuple(self, selector: Sequence[str]) -> tuple[str, str]:
return (selector[0], selector[1]) return (selector[0], selector[1])
def load_variables(self, selectors: list[list[str]]) -> list[Variable]: def load_variables(self, selectors: list[list[str]]) -> list[VariableBase]:
if not selectors: if not selectors:
return [] return []
# Map each selector (as a tuple via `_selector_to_tuple`) to its corresponding Variable instance. # Map each selector (as a tuple via `_selector_to_tuple`) to its corresponding variable instance.
variable_by_selector: dict[tuple[str, str], Variable] = {} variable_by_selector: dict[tuple[str, str], VariableBase] = {}
with Session(bind=self._engine, expire_on_commit=False) as session: with Session(bind=self._engine, expire_on_commit=False) as session:
srv = WorkflowDraftVariableService(session) srv = WorkflowDraftVariableService(session)
@@ -145,7 +145,7 @@ class DraftVarLoader(VariableLoader):
return list(variable_by_selector.values()) return list(variable_by_selector.values())
def _load_offloaded_variable(self, draft_var: WorkflowDraftVariable) -> tuple[tuple[str, str], Variable]: def _load_offloaded_variable(self, draft_var: WorkflowDraftVariable) -> tuple[tuple[str, str], VariableBase]:
# This logic is closely tied to `WorkflowDraftVaribleService._try_offload_large_variable` # This logic is closely tied to `WorkflowDraftVaribleService._try_offload_large_variable`
# and must remain synchronized with it. # and must remain synchronized with it.
# Ideally, these should be co-located for better maintainability. # Ideally, these should be co-located for better maintainability.
+8 -8
View File
@@ -13,8 +13,8 @@ from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.file import File from core.file import File
from core.repositories import DifyCoreRepositoryFactory from core.repositories import DifyCoreRepositoryFactory
from core.variables import Variable from core.variables import VariableBase
from core.variables.variables import VariableUnion from core.variables.variables import Variable
from core.workflow.entities import WorkflowNodeExecution from core.workflow.entities import WorkflowNodeExecution
from core.workflow.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from core.workflow.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.errors import WorkflowNodeRunFailedError from core.workflow.errors import WorkflowNodeRunFailedError
@@ -198,8 +198,8 @@ class WorkflowService:
features: dict, features: dict,
unique_hash: str | None, unique_hash: str | None,
account: Account, account: Account,
environment_variables: Sequence[Variable], environment_variables: Sequence[VariableBase],
conversation_variables: Sequence[Variable], conversation_variables: Sequence[VariableBase],
) -> Workflow: ) -> Workflow:
""" """
Sync draft workflow Sync draft workflow
@@ -1044,7 +1044,7 @@ def _setup_variable_pool(
workflow: Workflow, workflow: Workflow,
node_type: NodeType, node_type: NodeType,
conversation_id: str, conversation_id: str,
conversation_variables: list[Variable], conversation_variables: list[VariableBase],
): ):
# Only inject system variables for START node type. # Only inject system variables for START node type.
if node_type == NodeType.START or node_type.is_trigger_node: if node_type == NodeType.START or node_type.is_trigger_node:
@@ -1070,9 +1070,9 @@ def _setup_variable_pool(
system_variables=system_variable, system_variables=system_variable,
user_inputs=user_inputs, user_inputs=user_inputs,
environment_variables=workflow.environment_variables, environment_variables=workflow.environment_variables,
# Based on the definition of `VariableUnion`, # Based on the definition of `Variable`,
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible. # `VariableBase` instances can be safely used as `Variable` since they are compatible.
conversation_variables=cast(list[VariableUnion], conversation_variables), # conversation_variables=cast(list[Variable], conversation_variables), #
) )
return variable_pool return variable_pool
@@ -35,7 +35,6 @@ from core.variables.variables import (
SecretVariable, SecretVariable,
StringVariable, StringVariable,
Variable, Variable,
VariableUnion,
) )
from core.workflow.runtime import VariablePool from core.workflow.runtime import VariablePool
from core.workflow.system_variable import SystemVariable from core.workflow.system_variable import SystemVariable
@@ -96,7 +95,7 @@ class _Segments(BaseModel):
class _Variables(BaseModel): class _Variables(BaseModel):
variables: list[VariableUnion] variables: list[Variable]
def create_test_file( def create_test_file(
@@ -194,7 +193,7 @@ class TestSegmentDumpAndLoad:
# Create one instance of each variable type # Create one instance of each variable type
test_file = create_test_file() test_file = create_test_file()
all_variables: list[VariableUnion] = [ all_variables: list[Variable] = [
NoneVariable(name="none_var"), NoneVariable(name="none_var"),
StringVariable(value="test string", name="string_var"), StringVariable(value="test string", name="string_var"),
IntegerVariable(value=42, name="int_var"), IntegerVariable(value=42, name="int_var"),
@@ -11,7 +11,7 @@ from core.variables import (
SegmentType, SegmentType,
StringVariable, StringVariable,
) )
from core.variables.variables import Variable from core.variables.variables import VariableBase
def test_frozen_variables(): def test_frozen_variables():
@@ -76,7 +76,7 @@ def test_object_variable_to_object():
def test_variable_to_object(): def test_variable_to_object():
var: Variable = StringVariable(name="text", value="text") var: VariableBase = StringVariable(name="text", value="text")
assert var.to_object() == "text" assert var.to_object() == "text"
var = IntegerVariable(name="integer", value=42) var = IntegerVariable(name="integer", value=42)
assert var.to_object() == 42 assert var.to_object() == 42
@@ -24,7 +24,7 @@ from core.variables.variables import (
IntegerVariable, IntegerVariable,
ObjectVariable, ObjectVariable,
StringVariable, StringVariable,
VariableUnion, Variable,
) )
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from core.workflow.runtime import VariablePool from core.workflow.runtime import VariablePool
@@ -160,7 +160,7 @@ class TestVariablePoolSerialization:
) )
# Create environment variables with all types including ArrayFileVariable # Create environment variables with all types including ArrayFileVariable
env_vars: list[VariableUnion] = [ env_vars: list[Variable] = [
StringVariable( StringVariable(
id="env_string_id", id="env_string_id",
name="env_string", name="env_string",
@@ -182,7 +182,7 @@ class TestVariablePoolSerialization:
] ]
# Create conversation variables with complex data # Create conversation variables with complex data
conv_vars: list[VariableUnion] = [ conv_vars: list[Variable] = [
StringVariable( StringVariable(
id="conv_string_id", id="conv_string_id",
name="conv_string", name="conv_string",