From 858818b1f8e5b5cdf02bec9af5197379f4aa770d Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Wed, 1 Apr 2026 18:20:16 +0100 Subject: [PATCH] Fix deferred task resume failure when worker is older than server Workers on task-sdk < 1.2 crash with `KeyError: __var` when resuming deferred tasks whose trigger has fired on a 3.2 server. The server's handle_event_submit re-serializes next_kwargs with SDK serde (plain dicts), but old workers expect BaseSerialization format (__type/__var wrapping). submit_failure and scheduler timeout paths also write plain dicts that old workers cannot parse. Add a Cadwyn response converter on TIRunContext that deserializes SDK serde then re-serializes with BaseSerialization for old API versions. This catches all next_kwargs writers at the single API read point. Short-circuits when data is already in BaseSerialization format. --- airflow-core/.pre-commit-config.yaml | 1 + .../execution_api/versions/v2026_04_06.py | 28 ++++ airflow-core/src/airflow/models/trigger.py | 8 +- .../v2026_04_06/test_task_instances.py | 127 ++++++++++++++++++ 4 files changed, 161 insertions(+), 3 deletions(-) diff --git a/airflow-core/.pre-commit-config.yaml b/airflow-core/.pre-commit-config.yaml index 121b51d4e8b55..5d2e3e7fe2b9b 100644 --- a/airflow-core/.pre-commit-config.yaml +++ b/airflow-core/.pre-commit-config.yaml @@ -313,6 +313,7 @@ repos: ^src/airflow/api_fastapi/core_api/services/ui/task_group.py$| ^src/airflow/api_fastapi/execution_api/routes/hitl\.py$| ^src/airflow/api_fastapi/execution_api/routes/task_instances\.py$| + ^src/airflow/api_fastapi/execution_api/versions/v2026_04_06\.py$| ^src/airflow/api_fastapi/logging/decorators\.py$| ^src/airflow/assets/evaluation\.py$| ^src/airflow/assets/manager\.py$| diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_04_06.py b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_04_06.py index 85ec1f2a60899..59e671f0a24a2 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_04_06.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_04_06.py @@ -118,6 +118,34 @@ class ModifyDeferredTaskKwargsToJsonValue(VersionChange): schema(TIDeferredStatePayload).field("next_kwargs").had(type=dict[str, Any]), ) + @convert_response_to_previous_version_for(TIRunContext) # type: ignore[arg-type] + def convert_next_kwargs_to_base_serialization(response: ResponseInfo) -> None: # type: ignore[misc] + """ + Convert next_kwargs from SDK serde format to BaseSerialization format for old workers. + + Old workers (task-sdk < 1.2) only know BaseSerialization.deserialize(), which requires + dicts wrapped as {"__type": "dict", "__var": {...}}. SDK serde produces plain dicts that + BaseSerialization cannot parse, causing KeyError on __var. + + We must deserialize SDK serde first to recover native Python objects (datetime, + timedelta, etc.), then re-serialize with BaseSerialization so old workers get + proper typed values instead of raw {"__classname__": ...} dicts. + """ + next_kwargs = response.body.get("next_kwargs") + if next_kwargs is None: + return + + from airflow.sdk.serde import deserialize + from airflow.serialization.serialized_objects import BaseSerialization + + try: + plain = deserialize(next_kwargs) + except (ImportError, KeyError, AttributeError, TypeError): + # Already in BaseSerialization format (rolling upgrade, old data in DB) + return + + response.body["next_kwargs"] = BaseSerialization.serialize(plain) + class RemoveUpstreamMapIndexesField(VersionChange): """Remove upstream_map_indexes field from TIRunContext - now computed by Task SDK.""" diff --git a/airflow-core/src/airflow/models/trigger.py b/airflow-core/src/airflow/models/trigger.py index da78eede343dd..d2c0fde3c8977 100644 --- a/airflow-core/src/airflow/models/trigger.py +++ b/airflow-core/src/airflow/models/trigger.py @@ -477,13 +477,15 @@ def handle_event_submit(event: TriggerEvent, *, task_instance: TaskInstance, ses next_kwargs = BaseSerialization.deserialize(next_kwargs_raw) - # Add event to the plain dict, then serialize everything together. This ensures that the event is properly - # nested inside __var__ in the final serde serialized structure. + # Add event to the plain dict, then serialize everything together so nested + # non-primitive values get proper serde encoding. if TYPE_CHECKING: assert isinstance(next_kwargs, dict) next_kwargs["event"] = event.payload - # re-serialize the entire dict using serde to ensure consistent structure + # Re-serialize using serde. The Execution API version converter + # (ModifyDeferredTaskKwargsToJsonValue) handles converting this to + # BaseSerialization format when serving old workers. task_instance.next_kwargs = serialize(next_kwargs) # Remove ourselves as its trigger diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/v2026_04_06/test_task_instances.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/v2026_04_06/test_task_instances.py index 8117ac6b69c61..a914ac6c6e563 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/v2026_04_06/test_task_instances.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/v2026_04_06/test_task_instances.py @@ -20,6 +20,7 @@ import pytest from airflow._shared.timezones import timezone +from airflow.serialization.serialized_objects import BaseSerialization from airflow.utils.state import DagRunState, State from tests_common.test_utils.db import clear_db_runs @@ -125,3 +126,129 @@ def test_old_version_preserves_real_start_date( assert response.status_code == 200 assert dag_run["start_date"] is not None, "start_date should not be None when DagRun has started" assert dag_run["start_date"] == TIMESTAMP.isoformat().replace("+00:00", "Z") + + +class TestNextKwargsBackwardCompat: + """Old workers only know BaseSerialization.deserialize -- SDK serde plain dicts cause KeyError.""" + + @pytest.fixture(autouse=True) + def _freeze_time(self, time_machine): + time_machine.move_to(TIMESTAMP_STR, tick=False) + + def setup_method(self): + clear_db_runs() + + def teardown_method(self): + clear_db_runs() + + def test_old_version_gets_base_serialization_format(self, old_ver_client, session, create_task_instance): + """Old API version receives next_kwargs wrapped in __type/__var so BaseSerialization can parse it.""" + ti = create_task_instance( + task_id="test_next_kwargs_compat", + state=State.QUEUED, + session=session, + start_date=TIMESTAMP, + ) + # Store SDK serde format (plain dict) in DB -- this is what trigger.py handle_event_submit produces + ti.next_method = "execute_complete" + ti.next_kwargs = {"cheesecake": True, "event": "payload"} + session.commit() + + response = old_ver_client.patch(f"/execution/task-instances/{ti.id}/run", json=RUN_PATCH_BODY) + + assert response.status_code == 200 + next_kwargs = response.json()["next_kwargs"] + # Old workers call BaseSerialization.deserialize on this -- verify it works + result = BaseSerialization.deserialize(next_kwargs) + assert result == {"cheesecake": True, "event": "payload"} + + def test_old_version_deserializes_complex_types(self, old_ver_client, session, create_task_instance): + """Non-primitive values (datetime) must round-trip through serde -> BaseSerialization correctly.""" + from airflow.sdk.serde import serialize as serde_serialize + + original = {"event": TIMESTAMP, "simple": True} + # Store SDK serde format with a datetime -- this is what handle_event_submit produces + # when the trigger payload contains a datetime (e.g. DateTimeSensorAsync) + serde_encoded = serde_serialize(original) + + ti = create_task_instance( + task_id="test_next_kwargs_datetime", + state=State.QUEUED, + session=session, + start_date=TIMESTAMP, + ) + ti.next_method = "execute_complete" + ti.next_kwargs = serde_encoded + session.commit() + + response = old_ver_client.patch(f"/execution/task-instances/{ti.id}/run", json=RUN_PATCH_BODY) + + assert response.status_code == 200 + next_kwargs = response.json()["next_kwargs"] + result = BaseSerialization.deserialize(next_kwargs) + assert result["simple"] is True + # datetime must come back as a datetime, not a {"__classname__": ...} dict + assert result["event"] == TIMESTAMP + + def test_old_version_handles_already_base_serialization_in_db( + self, old_ver_client, session, create_task_instance + ): + """Rolling upgrade: DB still has BaseSerialization format from old handle_event_submit.""" + ti = create_task_instance( + task_id="test_next_kwargs_already_base", + state=State.QUEUED, + session=session, + start_date=TIMESTAMP, + ) + ti.next_method = "execute_complete" + # Pre-upgrade data: BaseSerialization format already in DB + ti.next_kwargs = BaseSerialization.serialize({"cheesecake": True, "event": "payload"}) + session.commit() + + response = old_ver_client.patch(f"/execution/task-instances/{ti.id}/run", json=RUN_PATCH_BODY) + + assert response.status_code == 200 + next_kwargs = response.json()["next_kwargs"] + # Should still be parseable by old workers + result = BaseSerialization.deserialize(next_kwargs) + assert result == {"cheesecake": True, "event": "payload"} + + def test_old_version_handles_submit_failure_plain_dict( + self, old_ver_client, session, create_task_instance + ): + """submit_failure and scheduler timeout write raw plain dicts -- converter must handle those too.""" + ti = create_task_instance( + task_id="test_next_kwargs_failure", + state=State.QUEUED, + session=session, + start_date=TIMESTAMP, + ) + ti.next_method = "__fail__" + # This is what submit_failure / scheduler timeout writes -- plain dict, no wrapping + ti.next_kwargs = {"error": "Trigger timeout"} + session.commit() + + response = old_ver_client.patch(f"/execution/task-instances/{ti.id}/run", json=RUN_PATCH_BODY) + + assert response.status_code == 200 + next_kwargs = response.json()["next_kwargs"] + result = BaseSerialization.deserialize(next_kwargs) + assert result == {"error": "Trigger timeout"} + + def test_head_version_returns_raw_serde_format(self, client, session, create_task_instance): + """Head API version returns next_kwargs as-is (SDK serde format).""" + ti = create_task_instance( + task_id="test_next_kwargs_head", + state=State.QUEUED, + session=session, + start_date=TIMESTAMP, + ) + ti.next_method = "execute_complete" + ti.next_kwargs = {"cheesecake": True, "event": "payload"} + session.commit() + + response = client.patch(f"/execution/task-instances/{ti.id}/run", json=RUN_PATCH_BODY) + + assert response.status_code == 200 + # Head version gets the plain dict directly -- no BaseSerialization wrapping + assert response.json()["next_kwargs"] == {"cheesecake": True, "event": "payload"}