@@ -742,10 +742,6 @@ def validate_deserialized_task(
742742 # We store the string, real dag has the actual code
743743 "_pre_execute_hook" ,
744744 "_post_execute_hook" ,
745- "on_execute_callback" ,
746- "on_failure_callback" ,
747- "on_success_callback" ,
748- "on_retry_callback" ,
749745 # Checked separately
750746 "resources" ,
751747 "on_failure_fail_dagrun" ,
@@ -811,11 +807,23 @@ def validate_deserialized_task(
811807 default_partial_kwargs = (
812808 BaseOperator .partial (task_id = "_" )._expand (EXPAND_INPUT_EMPTY , strict = False ).partial_kwargs
813809 )
810+
811+ # These are added in `_TaskDecorator` e.g. when @setup or @teardown task is passed
812+ default_decorator_partial_kwargs = {
813+ "is_setup" : False ,
814+ "is_teardown" : False ,
815+ "on_failure_fail_dagrun" : False ,
816+ }
814817 serialized_partial_kwargs = {
815818 ** default_partial_kwargs ,
819+ ** default_decorator_partial_kwargs ,
816820 ** serialized_task .partial_kwargs ,
817821 }
818- original_partial_kwargs = {** default_partial_kwargs , ** task .partial_kwargs }
822+ original_partial_kwargs = {
823+ ** default_partial_kwargs ,
824+ ** default_decorator_partial_kwargs ,
825+ ** task .partial_kwargs ,
826+ }
819827 assert serialized_partial_kwargs == original_partial_kwargs
820828
821829 # ExpandInputs have different classes between scheduler and definition
@@ -1415,6 +1423,11 @@ def test_no_new_fields_added_to_base_operator(self):
14151423 "execution_timeout" : None ,
14161424 "executor" : None ,
14171425 "executor_config" : {},
1426+ "has_on_execute_callback" : False ,
1427+ "has_on_failure_callback" : False ,
1428+ "has_on_retry_callback" : False ,
1429+ "has_on_skipped_callback" : False ,
1430+ "has_on_success_callback" : False ,
14181431 "ignore_first_depends_on_past" : False ,
14191432 "is_setup" : False ,
14201433 "is_teardown" : False ,
@@ -1423,12 +1436,7 @@ def test_no_new_fields_added_to_base_operator(self):
14231436 "max_active_tis_per_dag" : None ,
14241437 "max_active_tis_per_dagrun" : None ,
14251438 "max_retry_delay" : None ,
1426- "on_execute_callback" : [],
14271439 "on_failure_fail_dagrun" : False ,
1428- "on_failure_callback" : [],
1429- "on_retry_callback" : [],
1430- "on_skipped_callback" : [],
1431- "on_success_callback" : [],
14321440 "outlets" : [],
14331441 "owner" : "airflow" ,
14341442 "params" : {},
@@ -3011,6 +3019,8 @@ def operator_extra_links(self):
30113019 assert mapped_task .extra_links == sorted ({"airflow" , "github" })
30123020
30133021
3022+ # TODO: Remove xfail
3023+ @pytest .mark .xfail (reason = "TODO: Need to add support for v1 & v2 to v3" )
30143024def test_handle_v1_serdag ():
30153025 v1 = {
30163026 "__version" : 1 ,
@@ -3296,3 +3306,129 @@ def test_handle_v1_serdag():
32963306 del expected ["dag" ]["tasks" ][1 ]["__var" ]["_operator_extra_links" ]
32973307
32983308 assert v1 == expected
3309+
3310+
3311+ def dummy_callback ():
3312+ pass
3313+
3314+
3315+ @pytest .mark .parametrize (
3316+ "callback_config,expected_flags,is_mapped" ,
3317+ [
3318+ # Regular operator tests
3319+ (
3320+ {
3321+ "on_failure_callback" : dummy_callback ,
3322+ "on_retry_callback" : [dummy_callback , dummy_callback ],
3323+ "on_success_callback" : dummy_callback ,
3324+ },
3325+ {"has_on_failure_callback" : True , "has_on_retry_callback" : True , "has_on_success_callback" : True },
3326+ False ,
3327+ ),
3328+ (
3329+ {}, # No callbacks
3330+ {
3331+ "has_on_failure_callback" : False ,
3332+ "has_on_retry_callback" : False ,
3333+ "has_on_success_callback" : False ,
3334+ },
3335+ False ,
3336+ ),
3337+ (
3338+ {"on_failure_callback" : [], "on_success_callback" : None }, # Empty callbacks
3339+ {"has_on_failure_callback" : False , "has_on_success_callback" : False },
3340+ False ,
3341+ ),
3342+ # Mapped operator tests
3343+ (
3344+ {"on_failure_callback" : dummy_callback , "on_success_callback" : [dummy_callback , dummy_callback ]},
3345+ {"has_on_failure_callback" : True , "has_on_success_callback" : True },
3346+ True ,
3347+ ),
3348+ (
3349+ {}, # Mapped operator without callbacks
3350+ {"has_on_failure_callback" : False , "has_on_success_callback" : False },
3351+ True ,
3352+ ),
3353+ ],
3354+ )
3355+ def test_task_callback_boolean_optimization (callback_config , expected_flags , is_mapped ):
3356+ """Test that task callbacks are optimized using has_on_*_callback boolean flags."""
3357+ dag = DAG (dag_id = "test_callback_dag" , start_date = datetime (2020 , 1 , 1 ))
3358+
3359+ if is_mapped :
3360+ # Create mapped operator
3361+ task = BashOperator .partial (task_id = "test_task" , dag = dag , ** callback_config ).expand (
3362+ bash_command = ["echo 1" , "echo 2" ]
3363+ )
3364+
3365+ # Serialize and deserialize
3366+ serialized = BaseSerialization .serialize (task )
3367+ deserialized = BaseSerialization .deserialize (serialized )
3368+
3369+ # For mapped operators, check partial_kwargs
3370+ serialized_data = serialized .get ("__var" , {}).get ("partial_kwargs" , {})
3371+
3372+ # Test serialization
3373+ for flag , expected in expected_flags .items ():
3374+ if expected :
3375+ assert flag in serialized_data
3376+ assert serialized_data [flag ] is True
3377+ else :
3378+ assert serialized_data .get (flag , False ) is False
3379+
3380+ # Test deserialized properties
3381+ for flag , expected in expected_flags .items ():
3382+ assert getattr (deserialized , flag ) is expected
3383+
3384+ else :
3385+ # Create regular operator
3386+ task = BashOperator (task_id = "test_task" , bash_command = "echo test" , dag = dag , ** callback_config )
3387+
3388+ # Serialize and deserialize
3389+ serialized = BaseSerialization .serialize (task )
3390+ deserialized = BaseSerialization .deserialize (serialized )
3391+
3392+ # For regular operators, check top-level
3393+ serialized_data = serialized .get ("__var" , {})
3394+
3395+ # Test serialization (only True values are stored)
3396+ for flag , expected in expected_flags .items ():
3397+ if expected :
3398+ assert serialized_data .get (flag , False ) is True
3399+ else :
3400+ assert serialized_data .get (flag , False ) is False
3401+
3402+ # Test deserialized properties
3403+ for flag , expected in expected_flags .items ():
3404+ assert getattr (deserialized , flag ) is expected
3405+
3406+
3407+ def test_task_callback_properties_exist ():
3408+ """Test that all callback boolean properties exist on both regular and mapped operators."""
3409+ dag = DAG (dag_id = "test_dag" , start_date = datetime (2020 , 1 , 1 ))
3410+
3411+ # Regular operator
3412+ regular_task = BashOperator (task_id = "regular" , bash_command = "echo test" , dag = dag )
3413+
3414+ # Mapped operator
3415+ mapped_task = BashOperator .partial (task_id = "mapped" , dag = dag ).expand (bash_command = ["echo 1" ])
3416+
3417+ callback_properties = [
3418+ "has_on_execute_callback" ,
3419+ "has_on_failure_callback" ,
3420+ "has_on_success_callback" ,
3421+ "has_on_retry_callback" ,
3422+ "has_on_skipped_callback" ,
3423+ ]
3424+
3425+ for prop in callback_properties :
3426+ assert hasattr (regular_task , prop ), f"Regular operator missing { prop } "
3427+ assert hasattr (mapped_task , prop ), f"Mapped operator missing { prop } "
3428+
3429+ # Serialize and check deserialized versions too
3430+ serialized_regular = BaseSerialization .deserialize (BaseSerialization .serialize (regular_task ))
3431+ serialized_mapped = BaseSerialization .deserialize (BaseSerialization .serialize (mapped_task ))
3432+
3433+ assert hasattr (serialized_regular , prop ), f"Deserialized regular operator missing { prop } "
3434+ assert hasattr (serialized_mapped , prop ), f"Deserialized mapped operator missing { prop } "
0 commit comments