Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions devel-common/src/tests_common/test_utils/version_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]:
AIRFLOW_V_3_1_PLUS = get_base_airflow_version_tuple() >= (3, 1, 0)
AIRFLOW_V_3_1_3_PLUS = get_base_airflow_version_tuple() >= (3, 1, 3)
AIRFLOW_V_3_1_7_PLUS = get_base_airflow_version_tuple() >= (3, 1, 7)
AIRFLOW_V_3_1_9_PLUS = get_base_airflow_version_tuple() >= (3, 1, 9)
AIRFLOW_V_3_2_PLUS = get_base_airflow_version_tuple() >= (3, 2, 0)

if AIRFLOW_V_3_1_PLUS:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,11 @@
from sqlalchemy import select

from airflow.executors.base_executor import BaseExecutor
from airflow.providers.celery.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_2_PLUS
from airflow.providers.celery.version_compat import (
AIRFLOW_V_3_0_PLUS,
AIRFLOW_V_3_1_9_PLUS,
AIRFLOW_V_3_2_PLUS,
)
from airflow.providers.common.compat.sdk import AirflowException, AirflowTaskTimeout, Stats, conf, timeout
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.net import get_hostname
Expand Down Expand Up @@ -189,6 +193,7 @@ def on_celery_worker_ready(*args, **kwargs):
# and deserialization for us
@app.task(name="execute_workload")
def execute_workload(input: str) -> None:
from celery.exceptions import Ignore
from pydantic import TypeAdapter

from airflow.executors import workloads
Expand All @@ -208,22 +213,35 @@ def execute_workload(input: str) -> None:
base_url = f"http://localhost:8080{base_url}"
default_execution_api_server = f"{base_url.rstrip('/')}/execution/"

if isinstance(workload, workloads.ExecuteTask):
supervise(
# This is the "wrong" ti type, but it duck types the same. TODO: Create a protocol for this.
ti=workload.ti, # type: ignore[arg-type]
dag_rel_path=workload.dag_rel_path,
bundle_info=workload.bundle_info,
token=workload.token,
server=conf.get("core", "execution_api_server_url", fallback=default_execution_api_server),
log_path=workload.log_path,
)
elif isinstance(workload, workloads.ExecuteCallback):
success, error_msg = execute_callback_workload(workload.callback, log)
if not success:
raise RuntimeError(error_msg or "Callback execution failed")
else:
raise ValueError(f"CeleryExecutor does not know how to handle {type(workload)}")
try:
if isinstance(workload, workloads.ExecuteTask):
supervise(
# This is the "wrong" ti type, but it duck types the same. TODO: Create a protocol for this.
ti=workload.ti, # type: ignore[arg-type]
dag_rel_path=workload.dag_rel_path,
bundle_info=workload.bundle_info,
token=workload.token,
server=conf.get("core", "execution_api_server_url", fallback=default_execution_api_server),
log_path=workload.log_path,
)
elif isinstance(workload, workloads.ExecuteCallback):
success, error_msg = execute_callback_workload(workload.callback, log)
if not success:
raise RuntimeError(error_msg or "Callback execution failed")
else:
raise ValueError(f"CeleryExecutor does not know how to handle {type(workload)}")
except Exception as e:
if AIRFLOW_V_3_1_9_PLUS:
from airflow.sdk.exceptions import TaskAlreadyRunningError

if isinstance(e, TaskAlreadyRunningError):
log.info("[%s] Task already running elsewhere, ignoring redelivered message", celery_task_id)
# Raise Ignore() so Celery does not record a FAILURE result for this duplicate
# delivery. Without this, the broker redelivering the message (e.g. after a
# visibility timeout) would cause Celery to mark the task as failed, even though
# the original worker is still executing it successfully.
raise Ignore()
raise


if not AIRFLOW_V_3_0_PLUS:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]:


AIRFLOW_V_3_0_PLUS = get_base_airflow_version_tuple() >= (3, 0, 0)
AIRFLOW_V_3_1_9_PLUS = get_base_airflow_version_tuple() >= (3, 1, 9)
AIRFLOW_V_3_2_PLUS = get_base_airflow_version_tuple() >= (3, 2, 0)

__all__ = ["AIRFLOW_V_3_0_PLUS", "AIRFLOW_V_3_2_PLUS"]
__all__ = ["AIRFLOW_V_3_0_PLUS", "AIRFLOW_V_3_1_9_PLUS", "AIRFLOW_V_3_2_PLUS"]
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,12 @@
from tests_common.test_utils.config import conf_vars
from tests_common.test_utils.dag import sync_dag_to_db
from tests_common.test_utils.taskinstance import create_task_instance
from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_1_PLUS, AIRFLOW_V_3_2_PLUS
from tests_common.test_utils.version_compat import (
AIRFLOW_V_3_0_PLUS,
AIRFLOW_V_3_1_9_PLUS,
AIRFLOW_V_3_1_PLUS,
AIRFLOW_V_3_2_PLUS,
)

if AIRFLOW_V_3_0_PLUS:
from airflow.models.dag_version import DagVersion
Expand Down Expand Up @@ -761,3 +766,51 @@ def test_celery_tasks_registered_on_import():
assert "execute_command" in registered_tasks, (
"execute_command must be registered for Airflow 2.x compatibility."
)


@pytest.mark.skipif(not AIRFLOW_V_3_1_9_PLUS, reason="TaskAlreadyRunningError requires Airflow 3.1.9+")
def test_execute_workload_ignores_already_running_task():
"""Test that execute_workload raises Celery Ignore when task is already running."""
import importlib

from celery.exceptions import Ignore

from airflow.sdk.exceptions import TaskAlreadyRunningError

importlib.reload(celery_executor_utils)
execute_workload_unwrapped = celery_executor_utils.execute_workload.__wrapped__

mock_current_task = mock.MagicMock()
mock_current_task.request.id = "test-celery-task-id"
mock_app = mock.MagicMock()
mock_app.current_task = mock_current_task

with (
mock.patch("airflow.sdk.execution_time.supervisor.supervise") as mock_supervise,
mock.patch.object(celery_executor_utils, "app", mock_app),
):
mock_supervise.side_effect = TaskAlreadyRunningError("Task already running")

workload_json = """
{
"type": "ExecuteTask",
"token": "test-token",
"dag_rel_path": "test_dag.py",
"bundle_info": {"name": "test-bundle", "version": null},
"log_path": "test.log",
"ti": {
"id": "019bdec0-d353-7b68-abe0-5ac20fa75ad0",
"dag_version_id": "019bdead-fdcd-78ab-a9f2-aba3b80fded2",
"task_id": "test_task",
"dag_id": "test_dag",
"run_id": "test_run",
"try_number": 1,
"map_index": -1,
"pool_slots": 1,
"queue": "default",
"priority_weight": 1
}
}
"""
with pytest.raises(Ignore):
execute_workload_unwrapped(workload_json)
Loading