Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion airflow/callbacks/callback_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@ class BaseCallbackRequest(BaseModel):
:param msg: Additional Message that can be used for logging
"""

full_filepath: str
filepath: str
"""File Path to use to run the callback"""
bundle_name: str
bundle_version: str | None
msg: str | None = None
"""Additional Message that can be used for logging to determine failure/zombie"""

Expand Down
40 changes: 18 additions & 22 deletions airflow/dag_processing/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ class DagFileInfo:
rel_path: Path
bundle_name: str
bundle_path: Path | None = field(compare=False, default=None)
bundle_version: str | None = None

@property
def absolute_path(self) -> Path:
Expand Down Expand Up @@ -191,7 +192,7 @@ class DagFileProcessorManager:
_parsing_start_time: float = attrs.field(init=False)
_num_run: int = attrs.field(default=0, init=False)

_callback_to_execute: dict[str, list[CallbackRequest]] = attrs.field(
_callback_to_execute: dict[DagFileInfo, list[CallbackRequest]] = attrs.field(
factory=lambda: defaultdict(list), init=False
)

Expand Down Expand Up @@ -407,18 +408,21 @@ def _fetch_callbacks(

def _add_callback_to_queue(self, request: CallbackRequest):
self.log.debug("Queuing %s CallbackRequest: %s", type(request).__name__, request)
self.log.warning("Callbacks are not implemented yet!")
# TODO: AIP-66 make callbacks bundle aware
return
self._callback_to_execute[request.full_filepath].append(request)
if request.full_filepath in self._file_queue:
# Remove file paths matching request.full_filepath from self._file_queue
# Since we are already going to use that filepath to run callback,
# there is no need to have same file path again in the queue
# todo (AIP-66): update re bundle and rel loc
self._file_queue = deque(f for f in self._file_queue if f != request.full_filepath)
# todo (AIP-66): update re bundle and rel loc
self._add_files_to_queue([request.full_filepath], True)
try:
bundle = DagBundlesManager().get_bundle(name=request.bundle_name, version=request.bundle_version)
except ValueError:
# Bundle no longer configured
self.log.error("Bundle %s no longer configured, skipping callback", request.bundle_name)
return None

file_info = DagFileInfo(
rel_path=Path(request.filepath),
bundle_path=bundle.path,
bundle_name=request.bundle_name,
bundle_version=request.bundle_version,
)
self._callback_to_execute[file_info].append(request)
self._add_files_to_queue([file_info], True)
Stats.incr("dag_processing.other_callback_count")

@classmethod
Expand Down Expand Up @@ -686,15 +690,8 @@ def set_files(self, files: list[DagFileInfo]):
"""
self._files = files

# remove from queue any files no longer in the _files list
self._file_queue = deque(x for x in self._file_queue if x in files)
Stats.gauge("dag_processing.file_path_queue_size", len(self._file_queue))

# TODO: AIP-66 make callbacks bundle aware
# callback_paths_to_del = [x for x in self._callback_to_execute if x not in new_file_paths]
# for path_to_del in callback_paths_to_del:
# del self._callback_to_execute[path_to_del]

# Stop processors that are working on deleted files
filtered_processors = {}
for file, processor in self._processors.items():
Expand Down Expand Up @@ -785,8 +782,7 @@ def _get_logger_for_dag_file(self, dag_file: DagFileInfo):
def _create_process(self, dag_file: DagFileInfo) -> DagFileProcessorProcess:
id = uuid7()

# callback_to_execute_for_file = self._callback_to_execute.pop(file_path, [])
callback_to_execute_for_file: list[CallbackRequest] = []
callback_to_execute_for_file = self._callback_to_execute.pop(dag_file, [])

return DagFileProcessorProcess.start(
id=id,
Expand Down
13 changes: 8 additions & 5 deletions airflow/dag_processing/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,11 @@ def _parse_file_entrypoint():
log = structlog.get_logger(logger_name="task")

result = _parse_file(msg, log)
comms_decoder.send_request(log, result)
if result is not None:
comms_decoder.send_request(log, result)


def _parse_file(msg: DagFileParseRequest, log: FilteringBoundLogger) -> DagFileParsingResult:
def _parse_file(msg: DagFileParseRequest, log: FilteringBoundLogger) -> DagFileParsingResult | None:
# TODO: Set known_pool names on DagBag!
bag = DagBag(
dag_folder=msg.file,
Expand All @@ -79,6 +80,11 @@ def _parse_file(msg: DagFileParseRequest, log: FilteringBoundLogger) -> DagFileP
safe_mode=True,
load_op_links=False,
)
if msg.callback_requests:
Comment thread
ephraimbuddy marked this conversation as resolved.
# If the request is for callback, we shouldn't serialize the DAGs
_execute_callbacks(bag, msg.callback_requests, log)
return None

serialized_dags, serialization_import_errors = _serialize_dags(bag, log)
bag.import_errors.update(serialization_import_errors)
dags = [LazyDeserializedDAG(data=serdag) for serdag in serialized_dags]
Expand All @@ -89,9 +95,6 @@ def _parse_file(msg: DagFileParseRequest, log: FilteringBoundLogger) -> DagFileP
# TODO: Make `bag.dag_warnings` not return SQLA model objects
warnings=[],
)

if msg.callback_requests:
_execute_callbacks(bag, msg.callback_requests, log)
return result


Expand Down
31 changes: 21 additions & 10 deletions airflow/jobs/scheduler_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from deprecated import deprecated
from sqlalchemy import and_, delete, exists, func, select, text, tuple_, update
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import lazyload, load_only, make_transient, selectinload
from sqlalchemy.orm import joinedload, lazyload, load_only, make_transient, selectinload
from sqlalchemy.sql import expression

from airflow import settings
Expand Down Expand Up @@ -756,7 +756,12 @@ def process_executor_events(

# Check state of finished tasks
filter_for_tis = TI.filter_for_tis(tis_with_right_state)
query = select(TI).where(filter_for_tis).options(selectinload(TI.dag_model))
query = (
select(TI)
.where(filter_for_tis)
.options(selectinload(TI.dag_model))
.options(joinedload(TI.dag_version))
)
# row lock this entire set of taskinstances to make sure the scheduler doesn't fail when we have
# multi-schedulers
tis_query: Query = with_row_locks(query, of=TI, session=session, skip_locked=True)
Expand Down Expand Up @@ -853,7 +858,9 @@ def process_executor_events(
# too, which would lead to double logging
cls.logger().error(msg)
request = TaskCallbackRequest(
full_filepath=ti.dag_model.fileloc,
filepath=ti.dag_model.relative_fileloc,
bundle_name=ti.dag_version.bundle_name,
bundle_version=ti.dag_version.bundle_version,
ti=ti,
msg=msg,
)
Expand Down Expand Up @@ -1627,9 +1634,11 @@ def _schedule_dag_run(
dag_model.calculate_dagrun_date_fields(dag, dag.get_run_data_interval(dag_run))

callback_to_execute = DagCallbackRequest(
full_filepath=dag.fileloc,
filepath=dag_model.relative_fileloc,
dag_id=dag.dag_id,
run_id=dag_run.run_id,
bundle_name=dag_model.bundle_name,
bundle_version=dag_run.bundle_version,
is_failure_callback=True,
msg="timed_out",
)
Expand Down Expand Up @@ -1991,11 +2000,11 @@ def _find_and_purge_zombies(self) -> None:
if zombies := self._find_zombies(session=session):
self._purge_zombies(zombies, session=session)

def _find_zombies(self, *, session: Session) -> list[tuple[TI, str]]:
def _find_zombies(self, *, session: Session) -> list[TI]:
self.log.debug("Finding 'running' jobs without a recent heartbeat")
limit_dttm = timezone.utcnow() - timedelta(seconds=self._zombie_threshold_secs)
zombies = session.execute(
select(TI, DM.fileloc)
zombies = session.scalars(
select(TI)
.with_hint(TI, "USE INDEX (ti_state)", dialect_name="mysql")
.join(DM, TI.dag_id == DM.dag_id)
.where(
Expand All @@ -2008,11 +2017,13 @@ def _find_zombies(self, *, session: Session) -> list[tuple[TI, str]]:
self.log.warning("Failing %s TIs without heartbeat after %s", len(zombies), limit_dttm)
return zombies

def _purge_zombies(self, zombies: list[tuple[TI, str]], *, session: Session) -> None:
for ti, file_loc in zombies:
def _purge_zombies(self, zombies: list[TI], *, session: Session) -> None:
for ti in zombies:
zombie_message_details = self._generate_zombie_message_details(ti)
request = TaskCallbackRequest(
full_filepath=file_loc,
filepath=ti.dag_model.relative_fileloc,
bundle_name=ti.dag_version.bundle_name,
bundle_version=ti.dag_run.bundle_version,
ti=ti,
msg=str(zombie_message_details),
)
Expand Down
12 changes: 9 additions & 3 deletions airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -922,9 +922,11 @@ def recalculate(self) -> _UnfinishedStates:
dag.handle_callback(self, success=False, reason="task_failure", session=session)
elif dag.has_on_failure_callback:
callback = DagCallbackRequest(
full_filepath=dag.fileloc,
filepath=self.dag_model.relative_fileloc,
dag_id=self.dag_id,
run_id=self.run_id,
bundle_name=self.dag_version.bundle_name,
bundle_version=self.bundle_version,
is_failure_callback=True,
msg="task_failure",
)
Expand All @@ -949,9 +951,11 @@ def recalculate(self) -> _UnfinishedStates:
dag.handle_callback(self, success=True, reason="success", session=session)
elif dag.has_on_success_callback:
callback = DagCallbackRequest(
full_filepath=dag.fileloc,
filepath=self.dag_model.relative_fileloc,
dag_id=self.dag_id,
run_id=self.run_id,
bundle_name=self.dag_version.bundle_name,
bundle_version=self.bundle_version,
is_failure_callback=False,
msg="success",
)
Expand All @@ -966,9 +970,11 @@ def recalculate(self) -> _UnfinishedStates:
dag.handle_callback(self, success=False, reason="all_tasks_deadlocked", session=session)
elif dag.has_on_failure_callback:
callback = DagCallbackRequest(
full_filepath=dag.fileloc,
filepath=self.dag_model.relative_fileloc,
dag_id=self.dag_id,
run_id=self.run_id,
bundle_name=self.dag_version.bundle_name,
bundle_version=self.bundle_version,
is_failure_callback=True,
msg="all_tasks_deadlocked",
)
Expand Down
4 changes: 3 additions & 1 deletion airflow/triggers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,9 +217,11 @@ def _submit_callback_if_necessary(self, *, task_instance: TaskInstance, session)
"""Submit a callback request if the task state is SUCCESS or FAILED."""
if self.task_instance_state in (TaskInstanceState.SUCCESS, TaskInstanceState.FAILED):
request = TaskCallbackRequest(
full_filepath=task_instance.dag_model.fileloc,
filepath=task_instance.dag_model.relative_fileloc,
ti=task_instance,
task_callback_type=self.task_instance_state,
bundle_name=task_instance.dag_model.bundle_name,
Comment thread
jedcunningham marked this conversation as resolved.
bundle_version=task_instance.dag_run.bundle_version,
)
log.info("Sending callback: %s", request)
try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,9 @@ def test_send_callback(self):
cel_k8s_exec.callback_sink = mock.MagicMock()

if AIRFLOW_V_3_0_PLUS:
callback = DagCallbackRequest(full_filepath="fake", dag_id="fake", run_id="fake")
callback = DagCallbackRequest(
filepath="fake", dag_id="fake", run_id="fake", bundle_name="testing", bundle_version=None
)
else:
callback = CallbackRequest(full_filepath="fake")
cel_k8s_exec.send_callback(callback)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,9 @@ def test_send_callback(self):
local_k8s_exec.callback_sink = mock.MagicMock()

if AIRFLOW_V_3_0_PLUS:
callback = DagCallbackRequest(full_filepath="fake", dag_id="fake", run_id="fake")
callback = DagCallbackRequest(
filepath="fake", dag_id="fake", run_id="fake", bundle_name="fake", bundle_version=None
)
else:
callback = CallbackRequest(full_filepath="fake")
local_k8s_exec.send_callback(callback)
Expand Down
12 changes: 5 additions & 7 deletions tests/callbacks/test_callback_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,12 @@ class TestCallbackRequest:
),
(
DagCallbackRequest(
full_filepath="filepath",
filepath="filepath",
dag_id="fake_dag",
run_id="fake_run",
is_failure_callback=False,
bundle_name="testing",
bundle_version=None,
),
DagCallbackRequest,
),
Expand All @@ -66,8 +68,7 @@ def test_from_json(self, input, request_class):
)

input = TaskCallbackRequest(
full_filepath="filepath",
ti=ti,
filepath="filepath", ti=ti, bundle_name="testing", bundle_version=None
)
json_str = input.to_json()
result = request_class.from_json(json_str)
Expand All @@ -79,10 +80,7 @@ def test_taskcallback_to_json_with_start_date_and_end_date(self, session, create
ti.end_date = timezone.utcnow()
session.merge(ti)
session.flush()
input = TaskCallbackRequest(
full_filepath="filepath",
ti=ti,
)
input = TaskCallbackRequest(filepath="filepath", ti=ti, bundle_name="testing", bundle_version=None)
json_str = input.to_json()
result = TaskCallbackRequest.from_json(json_str)
assert input == result
Loading