Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions task-sdk/src/airflow/sdk/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@
XComSequenceSliceResponse,
)
from airflow.sdk.configuration import conf
from airflow.sdk.exceptions import ErrorType
from airflow.sdk.exceptions import ErrorType, TaskAlreadyRunningError
from airflow.sdk.execution_time.comms import (
CreateHITLDetailPayload,
DRCount,
Expand Down Expand Up @@ -216,7 +216,18 @@ def start(self, id: uuid.UUID, pid: int, when: datetime) -> TIRunContext:
"""Tell the API server that this TI has started running."""
body = TIEnterRunningPayload(pid=pid, hostname=get_hostname(), unixname=getuser(), start_date=when)

resp = self.client.patch(f"task-instances/{id}/run", content=body.model_dump_json())
try:
resp = self.client.patch(f"task-instances/{id}/run", content=body.model_dump_json())
except ServerResponseError as e:
if e.response.status_code == HTTPStatus.CONFLICT:
detail = e.detail
if (
isinstance(detail, dict)
and detail.get("reason") == "invalid_state"
and detail.get("previous_state") == "running"
):
raise TaskAlreadyRunningError(f"Task instance {id} is already running") from e
raise
Comment thread
anishgirianish marked this conversation as resolved.
return TIRunContext.model_validate_json(resp.read())

def finish(self, id: uuid.UUID, state: TerminalStateNonSuccess, when: datetime, rendered_map_index):
Expand Down Expand Up @@ -1034,7 +1045,7 @@ def dags(self) -> DagsOperations:

# This is only used for parsing. ServerResponseError is raised instead
class _ErrorBody(BaseModel):
detail: list[RemoteValidationError] | str
detail: list[RemoteValidationError] | dict[str, Any] | str

def __repr__(self):
return repr(self.detail)
Expand Down Expand Up @@ -1068,6 +1079,9 @@ def from_response(cls, response: httpx.Response) -> ServerResponseError | None:
if isinstance(body.detail, list):
detail = body.detail
msg = "Remote server returned validation error"
elif isinstance(body.detail, dict):
detail = body.detail
msg = "Server returned error"
else:
msg = body.detail or "Un-parseable error"
except Exception:
Expand Down
4 changes: 4 additions & 0 deletions task-sdk/src/airflow/sdk/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,10 @@ class TaskNotFound(AirflowException):
"""Raise when a Task is not available in the system."""


class TaskAlreadyRunningError(AirflowException):
"""Raised when a task is already running on another worker."""


class FailFastDagInvalidTriggerRule(AirflowException):
"""Raise when a dag has 'fail_fast' enabled yet has a non-default trigger rule."""

Expand Down
59 changes: 52 additions & 7 deletions task-sdk/tests/task_sdk/api/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
VariableResponse,
XComResponse,
)
from airflow.sdk.exceptions import ErrorType
from airflow.sdk.exceptions import ErrorType, TaskAlreadyRunningError
from airflow.sdk.execution_time.comms import (
DeferTask,
ErrorResponse,
Expand Down Expand Up @@ -161,7 +161,7 @@ def test_server_response_error_pickling(self):

err = exc_info.value
assert err.args == ("Server returned error",)
assert err.detail == {"detail": {"message": "Invalid input"}}
assert err.detail == {"message": "Invalid input"}

# Check that the error is picklable
pickled = pickle.dumps(err)
Expand All @@ -171,7 +171,7 @@ def test_server_response_error_pickling(self):

# Test that unpickled error has the same attributes as the original
assert unpickled.response.json() == {"detail": {"message": "Invalid input"}}
assert unpickled.detail == {"detail": {"message": "Invalid input"}}
assert unpickled.detail == {"message": "Invalid input"}
assert unpickled.response.status_code == 404
assert unpickled.request.url == "http://error"

Expand Down Expand Up @@ -333,6 +333,53 @@ def handle_request(request: httpx.Request) -> httpx.Response:
assert resp == ti_context
assert call_count == 3

def test_task_instance_start_already_running(self):
"""Test that start() raises TaskAlreadyRunningError when TI is already running."""
ti_id = uuid6.uuid7()

def handle_request(request: httpx.Request) -> httpx.Response:
if request.url.path == f"/task-instances/{ti_id}/run":
return httpx.Response(
409,
json={
"detail": {
"reason": "invalid_state",
"message": "TI was not in a state where it could be marked as running",
"previous_state": "running",
}
},
)
return httpx.Response(status_code=204)

client = make_client(transport=httpx.MockTransport(handle_request))

with pytest.raises(TaskAlreadyRunningError, match="already running"):
client.task_instances.start(ti_id, 100, datetime(2024, 10, 31, tzinfo=timezone.utc))

@pytest.mark.parametrize("previous_state", ["failed", "success", "skipped"])
def test_task_instance_start_other_invalid_states(self, previous_state):
"""Test that start() raises ServerResponseError for non-running invalid states."""
ti_id = uuid6.uuid7()

def handle_request(request: httpx.Request) -> httpx.Response:
if request.url.path == f"/task-instances/{ti_id}/run":
return httpx.Response(
409,
json={
"detail": {
"reason": "invalid_state",
"message": "TI was not in a state where it could be marked as running",
"previous_state": previous_state,
}
},
)
return httpx.Response(status_code=204)

client = make_client(transport=httpx.MockTransport(handle_request))

with pytest.raises(ServerResponseError):
client.task_instances.start(ti_id, 100, datetime(2024, 10, 31, tzinfo=timezone.utc))

@pytest.mark.parametrize(
"state", [state for state in TerminalTIState if state != TerminalTIState.SUCCESS]
)
Expand Down Expand Up @@ -1627,10 +1674,8 @@ def handle_request(request: httpx.Request) -> httpx.Response:

assert exc_info.value.response.status_code == 404
assert exc_info.value.detail == {
"detail": {
"message": "The Dag with dag_id: `missing_dag` was not found",
"reason": "not_found",
}
"message": "The Dag with dag_id: `missing_dag` was not found",
"reason": "not_found",
}

def test_get_server_error(self):
Expand Down
70 changes: 35 additions & 35 deletions task-sdk/tests/task_sdk/execution_time/test_supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
TaskInstance,
TaskInstanceState,
)
from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType
from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType, TaskAlreadyRunningError
from airflow.sdk.execution_time import task_runner
from airflow.sdk.execution_time.comms import (
AssetEventsResult,
Expand Down Expand Up @@ -731,40 +731,6 @@ def mock_monotonic():
"task_instance_id": str(ti.id),
} in captured_logs

def test_supervisor_handles_already_running_task(self):
Comment thread
anishgirianish marked this conversation as resolved.
"""Test that Supervisor prevents starting a Task Instance that is already running."""
ti = TaskInstance(
id=uuid7(), task_id="b", dag_id="c", run_id="d", try_number=1, dag_version_id=uuid7()
)

# Mock API Server response indicating the TI is already running
# The API Server would return a 409 Conflict status code if the TI is not
# in a "queued" state.
def handle_request(request: httpx.Request) -> httpx.Response:
if request.url.path == f"/task-instances/{ti.id}/run":
return httpx.Response(
409,
json={
"reason": "invalid_state",
"message": "TI was not in a state where it could be marked as running",
"previous_state": "running",
},
)

return httpx.Response(status_code=204)

client = make_client(transport=httpx.MockTransport(handle_request))

with pytest.raises(ServerResponseError, match="Server returned error") as err:
ActivitySubprocess.start(dag_rel_path=os.devnull, bundle_info=FAKE_BUNDLE, what=ti, client=client)

assert err.value.response.status_code == 409
assert err.value.detail == {
"reason": "invalid_state",
"message": "TI was not in a state where it could be marked as running",
"previous_state": "running",
}

@pytest.mark.parametrize("captured_logs", [logging.ERROR], indirect=True, ids=["log_level=error"])
def test_state_conflict_on_heartbeat(self, captured_logs, monkeypatch, mocker, make_ti_context_dict):
Comment thread
anishgirianish marked this conversation as resolved.
"""
Expand Down Expand Up @@ -865,6 +831,40 @@ def handle_request(request: httpx.Request) -> httpx.Response:
},
]

def test_start_raises_task_already_running_and_kills_subprocess(self):
"""Test that ActivitySubprocess.start() raises TaskAlreadyRunningError and kills the child
when the API returns 409 with previous_state='running'."""
ti_id = uuid7()

def handle_request(request: httpx.Request) -> httpx.Response:
if request.url.path == f"/task-instances/{ti_id}/run":
return httpx.Response(
409,
json={
"detail": {
"reason": "invalid_state",
"message": "TI was not in a state where it could be marked as running",
"previous_state": "running",
}
},
)
return httpx.Response(status_code=204)

def subprocess_main():
# Ensure we follow the "protocol" and get the startup message before we do anything
CommsDecoder()._get_response()

with pytest.raises(TaskAlreadyRunningError, match="already running"):
ActivitySubprocess.start(
dag_rel_path=os.devnull,
bundle_info=FAKE_BUNDLE,
what=TaskInstance(
id=ti_id, task_id="b", dag_id="c", run_id="d", try_number=1, dag_version_id=uuid7()
),
client=make_client(transport=httpx.MockTransport(handle_request)),
target=subprocess_main,
)

@pytest.mark.parametrize("captured_logs", [logging.WARNING], indirect=True)
def test_heartbeat_failures_handling(self, monkeypatch, mocker, captured_logs, time_machine):
"""
Expand Down
Loading