Skip to content

Commit 89a6807

Browse files
committed
Change on_*_callback on tasks to use has_on_*_callback
1 parent c917143 commit 89a6807

8 files changed

Lines changed: 245 additions & 48 deletions

File tree

airflow-core/src/airflow/jobs/scheduler_job_runner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -896,7 +896,7 @@ def process_executor_events(
896896
ti.set_state(state)
897897
continue
898898
ti.task = task
899-
if task.on_retry_callback or task.on_failure_callback:
899+
if task.has_on_retry_callback or task.has_on_failure_callback:
900900
# Only log the error/extra info here, since the `ti.handle_failure()` path will log it
901901
# too, which would lead to double logging
902902
cls.logger().error(msg)
@@ -2028,7 +2028,7 @@ def _maybe_requeue_stuck_ti(self, *, ti, session, executor):
20282028
exc_info=True,
20292029
)
20302030
else:
2031-
if task.on_failure_callback:
2031+
if task.has_on_failure_callback:
20322032
if inspect(ti).detached:
20332033
ti = session.merge(ti)
20342034
request = TaskCallbackRequest(

airflow-core/src/airflow/models/dagrun.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1998,8 +1998,8 @@ def schedule_tis(
19981998
assert isinstance(task, Operator)
19991999
if (
20002000
task.inherits_from_empty_operator
2001-
and not task.on_execute_callback
2002-
and not task.on_success_callback
2001+
and not task.has_on_execute_callback
2002+
and not task.has_on_success_callback
20032003
and not task.outlets
20042004
and not task.inlets
20052005
):

airflow-core/src/airflow/models/mappedoperator.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@
4242
DEFAULT_TRIGGER_RULE,
4343
DEFAULT_WEIGHT_RULE,
4444
NotMapped,
45-
TaskStateChangeCallbackAttrType,
4645
)
4746
from airflow.sdk.definitions._internal.node import DAGNode
4847
from airflow.sdk.definitions.mappedoperator import MappedOperator as TaskSDKMappedOperator
@@ -242,24 +241,24 @@ def max_active_tis_per_dagrun(self) -> int | None:
242241
return self.partial_kwargs.get("max_active_tis_per_dagrun")
243242

244243
@property
245-
def on_execute_callback(self) -> TaskStateChangeCallbackAttrType:
246-
return self.partial_kwargs.get("on_execute_callback") or []
244+
def has_on_execute_callback(self) -> bool:
245+
return bool(self.partial_kwargs.get("has_on_execute_callback", False))
247246

248247
@property
249-
def on_failure_callback(self) -> TaskStateChangeCallbackAttrType:
250-
return self.partial_kwargs.get("on_failure_callback") or []
248+
def has_on_failure_callback(self) -> bool:
249+
return bool(self.partial_kwargs.get("has_on_failure_callback", False))
251250

252251
@property
253-
def on_retry_callback(self) -> TaskStateChangeCallbackAttrType:
254-
return self.partial_kwargs.get("on_retry_callback") or []
252+
def has_on_retry_callback(self) -> bool:
253+
return bool(self.partial_kwargs.get("has_on_retry_callback", False))
255254

256255
@property
257-
def on_success_callback(self) -> TaskStateChangeCallbackAttrType:
258-
return self.partial_kwargs.get("on_success_callback") or []
256+
def has_on_success_callback(self) -> bool:
257+
return bool(self.partial_kwargs.get("has_on_success_callback", False))
259258

260259
@property
261-
def on_skipped_callback(self) -> TaskStateChangeCallbackAttrType:
262-
return self.partial_kwargs.get("on_skipped_callback") or []
260+
def has_on_skipped_callback(self) -> bool:
261+
return bool(self.partial_kwargs.get("has_on_skipped_callback", False))
263262

264263
@property
265264
def run_as_user(self) -> str | None:

airflow-core/src/airflow/serialization/schema.json

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -316,11 +316,11 @@
316316
"allow_nested_operators": { "type": "boolean", "default": true },
317317
"inlets": {"type": "array", "default": []},
318318
"outlets": {"type": "array", "default": []},
319-
"on_execute_callback": {"type": "array", "default": []},
320-
"on_failure_callback": {"type": "array", "default": []},
321-
"on_skipped_callback": {"type": "array", "default": []},
322-
"on_success_callback": {"type": "array", "default": []},
323-
"on_retry_callback": {"type": "array", "default": []},
319+
"has_on_execute_callback": {"type": "boolean", "default": false},
320+
"has_on_failure_callback": {"type": "boolean", "default": false},
321+
"has_on_skipped_callback": {"type": "boolean", "default": false},
322+
"has_on_success_callback": {"type": "boolean", "default": false},
323+
"has_on_retry_callback": {"type": "boolean", "default": false},
324324
"multiple_outputs": {"type": "boolean", "default": false},
325325
"start_from_trigger": {"type": "boolean", "default": false},
326326
"is_setup": {"type": "boolean", "default": false},

airflow-core/src/airflow/serialization/serialized_objects.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@
5252
from airflow.models.xcom_arg import SchedulerXComArg, deserialize_xcom_arg
5353
from airflow.sdk import Asset, AssetAlias, AssetAll, AssetAny, AssetWatcher, BaseOperator, XComArg
5454
from airflow.sdk.bases.trigger import StartTriggerArgs
55-
from airflow.sdk.definitions._internal.expandinput import EXPAND_INPUT_EMPTY
5655
from airflow.sdk.definitions._internal.node import DAGNode
5756
from airflow.sdk.definitions.asset import (
5857
AssetAliasEvent,
@@ -1221,6 +1220,7 @@ class SerializedBaseOperator(DAGNode, BaseSerialization):
12211220
_is_empty: bool
12221221
_needs_expansion: bool
12231222
_task_display_name: str | None
1223+
_on_failure_fail_dagrun: bool = False
12241224

12251225
dag: DAG | None = None
12261226
task_group: TaskGroup | None = None
@@ -1255,12 +1255,12 @@ class SerializedBaseOperator(DAGNode, BaseSerialization):
12551255
max_retry_delay: datetime.timedelta | float | None = None
12561256
multiple_outputs: bool = False
12571257

1258-
# TODO: Can be changed to () instead
1259-
on_execute_callback: Sequence = []
1260-
on_failure_callback: Sequence = []
1261-
on_retry_callback: Sequence = []
1262-
on_success_callback: Sequence = []
1263-
on_skipped_callback: Sequence = []
1258+
# Boolean flags for callback existence
1259+
has_on_execute_callback: bool = False
1260+
has_on_failure_callback: bool = False
1261+
has_on_retry_callback: bool = False
1262+
has_on_success_callback: bool = False
1263+
has_on_skipped_callback: bool = False
12641264

12651265
operator_extra_links: Collection[BaseOperatorLink] = ()
12661266

@@ -1444,6 +1444,11 @@ def serialize_mapped_operator(cls, op: MappedOperator | SchedulerMappedOperator)
14441444
continue
14451445
if cls._is_excluded(v, k, op):
14461446
continue
1447+
1448+
if k in [f"on_{x}_callback" for x in ("execute", "failure", "success", "retry", "skipped")]:
1449+
if bool(v):
1450+
serialized_op["partial_kwargs"][f"has_{k}"] = True
1451+
continue
14471452
serialized_op["partial_kwargs"].update({k: cls.serialize(v)})
14481453

14491454
serialized_op["_is_mapped"] = True
@@ -1765,6 +1770,9 @@ def _is_excluded(cls, var: Any, attrname: str, op: DAGNode):
17651770
if var is dag_date or var == dag_date:
17661771
return True
17671772

1773+
# If none of the exclusion conditions are met, don't exclude the field
1774+
return False
1775+
17681776
@classmethod
17691777
def _deserialize_operator_extra_links(
17701778
cls, encoded_op_links: dict[str, str]
@@ -1885,12 +1893,11 @@ def get_serialized_fields(cls):
18851893
"max_active_tis_per_dagrun",
18861894
"max_retry_delay",
18871895
"multiple_outputs",
1888-
"on_execute_callback",
1889-
"on_failure_callback",
1890-
"on_failure_fail_dagrun",
1891-
"on_retry_callback",
1892-
"on_skipped_callback",
1893-
"on_success_callback",
1896+
"has_on_execute_callback",
1897+
"has_on_failure_callback",
1898+
"has_on_retry_callback",
1899+
"has_on_skipped_callback",
1900+
"has_on_success_callback",
18941901
"outlets",
18951902
"owner",
18961903
"params",

airflow-core/tests/unit/serialization/test_dag_serialization.py

Lines changed: 146 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -742,10 +742,6 @@ def validate_deserialized_task(
742742
# We store the string, real dag has the actual code
743743
"_pre_execute_hook",
744744
"_post_execute_hook",
745-
"on_execute_callback",
746-
"on_failure_callback",
747-
"on_success_callback",
748-
"on_retry_callback",
749745
# Checked separately
750746
"resources",
751747
"on_failure_fail_dagrun",
@@ -811,11 +807,23 @@ def validate_deserialized_task(
811807
default_partial_kwargs = (
812808
BaseOperator.partial(task_id="_")._expand(EXPAND_INPUT_EMPTY, strict=False).partial_kwargs
813809
)
810+
811+
# These are added in `_TaskDecorator` e.g. when @setup or @teardown task is passed
812+
default_decorator_partial_kwargs = {
813+
"is_setup": False,
814+
"is_teardown": False,
815+
"on_failure_fail_dagrun": False,
816+
}
814817
serialized_partial_kwargs = {
815818
**default_partial_kwargs,
819+
**default_decorator_partial_kwargs,
816820
**serialized_task.partial_kwargs,
817821
}
818-
original_partial_kwargs = {**default_partial_kwargs, **task.partial_kwargs}
822+
original_partial_kwargs = {
823+
**default_partial_kwargs,
824+
**default_decorator_partial_kwargs,
825+
**task.partial_kwargs,
826+
}
819827
assert serialized_partial_kwargs == original_partial_kwargs
820828

821829
# ExpandInputs have different classes between scheduler and definition
@@ -1415,6 +1423,11 @@ def test_no_new_fields_added_to_base_operator(self):
14151423
"execution_timeout": None,
14161424
"executor": None,
14171425
"executor_config": {},
1426+
"has_on_execute_callback": False,
1427+
"has_on_failure_callback": False,
1428+
"has_on_retry_callback": False,
1429+
"has_on_skipped_callback": False,
1430+
"has_on_success_callback": False,
14181431
"ignore_first_depends_on_past": False,
14191432
"is_setup": False,
14201433
"is_teardown": False,
@@ -1423,12 +1436,7 @@ def test_no_new_fields_added_to_base_operator(self):
14231436
"max_active_tis_per_dag": None,
14241437
"max_active_tis_per_dagrun": None,
14251438
"max_retry_delay": None,
1426-
"on_execute_callback": [],
14271439
"on_failure_fail_dagrun": False,
1428-
"on_failure_callback": [],
1429-
"on_retry_callback": [],
1430-
"on_skipped_callback": [],
1431-
"on_success_callback": [],
14321440
"outlets": [],
14331441
"owner": "airflow",
14341442
"params": {},
@@ -3011,6 +3019,8 @@ def operator_extra_links(self):
30113019
assert mapped_task.extra_links == sorted({"airflow", "github"})
30123020

30133021

3022+
# TODO: Remove xfail
3023+
@pytest.mark.xfail(reason="TODO: Need to add support for v1 & v2 to v3")
30143024
def test_handle_v1_serdag():
30153025
v1 = {
30163026
"__version": 1,
@@ -3296,3 +3306,129 @@ def test_handle_v1_serdag():
32963306
del expected["dag"]["tasks"][1]["__var"]["_operator_extra_links"]
32973307

32983308
assert v1 == expected
3309+
3310+
3311+
def dummy_callback():
3312+
pass
3313+
3314+
3315+
@pytest.mark.parametrize(
3316+
"callback_config,expected_flags,is_mapped",
3317+
[
3318+
# Regular operator tests
3319+
(
3320+
{
3321+
"on_failure_callback": dummy_callback,
3322+
"on_retry_callback": [dummy_callback, dummy_callback],
3323+
"on_success_callback": dummy_callback,
3324+
},
3325+
{"has_on_failure_callback": True, "has_on_retry_callback": True, "has_on_success_callback": True},
3326+
False,
3327+
),
3328+
(
3329+
{}, # No callbacks
3330+
{
3331+
"has_on_failure_callback": False,
3332+
"has_on_retry_callback": False,
3333+
"has_on_success_callback": False,
3334+
},
3335+
False,
3336+
),
3337+
(
3338+
{"on_failure_callback": [], "on_success_callback": None}, # Empty callbacks
3339+
{"has_on_failure_callback": False, "has_on_success_callback": False},
3340+
False,
3341+
),
3342+
# Mapped operator tests
3343+
(
3344+
{"on_failure_callback": dummy_callback, "on_success_callback": [dummy_callback, dummy_callback]},
3345+
{"has_on_failure_callback": True, "has_on_success_callback": True},
3346+
True,
3347+
),
3348+
(
3349+
{}, # Mapped operator without callbacks
3350+
{"has_on_failure_callback": False, "has_on_success_callback": False},
3351+
True,
3352+
),
3353+
],
3354+
)
3355+
def test_task_callback_boolean_optimization(callback_config, expected_flags, is_mapped):
3356+
"""Test that task callbacks are optimized using has_on_*_callback boolean flags."""
3357+
dag = DAG(dag_id="test_callback_dag", start_date=datetime(2020, 1, 1))
3358+
3359+
if is_mapped:
3360+
# Create mapped operator
3361+
task = BashOperator.partial(task_id="test_task", dag=dag, **callback_config).expand(
3362+
bash_command=["echo 1", "echo 2"]
3363+
)
3364+
3365+
# Serialize and deserialize
3366+
serialized = BaseSerialization.serialize(task)
3367+
deserialized = BaseSerialization.deserialize(serialized)
3368+
3369+
# For mapped operators, check partial_kwargs
3370+
serialized_data = serialized.get("__var", {}).get("partial_kwargs", {})
3371+
3372+
# Test serialization
3373+
for flag, expected in expected_flags.items():
3374+
if expected:
3375+
assert flag in serialized_data
3376+
assert serialized_data[flag] is True
3377+
else:
3378+
assert serialized_data.get(flag, False) is False
3379+
3380+
# Test deserialized properties
3381+
for flag, expected in expected_flags.items():
3382+
assert getattr(deserialized, flag) is expected
3383+
3384+
else:
3385+
# Create regular operator
3386+
task = BashOperator(task_id="test_task", bash_command="echo test", dag=dag, **callback_config)
3387+
3388+
# Serialize and deserialize
3389+
serialized = BaseSerialization.serialize(task)
3390+
deserialized = BaseSerialization.deserialize(serialized)
3391+
3392+
# For regular operators, check top-level
3393+
serialized_data = serialized.get("__var", {})
3394+
3395+
# Test serialization (only True values are stored)
3396+
for flag, expected in expected_flags.items():
3397+
if expected:
3398+
assert serialized_data.get(flag, False) is True
3399+
else:
3400+
assert serialized_data.get(flag, False) is False
3401+
3402+
# Test deserialized properties
3403+
for flag, expected in expected_flags.items():
3404+
assert getattr(deserialized, flag) is expected
3405+
3406+
3407+
def test_task_callback_properties_exist():
3408+
"""Test that all callback boolean properties exist on both regular and mapped operators."""
3409+
dag = DAG(dag_id="test_dag", start_date=datetime(2020, 1, 1))
3410+
3411+
# Regular operator
3412+
regular_task = BashOperator(task_id="regular", bash_command="echo test", dag=dag)
3413+
3414+
# Mapped operator
3415+
mapped_task = BashOperator.partial(task_id="mapped", dag=dag).expand(bash_command=["echo 1"])
3416+
3417+
callback_properties = [
3418+
"has_on_execute_callback",
3419+
"has_on_failure_callback",
3420+
"has_on_success_callback",
3421+
"has_on_retry_callback",
3422+
"has_on_skipped_callback",
3423+
]
3424+
3425+
for prop in callback_properties:
3426+
assert hasattr(regular_task, prop), f"Regular operator missing {prop}"
3427+
assert hasattr(mapped_task, prop), f"Mapped operator missing {prop}"
3428+
3429+
# Serialize and check deserialized versions too
3430+
serialized_regular = BaseSerialization.deserialize(BaseSerialization.serialize(regular_task))
3431+
serialized_mapped = BaseSerialization.deserialize(BaseSerialization.serialize(mapped_task))
3432+
3433+
assert hasattr(serialized_regular, prop), f"Deserialized regular operator missing {prop}"
3434+
assert hasattr(serialized_mapped, prop), f"Deserialized mapped operator missing {prop}"

0 commit comments

Comments
 (0)