diff --git a/airflow-core/src/airflow/models/taskinstance.py b/airflow-core/src/airflow/models/taskinstance.py index 09044b5aaf2cc..964c345a3bf6c 100644 --- a/airflow-core/src/airflow/models/taskinstance.py +++ b/airflow-core/src/airflow/models/taskinstance.py @@ -88,6 +88,7 @@ from airflow.models.taskreschedule import TaskReschedule from airflow.models.xcom import XCOM_RETURN_KEY, LazyXComSelectSequence, XComModel from airflow.settings import task_instance_mutation_hook +from airflow.task.priority_strategy import validate_and_load_priority_weight_strategy from airflow.ti_deps.dep_context import DepContext from airflow.ti_deps.dependencies_deps import REQUEUEABLE_DEPS, RUNNING_DEPS from airflow.ti_deps.deps.ready_to_reschedule import ReadyToRescheduleDep @@ -693,7 +694,10 @@ def insert_mapping( :meta private: """ - priority_weight = task.weight_rule.get_weight( + weight_rule = task.weight_rule + if not hasattr(weight_rule, "get_weight"): + weight_rule = validate_and_load_priority_weight_strategy(weight_rule) + priority_weight = weight_rule.get_weight( TaskInstance(task=task, run_id=run_id, map_index=map_index, dag_version_id=dag_version_id) ) context_carrier = new_task_run_carrier(dag_run.context_carrier) @@ -874,7 +878,10 @@ def refresh_from_task(self, task: Operator, pool_override: str | None = None) -> self.queue = task.queue self.pool = pool_override or task.pool self.pool_slots = task.pool_slots - self.priority_weight = self.task.weight_rule.get_weight(self) + weight_rule = self.task.weight_rule + if not hasattr(weight_rule, "get_weight"): + weight_rule = validate_and_load_priority_weight_strategy(weight_rule) + self.priority_weight = weight_rule.get_weight(self) self.run_as_user = task.run_as_user # Do not set max_tries to task.retries here because max_tries is a cumulative # value that needs to be stored in the db. diff --git a/airflow-core/tests/unit/models/test_taskinstance.py b/airflow-core/tests/unit/models/test_taskinstance.py index cd1dd3c337a12..9eba07daaa130 100644 --- a/airflow-core/tests/unit/models/test_taskinstance.py +++ b/airflow-core/tests/unit/models/test_taskinstance.py @@ -2653,6 +2653,26 @@ def mock_policy(task_instance: TaskInstance): assert ti.max_tries == expected_max_tries +@pytest.mark.parametrize( + ("weight_rule", "expected_weight"), + [ + pytest.param("downstream", 10 + 5, id="downstream-sums-descendants"), + pytest.param("upstream", 10, id="upstream-no-ancestors"), + pytest.param("absolute", 10, id="absolute-self-only"), + ], +) +def test_refresh_from_task_with_non_serialized_operator(weight_rule, expected_weight): + """Regression: TaskInstance must work with non-serialized operators whose weight_rule is a WeightRule enum.""" + with DAG(dag_id="test_dag"): + root = EmptyOperator(task_id="root", priority_weight=10, weight_rule=weight_rule) + child = EmptyOperator(task_id="child", priority_weight=5) + root >> child + + ti = TI(root, run_id=None, dag_version_id=mock.MagicMock()) + + assert ti.priority_weight == expected_weight + + def test_defer_task_returns_false_when_no_start_from_trigger(create_task_instance): session = mock.MagicMock() ti = create_task_instance(