Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
e3bad52
Use serde for next_kwargs for deferred tasks
amoghrajesh Dec 22, 2025
8ad4b7c
remove field serializer actually
amoghrajesh Dec 22, 2025
88aa2c5
remove field serializer actually
amoghrajesh Dec 22, 2025
fb077d9
remove field serializer actually
amoghrajesh Dec 22, 2025
13ffed4
remove field serializer actually
amoghrajesh Dec 22, 2025
fc88026
remove field serializer actually
amoghrajesh Dec 22, 2025
df6b61f
remove field serializer actually
amoghrajesh Dec 22, 2025
81d4218
remove field serializer actually
amoghrajesh Dec 22, 2025
75ff08b
handling compat: old worker, new api server
amoghrajesh Dec 23, 2025
4ecc73c
why did I remove these?
amoghrajesh Dec 23, 2025
2180e1b
move over the kwargs to JsonValue
amoghrajesh Dec 23, 2025
936b402
move over the kwargs to JsonValue
amoghrajesh Dec 23, 2025
43a23ab
move over the kwargs to JsonValue
amoghrajesh Dec 23, 2025
4784c19
move over the kwargs to JsonValue
amoghrajesh Dec 23, 2025
1989dd9
move over the kwargs to JsonValue
amoghrajesh Dec 23, 2025
bd77075
move over the kwargs to JsonValue
amoghrajesh Dec 23, 2025
377c7ac
fixing unit tests
amoghrajesh Dec 23, 2025
0647c5c
no need for the cast
amoghrajesh Dec 23, 2025
e26e20f
no need for the cast
amoghrajesh Dec 23, 2025
41d1f2d
fixing unit tests
amoghrajesh Dec 23, 2025
5085323
fixing dag.test command
amoghrajesh Dec 23, 2025
a797be0
fixing dag.test command
amoghrajesh Dec 23, 2025
82e94f1
fixing dag.test command
amoghrajesh Dec 23, 2025
177b0a4
fixing dag.test command
amoghrajesh Dec 23, 2025
28f186f
comments from kaxil
amoghrajesh Dec 24, 2025
bf18255
fixing tests
amoghrajesh Dec 24, 2025
a09ea92
fixing tests
amoghrajesh Dec 24, 2025
8c14feb
fixing inline triggers
amoghrajesh Dec 24, 2025
f8d5a04
changing to dict[str,JsonValue] for kwargs
amoghrajesh Dec 24, 2025
c73862f
adding a cadwyn migration
amoghrajesh Dec 24, 2025
7471764
Merge branch 'main' into use-serde-for-next-kwargs
amoghrajesh Dec 24, 2025
6a2044c
fixing provider tests
amoghrajesh Dec 29, 2025
b304cac
better fixing mypy issues
amoghrajesh Dec 29, 2025
a446ece
better fixing mypy issues
amoghrajesh Dec 29, 2025
85e7e38
better fixing mypy issues
amoghrajesh Dec 29, 2025
46f3f8f
use literal asserts
amoghrajesh Dec 29, 2025
cd6dd67
fixing upgrade scenario
amoghrajesh Dec 30, 2025
a07754a
fixing upgrade scenario
amoghrajesh Dec 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions airflow-core/.pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,7 @@ repos:
^src/airflow/models/taskmap\.py$|
^src/airflow/models/taskmixin\.py$|
^src/airflow/models/taskreschedule\.py$|
^src/airflow/models/trigger\.py$|
^src/airflow/models/variable\.py$|
^src/airflow/models/xcom\.py$|
^src/airflow/models/xcom_arg\.py$|
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ class TIDeferredStatePayload(StrictBaseModel):
),
]
classpath: str
trigger_kwargs: Annotated[dict[str, Any] | str, Field(default_factory=dict)]
trigger_kwargs: Annotated[dict[str, JsonValue] | str, Field(default_factory=dict)]
"""
Kwargs to pass to the trigger constructor, either a plain dict or an encrypted string.

Expand All @@ -139,7 +139,7 @@ class TIDeferredStatePayload(StrictBaseModel):
trigger_timeout: timedelta | None = None
next_method: str
"""The name of the method on the operator to call in the worker after the trigger has fired."""
next_kwargs: Annotated[dict[str, Any], Field(default_factory=dict)]
next_kwargs: Annotated[dict[str, JsonValue], Field(default_factory=dict)]
"""
Kwargs to pass to the above method, either a plain dict or an encrypted string.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -507,19 +507,12 @@ def _create_ti_state_update_query_and_update_state(

query = update(TI).where(TI.id == ti_id_str)

# This is slightly inefficient as we deserialize it to then right again serialize it in the sqla
# TypeAdapter.
next_kwargs = None
if ti_patch_payload.next_kwargs:
from airflow.serialization.serialized_objects import BaseSerialization

next_kwargs = BaseSerialization.deserialize(ti_patch_payload.next_kwargs)

# Store next_kwargs directly (already serialized by worker)
query = query.values(
state=TaskInstanceState.DEFERRED,
trigger_id=trigger_row.id,
next_method=ti_patch_payload.next_method,
next_kwargs=next_kwargs,
next_kwargs=ti_patch_payload.next_kwargs,
trigger_timeout=timeout,
)
updated_state = TaskInstanceState.DEFERRED
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,11 @@
AddDagRunDetailEndpoint,
MovePreviousRunEndpoint,
)
from airflow.api_fastapi.execution_api.versions.v2026_03_31 import ModifyDeferredTaskKwargsToJsonValue

bundle = VersionBundle(
HeadVersion(),
Version("2026-03-31", ModifyDeferredTaskKwargsToJsonValue),
Version("2025-12-08", MovePreviousRunEndpoint, AddDagRunDetailEndpoint),
Version("2025-11-07", AddPartitionKeyField),
Version("2025-11-05", AddTriggeringUserNameField),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

from typing import Any

from cadwyn import VersionChange, schema

from airflow.api_fastapi.execution_api.datamodels.taskinstance import TIDeferredStatePayload


class ModifyDeferredTaskKwargsToJsonValue(VersionChange):
"""Change the types of `trigger_kwargs` and `next_kwargs` in TIDeferredStatePayload to JsonValue."""

description = __doc__

instructions_to_migrate_to_previous_version = (
schema(TIDeferredStatePayload).field("trigger_kwargs").had(type=dict[str, Any] | str),
schema(TIDeferredStatePayload).field("next_kwargs").had(type=dict[str, Any]),
)
5 changes: 4 additions & 1 deletion airflow-core/src/airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import lazy_object_proxy
import uuid6
from sqlalchemy import (
JSON,
Float,
ForeignKey,
ForeignKeyConstraint,
Expand Down Expand Up @@ -436,7 +437,9 @@ class TaskInstance(Base, LoggingMixin):
# The method to call next, and any extra arguments to pass to it.
# Usually used when resuming from DEFERRED.
next_method: Mapped[str | None] = mapped_column(String(1000), nullable=True)
next_kwargs: Mapped[dict | None] = mapped_column(MutableDict.as_mutable(ExtendedJSON), nullable=True)
next_kwargs: Mapped[dict | None] = mapped_column(
MutableDict.as_mutable(JSON().with_variant(postgresql.JSONB, "postgresql")), nullable=True
)

_task_display_property_value: Mapped[str | None] = mapped_column(
"task_display_name", String(2000), nullable=True
Expand Down
5 changes: 4 additions & 1 deletion airflow-core/src/airflow/models/taskinstancehistory.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import dill
from sqlalchemy import (
JSON,
DateTime,
Float,
ForeignKeyConstraint,
Expand Down Expand Up @@ -109,7 +110,9 @@ class TaskInstanceHistory(Base):
trigger_id: Mapped[int | None] = mapped_column(Integer, nullable=True)
trigger_timeout: Mapped[DateTime | None] = mapped_column(DateTime, nullable=True)
next_method: Mapped[str | None] = mapped_column(String(1000), nullable=True)
next_kwargs: Mapped[dict | None] = mapped_column(MutableDict.as_mutable(ExtendedJSON), nullable=True)
next_kwargs: Mapped[dict | None] = mapped_column(
MutableDict.as_mutable(JSON().with_variant(postgresql.JSONB, "postgresql")), nullable=True
)

task_display_name: Mapped[str | None] = mapped_column(String(2000), nullable=True)
dag_version_id: Mapped[str | None] = mapped_column(UUIDType(binary=False), nullable=True)
Expand Down
39 changes: 30 additions & 9 deletions airflow-core/src/airflow/models/trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,9 @@ def encrypt_kwargs(kwargs: dict[str, Any]) -> str:
import json

from airflow.models.crypto import get_fernet
from airflow.serialization.serialized_objects import BaseSerialization
from airflow.sdk.serde import serialize

serialized_kwargs = BaseSerialization.serialize(kwargs)
serialized_kwargs = serialize(kwargs)
return get_fernet().encrypt(json.dumps(serialized_kwargs).encode("utf-8")).decode("utf-8")

@staticmethod
Expand All @@ -153,7 +153,7 @@ def _decrypt_kwargs(encrypted_kwargs: str) -> dict[str, Any]:
import json

from airflow.models.crypto import get_fernet
from airflow.serialization.serialized_objects import BaseSerialization
from airflow.sdk.serde import deserialize

# We weren't able to encrypt the kwargs in all migration paths,
# so we need to handle the case where they are not encrypted.
Expand All @@ -165,7 +165,16 @@ def _decrypt_kwargs(encrypted_kwargs: str) -> dict[str, Any]:
get_fernet().decrypt(encrypted_kwargs.encode("utf-8")).decode("utf-8")
)

return BaseSerialization.deserialize(decrypted_kwargs)
try:
result = deserialize(decrypted_kwargs)
if TYPE_CHECKING:
assert isinstance(result, dict)
return result
except (ImportError, KeyError, AttributeError, TypeError):
# Backward compatibility: fall back to BaseSerialization for old format
from airflow.serialization.serialized_objects import BaseSerialization

return BaseSerialization.deserialize(decrypted_kwargs)

def rotate_fernet_key(self):
"""Encrypts data with a new key. See: :ref:`security/fernet`."""
Expand Down Expand Up @@ -417,16 +426,28 @@ def handle_event_submit(event: TriggerEvent, *, task_instance: TaskInstance, ses
:param task_instance: The task instance to handle the submit event for.
:param session: The session to be used for the database callback sink.
"""
from airflow.sdk.serde import deserialize, serialize
from airflow.utils.state import TaskInstanceState

# Get the next kwargs of the task instance, or an empty dictionary if it doesn't exist
next_kwargs = task_instance.next_kwargs or {}
next_kwargs_raw = task_instance.next_kwargs or {}

# deserialize first to provide a compat layer if there are mixed serialized (BaseSerialisation and serde) data
# which can happen if a deferred task resumes after upgrade
try:
next_kwargs = deserialize(next_kwargs_raw)
except (ImportError, KeyError, AttributeError, TypeError):
from airflow.serialization.serialized_objects import BaseSerialization

next_kwargs = BaseSerialization.deserialize(next_kwargs_raw)

# Add the event's payload into the kwargs for the task
# Add event to the plain dict, then serialize everything together. This ensures that the event is properly
# nested inside __var__ in the final serde serialized structure.
if TYPE_CHECKING:
assert isinstance(next_kwargs, dict)
next_kwargs["event"] = event.payload

# Update the next kwargs of the task instance
task_instance.next_kwargs = next_kwargs
# re-serialize the entire dict using serde to ensure consistent structure
task_instance.next_kwargs = serialize(next_kwargs)

# Remove ourselves as its trigger
task_instance.trigger_id = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -579,8 +579,23 @@ def test_next_kwargs_still_encoded(self, client, session, create_task_instance,
)

ti.next_method = "execute_complete"
# ti.next_kwargs under the hood applies the serde encoding for us
ti.next_kwargs = {"moment": instant}
# explicitly use serde serialized value before assigning since we use JSON/JSONB now
# that this value comes serde serialized from the worker
expected_next_kwargs = {
"moment": {
"__classname__": "pendulum.datetime.DateTime",
"__version__": 2,
"__data__": {
"timestamp": 1727697600.0,
"tz": {
"__classname__": "builtins.tuple",
"__version__": 1,
"__data__": ["UTC", "pendulum.tz.timezone.Timezone", 1, True],
},
},
}
}
ti.next_kwargs = expected_next_kwargs

session.commit()

Expand All @@ -606,10 +621,7 @@ def test_next_kwargs_still_encoded(self, client, session, create_task_instance,
"connections": [],
"xcom_keys_to_clear": [],
"next_method": "execute_complete",
"next_kwargs": {
"__type": "dict",
"__var": {"moment": {"__type": "datetime", "__var": 1727697600.0}},
},
"next_kwargs": expected_next_kwargs,
}

@pytest.mark.parametrize("resume", [True, False])
Expand All @@ -632,14 +644,26 @@ def test_next_kwargs_determines_start_date_update(self, client, session, create_
second_start_time = orig_task_start_time.add(seconds=30)
second_start_time_str = second_start_time.isoformat()

# ti.next_kwargs under the hood applies the serde encoding for us
# explicitly serialize using serde before assigning since we use JSON/JSONB now
# this value comes serde serialized from the worker
if resume:
ti.next_kwargs = {"moment": second_start_time}
expected_start_date = orig_task_start_time
# expected format is now in serde serialized format
expected_next_kwargs = {
"__type": "dict",
"__var": {"moment": {"__type": "datetime", "__var": second_start_time.timestamp()}},
"moment": {
"__classname__": "pendulum.datetime.DateTime",
"__version__": 2,
"__data__": {
"timestamp": 1727697635.0,
"tz": {
"__classname__": "builtins.tuple",
"__version__": 1,
"__data__": ["UTC", "pendulum.tz.timezone.Timezone", 1, True],
},
},
}
}
ti.next_kwargs = expected_next_kwargs
expected_start_date = orig_task_start_time
else:
expected_start_date = second_start_time
expected_next_kwargs = None
Expand Down Expand Up @@ -1123,17 +1147,40 @@ def test_ti_update_state_to_deferred(self, client, session, create_task_instance

payload = {
"state": "deferred",
# Raw payload is already "encoded", but not encrypted
# expected format is now in serde serialized format
"trigger_kwargs": {
"__type": "dict",
"__var": {"key": "value", "moment": {"__type": "datetime", "__var": 1734480001.0}},
"key": "value",
"moment": {
"__classname__": "datetime.datetime",
"__version__": 2,
"__data__": {
"timestamp": 1734480001.0,
"tz": {
"__classname__": "builtins.tuple",
"__version__": 1,
"__data__": ["UTC", "pendulum.tz.timezone.Timezone", 1, True],
},
},
},
},
"trigger_timeout": "P1D", # 1 day
"classpath": "my-classpath",
"next_method": "execute_callback",
# expected format is now in serde serialized format
"next_kwargs": {
"__type": "dict",
"__var": {"foo": {"__type": "datetime", "__var": 1734480000.0}, "bar": "abc"},
"foo": {
"__classname__": "datetime.datetime",
"__version__": 2,
"__data__": {
"timestamp": 1734480000.0,
"tz": {
"__classname__": "builtins.tuple",
"__version__": 1,
"__data__": ["UTC", "pendulum.tz.timezone.Timezone", 1, True],
},
},
},
"bar": "abc",
},
}

Expand All @@ -1149,9 +1196,21 @@ def test_ti_update_state_to_deferred(self, client, session, create_task_instance

assert tis[0].state == TaskInstanceState.DEFERRED
assert tis[0].next_method == "execute_callback"

assert tis[0].next_kwargs == {
"foo": {
"__classname__": "datetime.datetime",
"__version__": 2,
"__data__": {
"timestamp": 1734480000.0,
"tz": {
"__classname__": "builtins.tuple",
"__version__": 1,
"__data__": ["UTC", "pendulum.tz.timezone.Timezone", 1, True],
},
},
},
"bar": "abc",
"foo": datetime(2024, 12, 18, 00, 00, 00, tzinfo=timezone.utc),
}
assert tis[0].trigger_timeout == timezone.make_aware(datetime(2024, 11, 23), timezone=timezone.utc)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
# under the License.
from __future__ import annotations

from uuid import UUID

import pytest

from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS, AIRFLOW_V_3_2_PLUS
Expand Down Expand Up @@ -252,8 +254,12 @@ def test_execute(self, dag_maker: DagMaker, session: Session) -> None:
assert notifier.called is True

expected_params_in_trigger_kwargs: dict[str, dict[str, Any]]
# trigger_kwargs are encoded via BaseSerialization in versions < 3.2
expected_ti_id = ti.id
if AIRFLOW_V_3_2_PLUS:
expected_params_in_trigger_kwargs = expected_params
# trigger_kwargs are encoded via serde from task sdk in versions >= 3.2
expected_ti_id = UUID(ti.id)
else:
expected_params_in_trigger_kwargs = {"input_1": {"value": 1, "description": None, "schema": {}}}

Expand All @@ -262,7 +268,7 @@ def test_execute(self, dag_maker: DagMaker, session: Session) -> None:
)
assert registered_trigger is not None
assert registered_trigger.kwargs == {
"ti_id": ti.id,
"ti_id": expected_ti_id,
"options": ["1", "2", "3", "4", "5"],
"defaults": ["1"],
"params": expected_params_in_trigger_kwargs,
Expand Down
6 changes: 3 additions & 3 deletions task-sdk/src/airflow/sdk/api/datamodels/_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, JsonValue, RootModel

API_VERSION: Final[str] = "2025-12-08"
API_VERSION: Final[str] = "2026-03-31"


class AssetAliasReferenceAssetEventDagRun(BaseModel):
Expand Down Expand Up @@ -198,10 +198,10 @@ class TIDeferredStatePayload(BaseModel):
)
state: Annotated[Literal["deferred"] | None, Field(title="State")] = "deferred"
classpath: Annotated[str, Field(title="Classpath")]
trigger_kwargs: Annotated[dict[str, Any] | str | None, Field(title="Trigger Kwargs")] = None
trigger_kwargs: Annotated[dict[str, JsonValue] | str | None, Field(title="Trigger Kwargs")] = None
trigger_timeout: Annotated[timedelta | None, Field(title="Trigger Timeout")] = None
next_method: Annotated[str, Field(title="Next Method")]
next_kwargs: Annotated[dict[str, Any] | None, Field(title="Next Kwargs")] = None
next_kwargs: Annotated[dict[str, JsonValue] | None, Field(title="Next Kwargs")] = None
rendered_map_index: Annotated[str | None, Field(title="Rendered Map Index")] = None


Expand Down
Loading
Loading