Skip to content

Commit 173c2a1

Browse files
Recover stuck TIs when direct terminal-state API call fails (#66574)
* Recover stuck TIs when direct terminal-state API call fails The supervisor's _handle_request for SucceedTask, RetryTask, DeferTask, and RescheduleTask set _terminal_state BEFORE calling the matching client.task_instances.{succeed,retry,defer,reschedule}() API. If that API call raised (transient network blip, server 5xx, etc.), _terminal_state was set on the supervisor but the server never saw the transition. The supervisor's update_task_state_if_needed then saw final_state in STATES_SENT_DIRECTLY and short-circuited the recovery finish() call -- leaving the TaskInstance stuck RUNNING on the server forever, blocking downstream dependencies and triggering false alerts. Two-part fix: 1. Make the direct API call FIRST. Only set _terminal_state and the new _terminal_state_synced_to_server flag after the call returns successfully. If the API raises, both stay unset and the exception propagates to handle_requests, where the existing catch-all sends an ErrorResponse to the task subprocess. 2. Have update_task_state_if_needed always call finish() when _terminal_state_synced_to_server is False, regardless of what final_state happens to return. The finish() API takes the state value, so a SUCCESS / DEFERRED / etc. transition that originally failed is re-attempted via finish() on subprocess exit. Pre-existing semantics for the no-direct-API states (FAILED, UP_FOR_RETRY without RetryTask, etc.) preserved -- those land in the same finish() branch. Tests added: - _terminal_state not set when succeed() raises. - update_task_state_if_needed calls finish() when synced flag is False, even with final_state == SUCCESS. - update_task_state_if_needed skips finish() when synced flag is True (preserves the existing happy-path optimisation). Reported by the L3 ASVS sweep at apache/tooling-agents#24 (FINDING-007). * Refactor terminal-state dispatch and parametrize tests across all 4 states Address review feedback on #66574: - Extract `_send_terminal_state_msg` helper so the per-msg-type dispatch for succeed / retry / defer / reschedule lives in one place. Both `_handle_request` and `_replay_pending_terminal_state_msg` now go through it instead of duplicating the four-branch isinstance chain. - Parametrize the two recovery tests over all four terminal-state message types (was only Succeed + Defer); add UP_FOR_RETRY and UP_FOR_RESCHEDULE coverage. * Narrow _pending_terminal_state_msg type to satisfy mypy The field was annotated as BaseModel | None, but _send_terminal_state_msg expects SucceedTask | RetryTask | DeferTask | RescheduleTask. mypy couldn't prove the narrowing at the _replay_pending_terminal_state_msg call site. Tighten the field type to the exact union the setter assigns and the consumer accepts. --------- Co-authored-by: vatsrahul1001 <rah.sharma11@gmail.com> Co-authored-by: Rahul Vats <43964496+vatsrahul1001@users.noreply.github.com>
1 parent f9faf65 commit 173c2a1

2 files changed

Lines changed: 219 additions & 24 deletions

File tree

task-sdk/src/airflow/sdk/execution_time/supervisor.py

Lines changed: 88 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1152,6 +1152,18 @@ class ActivitySubprocess(WatchedSubprocess):
11521152

11531153
_terminal_state: str | None = attrs.field(default=None, init=False)
11541154
_final_state: str | None = attrs.field(default=None, init=False)
1155+
# The terminal-state message currently being processed by `_handle_request`,
1156+
# captured BEFORE the dedicated API call (succeed / retry / defer /
1157+
# reschedule). If the API call raises (network blip, server 5xx, etc.),
1158+
# this attribute stays set and the dispatcher in
1159+
# `update_task_state_if_needed` re-issues the matching API call on
1160+
# subprocess exit — re-attempting the original transition rather than
1161+
# falling back to `finish()`, which doesn't accept SUCCESS / DEFERRED /
1162+
# SERVER_TERMINATED on the server side. Cleared (and `_terminal_state`
1163+
# set) only after the API call returns successfully.
1164+
_pending_terminal_state_msg: SucceedTask | RetryTask | DeferTask | RescheduleTask | None = attrs.field(
1165+
default=None, init=False
1166+
)
11551167

11561168
_last_successful_heartbeat: float = attrs.field(default=0, init=False)
11571169
_last_heartbeat_attempt: float = attrs.field(default=0, init=False)
@@ -1269,10 +1281,23 @@ def wait(self) -> int:
12691281
return self._exit_code
12701282

12711283
def update_task_state_if_needed(self):
1272-
# If the process has finished non-directly patched state (directly means deferred, reschedule, etc.),
1273-
# update the state of the TaskInstance to reflect the final state of the process.
1274-
# For states like `deferred`, `up_for_reschedule`, the process will exit with 0, but the state will be updated
1275-
# by the subprocess in the `handle_requests` method.
1284+
# If a direct-state API call (succeed / retry / defer / reschedule)
1285+
# was attempted but raised, `_pending_terminal_state_msg` still holds
1286+
# the original request. Re-issue the matching dedicated API call so
1287+
# the server learns the terminal state we couldn't deliver earlier.
1288+
# Without this recovery, a transient API failure during the direct
1289+
# call would leave the TI stuck RUNNING on the server — `finish()`
1290+
# cannot substitute because the server-side `finish` endpoint does
1291+
# not accept SUCCESS / DEFERRED / SERVER_TERMINATED transitions.
1292+
if self._pending_terminal_state_msg is not None:
1293+
self._replay_pending_terminal_state_msg()
1294+
return
1295+
1296+
# If the process has finished a non-directly-patched state (e.g.
1297+
# FAILED, UP_FOR_RETRY without RetryTask), `finish()` is the
1298+
# dedicated endpoint for those transitions. For states already in
1299+
# STATES_SENT_DIRECTLY whose direct API call succeeded, no further
1300+
# action is needed.
12761301
if self.final_state not in STATES_SENT_DIRECTLY:
12771302
self.client.task_instances.finish(
12781303
id=self.id,
@@ -1281,6 +1306,58 @@ def update_task_state_if_needed(self):
12811306
rendered_map_index=self._rendered_map_index,
12821307
)
12831308

1309+
def _send_terminal_state_msg(self, msg: SucceedTask | RetryTask | DeferTask | RescheduleTask) -> None:
1310+
# Capture the message BEFORE the API call so the recovery dispatcher
1311+
# in `update_task_state_if_needed` can re-issue it if the call raises
1312+
# (network blip, transient server 5xx). Clear the pending slot and
1313+
# record the resulting state only after the call returns successfully.
1314+
self._pending_terminal_state_msg = msg
1315+
if isinstance(msg, SucceedTask):
1316+
self.client.task_instances.succeed(
1317+
id=self.id,
1318+
when=msg.end_date,
1319+
task_outlets=msg.task_outlets,
1320+
outlet_events=msg.outlet_events,
1321+
rendered_map_index=self._rendered_map_index,
1322+
)
1323+
self._terminal_state = msg.state
1324+
elif isinstance(msg, RetryTask):
1325+
self.client.task_instances.retry(
1326+
id=self.id,
1327+
end_date=msg.end_date,
1328+
rendered_map_index=self._rendered_map_index,
1329+
retry_delay_seconds=getattr(msg, "retry_delay_seconds", None),
1330+
retry_reason=getattr(msg, "retry_reason", None),
1331+
)
1332+
self._terminal_state = msg.state
1333+
elif isinstance(msg, DeferTask):
1334+
self.client.task_instances.defer(self.id, msg)
1335+
self._terminal_state = TaskInstanceState.DEFERRED
1336+
elif isinstance(msg, RescheduleTask):
1337+
self.client.task_instances.reschedule(self.id, msg)
1338+
self._terminal_state = TaskInstanceState.UP_FOR_RESCHEDULE
1339+
self._pending_terminal_state_msg = None
1340+
1341+
def _replay_pending_terminal_state_msg(self) -> None:
1342+
"""
1343+
Re-issue the dedicated API call for an unsynced terminal-state msg.
1344+
1345+
Best-effort — if the second attempt also fails the exception is
1346+
logged and we move on; the supervisor's overall failure handling
1347+
(heartbeat, exit-code reporting) will eventually surface the issue.
1348+
"""
1349+
msg = self._pending_terminal_state_msg
1350+
if msg is None:
1351+
return
1352+
try:
1353+
self._send_terminal_state_msg(msg)
1354+
except Exception:
1355+
log.exception(
1356+
"Recovery retry of terminal-state API call failed; TI may be stuck on the server",
1357+
ti_id=self.id,
1358+
msg_type=type(msg).__name__,
1359+
)
1360+
12841361
def _upload_logs(self):
12851362
"""
12861363
Upload all log files found to the remote storage.
@@ -1452,31 +1529,20 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id:
14521529
resp: BaseModel | None = None
14531530
dump_opts: dict[str, bool] = {}
14541531
if isinstance(msg, TaskState):
1532+
# No direct API call here — the recovery path in
1533+
# `update_task_state_if_needed` will call `finish()` for
1534+
# non-direct states (FAILED, etc.) once the subprocess exits.
14551535
self._terminal_state = msg.state
14561536
self._task_end_time_monotonic = time.monotonic()
14571537
self._rendered_map_index = msg.rendered_map_index
14581538
elif isinstance(msg, SucceedTask):
1459-
self._terminal_state = msg.state
14601539
self._task_end_time_monotonic = time.monotonic()
14611540
self._rendered_map_index = msg.rendered_map_index
1462-
self.client.task_instances.succeed(
1463-
id=self.id,
1464-
when=msg.end_date,
1465-
task_outlets=msg.task_outlets,
1466-
outlet_events=msg.outlet_events,
1467-
rendered_map_index=self._rendered_map_index,
1468-
)
1541+
self._send_terminal_state_msg(msg)
14691542
elif isinstance(msg, RetryTask):
1470-
self._terminal_state = msg.state
14711543
self._task_end_time_monotonic = time.monotonic()
14721544
self._rendered_map_index = msg.rendered_map_index
1473-
self.client.task_instances.retry(
1474-
id=self.id,
1475-
end_date=msg.end_date,
1476-
rendered_map_index=self._rendered_map_index,
1477-
retry_delay_seconds=getattr(msg, "retry_delay_seconds", None),
1478-
retry_reason=getattr(msg, "retry_reason", None),
1479-
)
1545+
self._send_terminal_state_msg(msg)
14801546
elif isinstance(msg, GetConnection):
14811547
resp, dump_opts = handle_get_connection(self.client, msg)
14821548
elif isinstance(msg, GetVariable):
@@ -1512,12 +1578,10 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id:
15121578
)
15131579
resp = XComSequenceSliceResult.from_response(xcoms)
15141580
elif isinstance(msg, DeferTask):
1515-
self._terminal_state = TaskInstanceState.DEFERRED
15161581
self._rendered_map_index = msg.rendered_map_index
1517-
self.client.task_instances.defer(self.id, msg)
1582+
self._send_terminal_state_msg(msg)
15181583
elif isinstance(msg, RescheduleTask):
1519-
self._terminal_state = TaskInstanceState.UP_FOR_RESCHEDULE
1520-
self.client.task_instances.reschedule(self.id, msg)
1584+
self._send_terminal_state_msg(msg)
15211585
elif isinstance(msg, SkipDownstreamTasks):
15221586
self.client.task_instances.skip_downstream_tasks(self.id, msg)
15231587
elif isinstance(msg, SetXCom):

task-sdk/tests/task_sdk/execution_time/test_supervisor.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3079,6 +3079,137 @@ def test_handle_requests_network_exception_does_not_crash_loop(self, watched_sub
30793079
# Should not raise StopIteration (which would mean the loop crashed).
30803080
generator.send(req2)
30813081

3082+
@pytest.mark.parametrize(
3083+
("msg", "api_method", "expected_state"),
3084+
[
3085+
pytest.param(
3086+
SucceedTask(end_date=timezone.parse("2024-10-31T12:00:00Z")),
3087+
"succeed",
3088+
TaskInstanceState.SUCCESS,
3089+
id="succeed",
3090+
),
3091+
pytest.param(
3092+
RetryTask(end_date=timezone.parse("2024-10-31T12:00:00Z")),
3093+
"retry",
3094+
TaskInstanceState.UP_FOR_RETRY,
3095+
id="retry",
3096+
),
3097+
pytest.param(
3098+
DeferTask(
3099+
next_method="execute_complete",
3100+
classpath="airflow.providers.standard.triggers.external_task.WorkflowTrigger",
3101+
trigger_kwargs={},
3102+
),
3103+
"defer",
3104+
TaskInstanceState.DEFERRED,
3105+
id="defer",
3106+
),
3107+
pytest.param(
3108+
RescheduleTask(
3109+
reschedule_date=timezone.parse("2024-10-31T12:00:00Z"),
3110+
end_date=timezone.parse("2024-10-31T12:00:00Z"),
3111+
),
3112+
"reschedule",
3113+
TaskInstanceState.UP_FOR_RESCHEDULE,
3114+
id="reschedule",
3115+
),
3116+
],
3117+
)
3118+
def test_terminal_state_not_set_when_direct_api_fails(
3119+
self, watched_subprocess, mocker, msg, api_method, expected_state
3120+
):
3121+
"""`_terminal_state` must NOT be set when the dedicated terminal-state
3122+
API raises.
3123+
3124+
The original message is captured in `_pending_terminal_state_msg`
3125+
BEFORE the API call so the recovery dispatcher in
3126+
`update_task_state_if_needed` can re-issue it on subprocess exit.
3127+
Covers all four terminal-state message types.
3128+
"""
3129+
watched_subprocess, _ = watched_subprocess
3130+
setattr(
3131+
watched_subprocess.client.task_instances,
3132+
api_method,
3133+
mocker.Mock(side_effect=httpx.ConnectError("connection refused")),
3134+
)
3135+
3136+
with pytest.raises(httpx.ConnectError):
3137+
watched_subprocess._handle_request(msg, mocker.Mock(), req_id=1)
3138+
3139+
assert watched_subprocess._terminal_state is None
3140+
# Pending msg preserved so the recovery dispatcher can re-issue.
3141+
assert watched_subprocess._pending_terminal_state_msg is msg
3142+
3143+
@pytest.mark.parametrize(
3144+
("msg", "api_method", "expected_state"),
3145+
[
3146+
pytest.param(
3147+
SucceedTask(end_date=timezone.parse("2024-10-31T12:00:00Z")),
3148+
"succeed",
3149+
TaskInstanceState.SUCCESS,
3150+
id="succeed",
3151+
),
3152+
pytest.param(
3153+
RetryTask(end_date=timezone.parse("2024-10-31T12:00:00Z")),
3154+
"retry",
3155+
TaskInstanceState.UP_FOR_RETRY,
3156+
id="retry",
3157+
),
3158+
pytest.param(
3159+
DeferTask(
3160+
next_method="execute_complete",
3161+
classpath="airflow.providers.standard.triggers.external_task.WorkflowTrigger",
3162+
trigger_kwargs={},
3163+
),
3164+
"defer",
3165+
TaskInstanceState.DEFERRED,
3166+
id="defer",
3167+
),
3168+
pytest.param(
3169+
RescheduleTask(
3170+
reschedule_date=timezone.parse("2024-10-31T12:00:00Z"),
3171+
end_date=timezone.parse("2024-10-31T12:00:00Z"),
3172+
),
3173+
"reschedule",
3174+
TaskInstanceState.UP_FOR_RESCHEDULE,
3175+
id="reschedule",
3176+
),
3177+
],
3178+
)
3179+
def test_update_task_state_replays_pending_terminal_state_call(
3180+
self, watched_subprocess, mocker, msg, api_method, expected_state
3181+
):
3182+
"""If a direct terminal-state API call was attempted and raised, the
3183+
recovery dispatcher must re-issue the dedicated endpoint (not
3184+
`finish()`, which the server-side endpoint refuses for SUCCESS /
3185+
DEFERRED / SERVER_TERMINATED). Covers all four message types.
3186+
"""
3187+
watched_subprocess, _ = watched_subprocess
3188+
watched_subprocess._exit_code = 0
3189+
# Simulate the failure scenario: original API call raised, msg preserved.
3190+
watched_subprocess._pending_terminal_state_msg = msg
3191+
3192+
watched_subprocess.update_task_state_if_needed()
3193+
3194+
# Recovery re-issues the dedicated endpoint, NOT finish().
3195+
getattr(watched_subprocess.client.task_instances, api_method).assert_called_once()
3196+
watched_subprocess.client.task_instances.finish.assert_not_called()
3197+
assert watched_subprocess._terminal_state == expected_state
3198+
assert watched_subprocess._pending_terminal_state_msg is None
3199+
3200+
def test_update_task_state_no_recovery_without_pending_msg(self, watched_subprocess, mocker):
3201+
"""No replay when nothing was pending — preserves the original
3202+
STATES_SENT_DIRECTLY short-circuit for the happy path."""
3203+
watched_subprocess, _ = watched_subprocess
3204+
watched_subprocess._exit_code = 0
3205+
watched_subprocess._terminal_state = TaskInstanceState.SUCCESS
3206+
watched_subprocess._pending_terminal_state_msg = None
3207+
3208+
watched_subprocess.update_task_state_if_needed()
3209+
3210+
watched_subprocess.client.task_instances.finish.assert_not_called()
3211+
watched_subprocess.client.task_instances.succeed.assert_not_called()
3212+
30823213

30833214
class TestSetSupervisorComms:
30843215
class DummyComms:

0 commit comments

Comments
 (0)