@@ -814,5 +814,78 @@ def test_with_var_input():
814814 _check_workload (te_slice_with_var_input , tir_slice_with_var_input , index_dtype_override = "int64" )
815815
816816
817+ def test_loop_aware_initial_value ():
818+ """Test initial value aware of spatial iter position"""
819+
820+ @T .prim_func
821+ def tir_workload (var_a : T .handle , var_b : T .handle , var_sum_red : T .handle ):
822+ T .func_attr ({"tir.noalias" : T .bool (True ), "global_symbol" : "main" })
823+ a = T .match_buffer (var_a , (5 , 5 ))
824+ b = T .match_buffer (var_b , (5 ,))
825+ sum_red = T .match_buffer (var_sum_red , (5 ,))
826+ for i , ax in T .grid (5 , 5 ):
827+ with T .block ("sum_red" ):
828+ v_i , v_ax = T .axis .remap ("SR" , [i , ax ])
829+ T .reads (b [v_i ], a [v_i , v_ax ])
830+ T .writes (sum_red [v_i ])
831+ with T .init ():
832+ sum_red [v_i ] = b [v_i ]
833+ sum_red [v_i ] = sum_red [v_i ] + a [v_i , v_ax ]
834+
835+ def te_workload ():
836+ data = te .placeholder ((5 , 5 ), "float32" , "a" )
837+ init = te .placeholder ((5 ,), "float32" , "b" )
838+ ax = te .reduce_axis ((0 , 5 ), "ax" )
839+ sum_red = te .compute (
840+ (5 ,),
841+ lambda i : te .comm_reducer (
842+ lambda x , y : x + y ,
843+ lambda t : init [i ],
844+ )(data [i , ax ], axis = [ax ]),
845+ name = "sum_red" ,
846+ )
847+ return [data , init , sum_red ]
848+
849+ _check_workload (te_workload , tir_workload )
850+
851+
852+ def test_loop_aware_reducer_combiner ():
853+ """Test combiner aware of spatial iter position"""
854+
855+ @T .prim_func
856+ def tir_workload (var_a : T .handle , var_b : T .handle , var_sum_red : T .handle ):
857+ T .func_attr ({"tir.noalias" : T .bool (True ), "global_symbol" : "main" })
858+ a = T .match_buffer (var_a , (5 , 5 ))
859+ b = T .match_buffer (var_b , (5 ,))
860+ sum_red = T .match_buffer (var_sum_red , (5 ,))
861+ for i , ax in T .grid (5 , 5 ):
862+ with T .block ("sum_red" ):
863+ v_i = T .axis .spatial (5 , i )
864+ v_ax = T .axis .reduce (5 , ax )
865+ T .reads (a [v_i , 0 :5 ])
866+ T .writes (sum_red [v_i ])
867+ with T .init ():
868+ sum_red [v_i ] = T .float32 (0.0 )
869+ sum_red [v_i ] = T .if_then_else (
870+ a [v_i , sum_red [v_i ]] < a [v_i , v_ax ], sum_red [v_i ], T .Cast ("float32" , v_ax )
871+ )
872+
873+ def te_workload ():
874+ data = te .placeholder ((5 , 5 ), "float32" , "a" )
875+ init = te .placeholder ((5 ,), "float32" , "b" )
876+ ax = te .reduce_axis ((0 , 5 ), "ax" )
877+ sum_red = te .compute (
878+ (5 ,),
879+ lambda i : te .comm_reducer (
880+ lambda x , y : te .if_then_else (data [i , x ] < y , x , ax ),
881+ lambda _ : te .const (0 , "float32" ),
882+ )(data [i , ax ], axis = [ax ]),
883+ name = "sum_red" ,
884+ )
885+ return [data , init , sum_red ]
886+
887+ _check_workload (te_workload , tir_workload )
888+
889+
817890if __name__ == "__main__" :
818891 tvm .testing .main ()
0 commit comments