Skip to content

Commit 0ef0a6b

Browse files
authored
Fix BranchPythonOperator failure when callable returns None (#54991)
`BranchPythonOperator` now properly handles callables that return None by skipping all downstream tasks, instead of throwing an execution error. This restores the expected behavior for users who rely on None returns to skip branches conditionally. Fixes #54340
1 parent 6dd0169 commit 0ef0a6b

File tree

2 files changed

+38
-5
lines changed

2 files changed

+38
-5
lines changed

providers/standard/src/airflow/providers/standard/operators/branch.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,17 @@
3737
class BranchMixIn(SkipMixin):
3838
"""Utility helper which handles the branching as one-liner."""
3939

40-
def do_branch(self, context: Context, branches_to_execute: str | Iterable[str]) -> str | Iterable[str]:
40+
def do_branch(
41+
self, context: Context, branches_to_execute: str | Iterable[str] | None
42+
) -> str | Iterable[str] | None:
4143
"""Implement the handling of branching including logging."""
4244
self.log.info("Branch into %s", branches_to_execute)
43-
branch_task_ids = self._expand_task_group_roots(context["ti"], branches_to_execute)
44-
self.skip_all_except(context["ti"], branch_task_ids)
45+
if branches_to_execute is None:
46+
# When None is returned, skip all downstream tasks
47+
self.skip_all_except(context["ti"], None)
48+
else:
49+
branch_task_ids = self._expand_task_group_roots(context["ti"], branches_to_execute)
50+
self.skip_all_except(context["ti"], branch_task_ids)
4551
return branches_to_execute
4652

4753
def _expand_task_group_roots(
@@ -86,13 +92,13 @@ class BaseBranchOperator(BaseOperator, BranchMixIn):
8692

8793
inherits_from_skipmixin = True
8894

89-
def choose_branch(self, context: Context) -> str | Iterable[str]:
95+
def choose_branch(self, context: Context) -> str | Iterable[str] | None:
9096
"""
9197
Abstract method to choose which branch to run.
9298
9399
Subclasses should implement this, running whatever logic is
94100
necessary to choose a branch and returning a task_id or list of
95-
task_ids.
101+
task_ids. If None is returned, all downstream tasks will be skipped.
96102
97103
:param context: Context dictionary as passed to execute()
98104
"""

providers/standard/tests/unit/standard/operators/test_python.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,33 @@ def f():
552552
):
553553
ti.run()
554554

555+
def test_none_return_value_should_skip_all_downstream(self):
556+
"""Test that returning None from callable should skip all downstream tasks."""
557+
clear_db_runs()
558+
with self.dag_maker(self.dag_id, serialized=True):
559+
560+
def return_none():
561+
return None
562+
563+
branch_op = self.opcls(task_id=self.task_id, python_callable=return_none, **self.default_kwargs())
564+
branch_op >> [self.branch_1, self.branch_2]
565+
566+
dr = self.dag_maker.create_dagrun()
567+
if AIRFLOW_V_3_0_1:
568+
from airflow.exceptions import DownstreamTasksSkipped
569+
570+
with pytest.raises(DownstreamTasksSkipped) as dts:
571+
self.dag_maker.run_ti(self.task_id, dr)
572+
573+
# When None is returned, all downstream tasks should be skipped
574+
expected_skipped = {("branch_1", -1), ("branch_2", -1)}
575+
assert set(dts.value.tasks) == expected_skipped
576+
else:
577+
self.dag_maker.run_ti(self.task_id, dr)
578+
self.assert_expected_task_states(
579+
dr, {self.task_id: State.SUCCESS, "branch_1": State.SKIPPED, "branch_2": State.SKIPPED}
580+
)
581+
555582
@pytest.mark.parametrize(
556583
"choice,expected_states",
557584
[

0 commit comments

Comments
 (0)