Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions airflow-core/.pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ repos:
^src/airflow/api_fastapi/core_api/services/ui/task_group.py$|
^src/airflow/api_fastapi/execution_api/routes/hitl\.py$|
^src/airflow/api_fastapi/execution_api/routes/task_instances\.py$|
^src/airflow/api_fastapi/execution_api/versions/v2026_04_06\.py$|
^src/airflow/api_fastapi/logging/decorators\.py$|
^src/airflow/assets/evaluation\.py$|
^src/airflow/assets/manager\.py$|
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,34 @@ class ModifyDeferredTaskKwargsToJsonValue(VersionChange):
schema(TIDeferredStatePayload).field("next_kwargs").had(type=dict[str, Any]),
)

@convert_response_to_previous_version_for(TIRunContext) # type: ignore[arg-type]
def convert_next_kwargs_to_base_serialization(response: ResponseInfo) -> None: # type: ignore[misc]
"""
Convert next_kwargs from SDK serde format to BaseSerialization format for old workers.
Old workers (task-sdk < 1.2) only know BaseSerialization.deserialize(), which requires
dicts wrapped as {"__type": "dict", "__var": {...}}. SDK serde produces plain dicts that
BaseSerialization cannot parse, causing KeyError on __var.
We must deserialize SDK serde first to recover native Python objects (datetime,
timedelta, etc.), then re-serialize with BaseSerialization so old workers get
proper typed values instead of raw {"__classname__": ...} dicts.
"""
next_kwargs = response.body.get("next_kwargs")
if next_kwargs is None:
return

from airflow.sdk.serde import deserialize
from airflow.serialization.serialized_objects import BaseSerialization

try:
plain = deserialize(next_kwargs)
except (ImportError, KeyError, AttributeError, TypeError):
# Already in BaseSerialization format (rolling upgrade, old data in DB)
return

response.body["next_kwargs"] = BaseSerialization.serialize(plain)


class RemoveUpstreamMapIndexesField(VersionChange):
"""Remove upstream_map_indexes field from TIRunContext - now computed by Task SDK."""
Expand Down
8 changes: 5 additions & 3 deletions airflow-core/src/airflow/models/trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,13 +477,15 @@ def handle_event_submit(event: TriggerEvent, *, task_instance: TaskInstance, ses

next_kwargs = BaseSerialization.deserialize(next_kwargs_raw)

# Add event to the plain dict, then serialize everything together. This ensures that the event is properly
# nested inside __var__ in the final serde serialized structure.
# Add event to the plain dict, then serialize everything together so nested
# non-primitive values get proper serde encoding.
if TYPE_CHECKING:
assert isinstance(next_kwargs, dict)
next_kwargs["event"] = event.payload

# re-serialize the entire dict using serde to ensure consistent structure
# Re-serialize using serde. The Execution API version converter
# (ModifyDeferredTaskKwargsToJsonValue) handles converting this to
# BaseSerialization format when serving old workers.
task_instance.next_kwargs = serialize(next_kwargs)

# Remove ourselves as its trigger
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import pytest

from airflow._shared.timezones import timezone
from airflow.serialization.serialized_objects import BaseSerialization
from airflow.utils.state import DagRunState, State

from tests_common.test_utils.db import clear_db_runs
Expand Down Expand Up @@ -125,3 +126,129 @@ def test_old_version_preserves_real_start_date(
assert response.status_code == 200
assert dag_run["start_date"] is not None, "start_date should not be None when DagRun has started"
assert dag_run["start_date"] == TIMESTAMP.isoformat().replace("+00:00", "Z")


class TestNextKwargsBackwardCompat:
"""Old workers only know BaseSerialization.deserialize -- SDK serde plain dicts cause KeyError."""

@pytest.fixture(autouse=True)
def _freeze_time(self, time_machine):
time_machine.move_to(TIMESTAMP_STR, tick=False)

def setup_method(self):
clear_db_runs()

def teardown_method(self):
clear_db_runs()

def test_old_version_gets_base_serialization_format(self, old_ver_client, session, create_task_instance):
"""Old API version receives next_kwargs wrapped in __type/__var so BaseSerialization can parse it."""
ti = create_task_instance(
task_id="test_next_kwargs_compat",
state=State.QUEUED,
session=session,
start_date=TIMESTAMP,
)
# Store SDK serde format (plain dict) in DB -- this is what trigger.py handle_event_submit produces
ti.next_method = "execute_complete"
ti.next_kwargs = {"cheesecake": True, "event": "payload"}
session.commit()

response = old_ver_client.patch(f"/execution/task-instances/{ti.id}/run", json=RUN_PATCH_BODY)

assert response.status_code == 200
next_kwargs = response.json()["next_kwargs"]
# Old workers call BaseSerialization.deserialize on this -- verify it works
result = BaseSerialization.deserialize(next_kwargs)
assert result == {"cheesecake": True, "event": "payload"}

def test_old_version_deserializes_complex_types(self, old_ver_client, session, create_task_instance):
"""Non-primitive values (datetime) must round-trip through serde -> BaseSerialization correctly."""
from airflow.sdk.serde import serialize as serde_serialize

original = {"event": TIMESTAMP, "simple": True}
# Store SDK serde format with a datetime -- this is what handle_event_submit produces
# when the trigger payload contains a datetime (e.g. DateTimeSensorAsync)
serde_encoded = serde_serialize(original)

ti = create_task_instance(
task_id="test_next_kwargs_datetime",
state=State.QUEUED,
session=session,
start_date=TIMESTAMP,
)
ti.next_method = "execute_complete"
ti.next_kwargs = serde_encoded
session.commit()

response = old_ver_client.patch(f"/execution/task-instances/{ti.id}/run", json=RUN_PATCH_BODY)

assert response.status_code == 200
next_kwargs = response.json()["next_kwargs"]
result = BaseSerialization.deserialize(next_kwargs)
assert result["simple"] is True
# datetime must come back as a datetime, not a {"__classname__": ...} dict
assert result["event"] == TIMESTAMP

def test_old_version_handles_already_base_serialization_in_db(
self, old_ver_client, session, create_task_instance
):
"""Rolling upgrade: DB still has BaseSerialization format from old handle_event_submit."""
ti = create_task_instance(
task_id="test_next_kwargs_already_base",
state=State.QUEUED,
session=session,
start_date=TIMESTAMP,
)
ti.next_method = "execute_complete"
# Pre-upgrade data: BaseSerialization format already in DB
ti.next_kwargs = BaseSerialization.serialize({"cheesecake": True, "event": "payload"})
session.commit()

response = old_ver_client.patch(f"/execution/task-instances/{ti.id}/run", json=RUN_PATCH_BODY)

assert response.status_code == 200
next_kwargs = response.json()["next_kwargs"]
# Should still be parseable by old workers
result = BaseSerialization.deserialize(next_kwargs)
assert result == {"cheesecake": True, "event": "payload"}

def test_old_version_handles_submit_failure_plain_dict(
self, old_ver_client, session, create_task_instance
):
"""submit_failure and scheduler timeout write raw plain dicts -- converter must handle those too."""
ti = create_task_instance(
task_id="test_next_kwargs_failure",
state=State.QUEUED,
session=session,
start_date=TIMESTAMP,
)
ti.next_method = "__fail__"
# This is what submit_failure / scheduler timeout writes -- plain dict, no wrapping
ti.next_kwargs = {"error": "Trigger timeout"}
session.commit()

response = old_ver_client.patch(f"/execution/task-instances/{ti.id}/run", json=RUN_PATCH_BODY)

assert response.status_code == 200
next_kwargs = response.json()["next_kwargs"]
result = BaseSerialization.deserialize(next_kwargs)
assert result == {"error": "Trigger timeout"}

def test_head_version_returns_raw_serde_format(self, client, session, create_task_instance):
"""Head API version returns next_kwargs as-is (SDK serde format)."""
ti = create_task_instance(
task_id="test_next_kwargs_head",
state=State.QUEUED,
session=session,
start_date=TIMESTAMP,
)
ti.next_method = "execute_complete"
ti.next_kwargs = {"cheesecake": True, "event": "payload"}
session.commit()

response = client.patch(f"/execution/task-instances/{ti.id}/run", json=RUN_PATCH_BODY)

assert response.status_code == 200
# Head version gets the plain dict directly -- no BaseSerialization wrapping
assert response.json()["next_kwargs"] == {"cheesecake": True, "event": "payload"}
Loading