Skip to content

Commit 3edd525

Browse files
committed
- feat(graph-engine): enforce variable update value types with SegmentType validation and casting
- test(graph-engine): update VariableUpdate usages to include value_type in command tests
1 parent 89e644f commit 3edd525

3 files changed

Lines changed: 31 additions & 6 deletions

File tree

api/core/workflow/graph_engine/entities/commands.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99
from enum import StrEnum, auto
1010
from typing import Any, TypeAlias
1111

12-
from pydantic import BaseModel, Field
12+
from pydantic import BaseModel, Field, model_validator
1313

1414
from core.file import File
15-
from core.variables import Segment, Variable
15+
from core.variables import Segment, SegmentType, Variable
1616

1717

1818
class CommandType(StrEnum):
@@ -51,8 +51,31 @@ class VariableUpdate(BaseModel):
5151
"""Represents a single variable update instruction."""
5252

5353
selector: tuple[str, str] = Field(description="Variable selector (node_id, variable_name)")
54+
value_type: SegmentType = Field(description="Variable value type")
5455
value: VariableUpdateValue = Field(description="New variable value")
5556

57+
@model_validator(mode="after")
58+
def _validate_value_type(self) -> "VariableUpdate":
59+
value_type = self.value_type
60+
value = self.value
61+
62+
if isinstance(value, Variable | Segment):
63+
if value.value_type != value_type:
64+
raise ValueError(f"value type mismatch: expected {value_type}, got {value.value_type}")
65+
return self
66+
67+
if isinstance(value, File):
68+
if value_type != SegmentType.FILE:
69+
raise ValueError(f"value type mismatch: expected {value_type}, got {SegmentType.FILE}")
70+
return self
71+
72+
casted_value = SegmentType.cast_value(value, value_type)
73+
if not value_type.is_valid(casted_value):
74+
raise ValueError(f"value type mismatch: expected {value_type}, got {type(value).__name__}")
75+
76+
self.value = casted_value
77+
return self
78+
5679

5780
class UpdateVariablesCommand(GraphEngineCommand):
5881
"""Command to update a group of variables in the variable pool."""

api/tests/unit_tests/core/workflow/graph_engine/command_channels/test_redis_channel.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import json
44
from unittest.mock import MagicMock
55

6+
from core.variables import SegmentType
67
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
78
from core.workflow.graph_engine.entities.commands import (
89
AbortCommand,
@@ -169,8 +170,8 @@ def test_fetch_commands_with_update_variables_command(self):
169170

170171
update_command = UpdateVariablesCommand(
171172
updates=[
172-
VariableUpdate(selector=["node1", "foo"], value="bar"),
173-
VariableUpdate(selector=["node2", "baz"], value=123),
173+
VariableUpdate(selector=["node1", "foo"], value_type=SegmentType.STRING, value="bar"),
174+
VariableUpdate(selector=["node2", "baz"], value_type=SegmentType.INTEGER, value=123),
174175
]
175176
)
176177
command_json = json.dumps(update_command.model_dump())

api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from unittest.mock import MagicMock
55

66
from core.app.entities.app_invoke_entities import InvokeFrom
7+
from core.variables import SegmentType
78
from core.workflow.entities.graph_init_params import GraphInitParams
89
from core.workflow.entities.pause_reason import SchedulingPause
910
from core.workflow.graph import Graph
@@ -231,8 +232,8 @@ def test_update_variables_command_updates_pool():
231232

232233
update_command = UpdateVariablesCommand(
233234
updates=[
234-
VariableUpdate(selector=["node1", "foo"], value="new value"),
235-
VariableUpdate(selector=["node2", "bar"], value=123),
235+
VariableUpdate(selector=["node1", "foo"], value_type=SegmentType.STRING, value="new value"),
236+
VariableUpdate(selector=["node2", "bar"], value_type=SegmentType.INTEGER, value=123),
236237
]
237238
)
238239
command_channel.send_command(update_command)

0 commit comments

Comments
 (0)