Enforce supervisor schema class name matches its type literal#66899
Conversation
c37f6da to
45c5d86
Compare
type literal
type literaltype literal
af83d31 to
e513938
Compare
|
Not sure if considered but besides making it static check at coding time, have we considered making this a runtime property handled by Pydantic? |
I will follow-up on this one. BTW, not sure if there's any reason to explictly specify the |
|
Tried it. Runtime-binding changes the schema and also breaks the tagged union itself, because Pydantic v2 needs the discriminator field to be a from typing import Annotated, Literal, Union
from pydantic import BaseModel, Field, TypeAdapter
# --- Current pattern: Literal type field ---
class A1(BaseModel):
type: Literal["A1"] = "A1"
class A2(BaseModel):
type: Literal["A2"] = "A2"
print(A1.model_json_schema())
#. {'properties': {'type': {'const': 'A1', 'default': 'A1', 'title': 'Type', 'type': 'string'}}, 'title': 'A1', 'type': 'object'}
TypeAdapter(Annotated[Union[A1, A2], Field(discriminator="type")]) # OK
# --- Proposed pattern: runtime-bind in __init__ ---
class B(BaseModel):
type: str = "__undefined__"
def __init__(self, **kw):
super().__init__(**kw)
self.type = type(self).__name__
class B1(B): pass
class B2(B): pass
print(B1.model_json_schema())
# {'properties': {'type': {'default': '__undefined__', 'title': 'Type', 'type': 'string'}}, 'title': 'B1', 'type': 'object'}
print(B1().type) # 'B1' (instance state only; not reflected in schema)
TypeAdapter(Annotated[Union[B1, B2], Field(discriminator="type")])
# PydanticUserError: Model 'B1' needs field 'type' to be of type `Literal`So |
Why
The six supervisor discriminated unions (
ToTask,ToSupervisor,ToManager,ToDagProcessor,ToTriggerRunner,ToTriggerSupervisor) need every member's class__name__to equal itstype: Literal[...]value soCommsDecoderroutes wire frames to the right class — but nothing enforced that invariant, and two members had silently drifted onmain.What
task-sdk/tests/task_sdk/execution_time/test_supervisor_schemas_name_type_sync.py— a parametrized unit test (one case per union) that walks every member and asserts class__name__equals its single-valuetypeLiteral. Catches drift, a missingtypefield, and multi-value Literals. Runs as part of the task-sdk test suite.task-sdk/src/airflow/sdk/execution_time/comms.py:XComCountResponse:Literal["XComLengthResponse"]→Literal["XComCountResponse"]GetXComCount:Literal["GetNumberXComs"]→Literal["GetXComCount"]task-sdk/tests/task_sdk/execution_time/test_supervisor.pythat pinned the old wire string.Verfication
Was generative AI tooling used to co-author this PR?