|
9 | 9 | from enum import StrEnum, auto |
10 | 10 | from typing import Any, TypeAlias |
11 | 11 |
|
12 | | -from pydantic import BaseModel, Field |
| 12 | +from pydantic import BaseModel, Field, model_validator |
13 | 13 |
|
14 | 14 | from core.file import File |
15 | | -from core.variables import Segment, Variable |
| 15 | +from core.variables import Segment, SegmentType, Variable |
16 | 16 |
|
17 | 17 |
|
18 | 18 | class CommandType(StrEnum): |
@@ -51,8 +51,31 @@ class VariableUpdate(BaseModel): |
51 | 51 | """Represents a single variable update instruction.""" |
52 | 52 |
|
53 | 53 | selector: tuple[str, str] = Field(description="Variable selector (node_id, variable_name)") |
| 54 | + value_type: SegmentType = Field(description="Variable value type") |
54 | 55 | value: VariableUpdateValue = Field(description="New variable value") |
55 | 56 |
|
| 57 | + @model_validator(mode="after") |
| 58 | + def _validate_value_type(self) -> "VariableUpdate": |
| 59 | + value_type = self.value_type |
| 60 | + value = self.value |
| 61 | + |
| 62 | + if isinstance(value, Variable | Segment): |
| 63 | + if value.value_type != value_type: |
| 64 | + raise ValueError(f"value type mismatch: expected {value_type}, got {value.value_type}") |
| 65 | + return self |
| 66 | + |
| 67 | + if isinstance(value, File): |
| 68 | + if value_type != SegmentType.FILE: |
| 69 | + raise ValueError(f"value type mismatch: expected {value_type}, got {SegmentType.FILE}") |
| 70 | + return self |
| 71 | + |
| 72 | + casted_value = SegmentType.cast_value(value, value_type) |
| 73 | + if not value_type.is_valid(casted_value): |
| 74 | + raise ValueError(f"value type mismatch: expected {value_type}, got {type(value).__name__}") |
| 75 | + |
| 76 | + self.value = casted_value |
| 77 | + return self |
| 78 | + |
56 | 79 |
|
57 | 80 | class UpdateVariablesCommand(GraphEngineCommand): |
58 | 81 | """Command to update a group of variables in the variable pool.""" |
|
0 commit comments