Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions airflow-core/src/airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
from airflow.models.taskreschedule import TaskReschedule
from airflow.models.xcom import XCOM_RETURN_KEY, LazyXComSelectSequence, XComModel
from airflow.settings import task_instance_mutation_hook
from airflow.task.priority_strategy import validate_and_load_priority_weight_strategy
from airflow.ti_deps.dep_context import DepContext
from airflow.ti_deps.dependencies_deps import REQUEUEABLE_DEPS, RUNNING_DEPS
from airflow.ti_deps.deps.ready_to_reschedule import ReadyToRescheduleDep
Expand Down Expand Up @@ -693,7 +694,10 @@ def insert_mapping(

:meta private:
"""
priority_weight = task.weight_rule.get_weight(
weight_rule = task.weight_rule
if not hasattr(weight_rule, "get_weight"):
weight_rule = validate_and_load_priority_weight_strategy(weight_rule)
priority_weight = weight_rule.get_weight(
TaskInstance(task=task, run_id=run_id, map_index=map_index, dag_version_id=dag_version_id)
)
context_carrier = new_task_run_carrier(dag_run.context_carrier)
Expand Down Expand Up @@ -874,7 +878,10 @@ def refresh_from_task(self, task: Operator, pool_override: str | None = None) ->
self.queue = task.queue
self.pool = pool_override or task.pool
self.pool_slots = task.pool_slots
self.priority_weight = self.task.weight_rule.get_weight(self)
weight_rule = self.task.weight_rule
if not hasattr(weight_rule, "get_weight"):
weight_rule = validate_and_load_priority_weight_strategy(weight_rule)
self.priority_weight = weight_rule.get_weight(self)
self.run_as_user = task.run_as_user
# Do not set max_tries to task.retries here because max_tries is a cumulative
# value that needs to be stored in the db.
Expand Down
20 changes: 20 additions & 0 deletions airflow-core/tests/unit/models/test_taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2653,6 +2653,26 @@ def mock_policy(task_instance: TaskInstance):
assert ti.max_tries == expected_max_tries


@pytest.mark.parametrize(
("weight_rule", "expected_weight"),
[
pytest.param("downstream", 10 + 5, id="downstream-sums-descendants"),
pytest.param("upstream", 10, id="upstream-no-ancestors"),
pytest.param("absolute", 10, id="absolute-self-only"),
],
)
def test_refresh_from_task_with_non_serialized_operator(weight_rule, expected_weight):
"""Regression: TaskInstance must work with non-serialized operators whose weight_rule is a WeightRule enum."""
with DAG(dag_id="test_dag"):
root = EmptyOperator(task_id="root", priority_weight=10, weight_rule=weight_rule)
child = EmptyOperator(task_id="child", priority_weight=5)
root >> child

ti = TI(root, run_id=None, dag_version_id=mock.MagicMock())

assert ti.priority_weight == expected_weight


def test_defer_task_returns_false_when_no_start_from_trigger(create_task_instance):
session = mock.MagicMock()
ti = create_task_instance(
Expand Down
Loading