2222import tvm .testing
2323
2424
25- def test_unique_name ():
25+ def test_unique_name_complete_block ():
2626 A = te .placeholder ((16 , 16 ), name = "A" )
2727 B = te .compute ((16 , 16 ), lambda x , y : A [x , y ] * 2 , name = "main" )
2828 C = te .compute ((16 , 16 ), lambda x , y : B [x , y ] + 1 , name = "main" )
@@ -32,6 +32,18 @@ def test_unique_name():
3232 assert isinstance (s .get_sref (s .get_block ("main_1" )), tir .schedule .StmtSRef )
3333
3434
35+ def test_unique_name_reduction_block ():
36+ k1 = te .reduce_axis ((0 , 16 ), "k1" )
37+ k2 = te .reduce_axis ((0 , 16 ), "k2" )
38+ A = te .placeholder ((16 , 16 ), name = "A" )
39+ B = te .compute ((16 ,), lambda i : te .sum (A [i , k1 ], axis = k1 ), name = "sum" )
40+ C = te .compute ((), lambda : te .sum (B [k2 ], axis = k2 ), name = "sum" )
41+ func = te .create_prim_func ([A , C ])
42+ s = tir .Schedule (func , debug_mask = "all" )
43+ assert isinstance (s .get_sref (s .get_block ("sum" )), tir .schedule .StmtSRef )
44+ assert isinstance (s .get_sref (s .get_block ("sum_1" )), tir .schedule .StmtSRef )
45+
46+
3547def _check_workload (te_workload , tir_workload ):
3648 func = te .create_prim_func (te_workload ())
3749 tvm .ir .assert_structural_equal (func , tir_workload )
@@ -462,7 +474,8 @@ def test_argmax_val_idx():
462474
463475
464476if __name__ == "__main__" :
465- test_unique_name ()
477+ test_unique_name_complete_block ()
478+ test_unique_name_reduction_block ()
466479 test_matmul ()
467480 test_element_wise ()
468481 test_conv2d ()
0 commit comments