diff --git a/devel-common/src/tests_common/test_utils/version_compat.py b/devel-common/src/tests_common/test_utils/version_compat.py index 41326e951539d..7921f02529668 100644 --- a/devel-common/src/tests_common/test_utils/version_compat.py +++ b/devel-common/src/tests_common/test_utils/version_compat.py @@ -38,6 +38,7 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]: AIRFLOW_V_3_1_PLUS = get_base_airflow_version_tuple() >= (3, 1, 0) AIRFLOW_V_3_1_3_PLUS = get_base_airflow_version_tuple() >= (3, 1, 3) AIRFLOW_V_3_1_7_PLUS = get_base_airflow_version_tuple() >= (3, 1, 7) +AIRFLOW_V_3_1_9_PLUS = get_base_airflow_version_tuple() >= (3, 1, 9) AIRFLOW_V_3_2_PLUS = get_base_airflow_version_tuple() >= (3, 2, 0) if AIRFLOW_V_3_1_PLUS: diff --git a/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py b/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py index 768f2d9dbf646..6ac9ce1902974 100644 --- a/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py +++ b/providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py @@ -42,7 +42,11 @@ from sqlalchemy import select from airflow.executors.base_executor import BaseExecutor -from airflow.providers.celery.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_2_PLUS +from airflow.providers.celery.version_compat import ( + AIRFLOW_V_3_0_PLUS, + AIRFLOW_V_3_1_9_PLUS, + AIRFLOW_V_3_2_PLUS, +) from airflow.providers.common.compat.sdk import AirflowException, AirflowTaskTimeout, Stats, conf, timeout from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.net import get_hostname @@ -189,6 +193,7 @@ def on_celery_worker_ready(*args, **kwargs): # and deserialization for us @app.task(name="execute_workload") def execute_workload(input: str) -> None: + from celery.exceptions import Ignore from pydantic import TypeAdapter from airflow.executors import workloads @@ -208,22 +213,35 @@ def execute_workload(input: str) -> None: base_url = f"http://localhost:8080{base_url}" default_execution_api_server = f"{base_url.rstrip('/')}/execution/" - if isinstance(workload, workloads.ExecuteTask): - supervise( - # This is the "wrong" ti type, but it duck types the same. TODO: Create a protocol for this. - ti=workload.ti, # type: ignore[arg-type] - dag_rel_path=workload.dag_rel_path, - bundle_info=workload.bundle_info, - token=workload.token, - server=conf.get("core", "execution_api_server_url", fallback=default_execution_api_server), - log_path=workload.log_path, - ) - elif isinstance(workload, workloads.ExecuteCallback): - success, error_msg = execute_callback_workload(workload.callback, log) - if not success: - raise RuntimeError(error_msg or "Callback execution failed") - else: - raise ValueError(f"CeleryExecutor does not know how to handle {type(workload)}") + try: + if isinstance(workload, workloads.ExecuteTask): + supervise( + # This is the "wrong" ti type, but it duck types the same. TODO: Create a protocol for this. + ti=workload.ti, # type: ignore[arg-type] + dag_rel_path=workload.dag_rel_path, + bundle_info=workload.bundle_info, + token=workload.token, + server=conf.get("core", "execution_api_server_url", fallback=default_execution_api_server), + log_path=workload.log_path, + ) + elif isinstance(workload, workloads.ExecuteCallback): + success, error_msg = execute_callback_workload(workload.callback, log) + if not success: + raise RuntimeError(error_msg or "Callback execution failed") + else: + raise ValueError(f"CeleryExecutor does not know how to handle {type(workload)}") + except Exception as e: + if AIRFLOW_V_3_1_9_PLUS: + from airflow.sdk.exceptions import TaskAlreadyRunningError + + if isinstance(e, TaskAlreadyRunningError): + log.info("[%s] Task already running elsewhere, ignoring redelivered message", celery_task_id) + # Raise Ignore() so Celery does not record a FAILURE result for this duplicate + # delivery. Without this, the broker redelivering the message (e.g. after a + # visibility timeout) would cause Celery to mark the task as failed, even though + # the original worker is still executing it successfully. + raise Ignore() + raise if not AIRFLOW_V_3_0_PLUS: diff --git a/providers/celery/src/airflow/providers/celery/version_compat.py b/providers/celery/src/airflow/providers/celery/version_compat.py index 0b65e14199e6f..6d0c610745181 100644 --- a/providers/celery/src/airflow/providers/celery/version_compat.py +++ b/providers/celery/src/airflow/providers/celery/version_compat.py @@ -27,6 +27,7 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]: AIRFLOW_V_3_0_PLUS = get_base_airflow_version_tuple() >= (3, 0, 0) +AIRFLOW_V_3_1_9_PLUS = get_base_airflow_version_tuple() >= (3, 1, 9) AIRFLOW_V_3_2_PLUS = get_base_airflow_version_tuple() >= (3, 2, 0) -__all__ = ["AIRFLOW_V_3_0_PLUS", "AIRFLOW_V_3_2_PLUS"] +__all__ = ["AIRFLOW_V_3_0_PLUS", "AIRFLOW_V_3_1_9_PLUS", "AIRFLOW_V_3_2_PLUS"] diff --git a/providers/celery/tests/unit/celery/executors/test_celery_executor.py b/providers/celery/tests/unit/celery/executors/test_celery_executor.py index 27328a78d061c..ff2c146f82874 100644 --- a/providers/celery/tests/unit/celery/executors/test_celery_executor.py +++ b/providers/celery/tests/unit/celery/executors/test_celery_executor.py @@ -45,7 +45,12 @@ from tests_common.test_utils.config import conf_vars from tests_common.test_utils.dag import sync_dag_to_db from tests_common.test_utils.taskinstance import create_task_instance -from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_1_PLUS, AIRFLOW_V_3_2_PLUS +from tests_common.test_utils.version_compat import ( + AIRFLOW_V_3_0_PLUS, + AIRFLOW_V_3_1_9_PLUS, + AIRFLOW_V_3_1_PLUS, + AIRFLOW_V_3_2_PLUS, +) if AIRFLOW_V_3_0_PLUS: from airflow.models.dag_version import DagVersion @@ -761,3 +766,51 @@ def test_celery_tasks_registered_on_import(): assert "execute_command" in registered_tasks, ( "execute_command must be registered for Airflow 2.x compatibility." ) + + +@pytest.mark.skipif(not AIRFLOW_V_3_1_9_PLUS, reason="TaskAlreadyRunningError requires Airflow 3.1.9+") +def test_execute_workload_ignores_already_running_task(): + """Test that execute_workload raises Celery Ignore when task is already running.""" + import importlib + + from celery.exceptions import Ignore + + from airflow.sdk.exceptions import TaskAlreadyRunningError + + importlib.reload(celery_executor_utils) + execute_workload_unwrapped = celery_executor_utils.execute_workload.__wrapped__ + + mock_current_task = mock.MagicMock() + mock_current_task.request.id = "test-celery-task-id" + mock_app = mock.MagicMock() + mock_app.current_task = mock_current_task + + with ( + mock.patch("airflow.sdk.execution_time.supervisor.supervise") as mock_supervise, + mock.patch.object(celery_executor_utils, "app", mock_app), + ): + mock_supervise.side_effect = TaskAlreadyRunningError("Task already running") + + workload_json = """ + { + "type": "ExecuteTask", + "token": "test-token", + "dag_rel_path": "test_dag.py", + "bundle_info": {"name": "test-bundle", "version": null}, + "log_path": "test.log", + "ti": { + "id": "019bdec0-d353-7b68-abe0-5ac20fa75ad0", + "dag_version_id": "019bdead-fdcd-78ab-a9f2-aba3b80fded2", + "task_id": "test_task", + "dag_id": "test_dag", + "run_id": "test_run", + "try_number": 1, + "map_index": -1, + "pool_slots": 1, + "queue": "default", + "priority_weight": 1 + } + } + """ + with pytest.raises(Ignore): + execute_workload_unwrapped(workload_json)