diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index d7503b8f4f9c..5cf6e5c7dc1b 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -136,7 +136,7 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op, Stmt body; if (const auto* reduce = expr_body.as()) { // Case 1. Reduce compute - block_name = compute_op->name; + block_name = info->GetUniqueName(compute_op->name); int n_buffers = buffers.size(); Array lhs; diff --git a/tests/python/unittest/test_te_create_primfunc.py b/tests/python/unittest/test_te_create_primfunc.py index 48082c44a4ab..3eca60645411 100644 --- a/tests/python/unittest/test_te_create_primfunc.py +++ b/tests/python/unittest/test_te_create_primfunc.py @@ -22,7 +22,7 @@ import tvm.testing -def test_unique_name(): +def test_unique_name_complete_block(): A = te.placeholder((16, 16), name="A") B = te.compute((16, 16), lambda x, y: A[x, y] * 2, name="main") C = te.compute((16, 16), lambda x, y: B[x, y] + 1, name="main") @@ -32,6 +32,18 @@ def test_unique_name(): assert isinstance(s.get_sref(s.get_block("main_1")), tir.schedule.StmtSRef) +def test_unique_name_reduction_block(): + k1 = te.reduce_axis((0, 16), "k1") + k2 = te.reduce_axis((0, 16), "k2") + A = te.placeholder((16, 16), name="A") + B = te.compute((16,), lambda i: te.sum(A[i, k1], axis=k1), name="sum") + C = te.compute((), lambda: te.sum(B[k2], axis=k2), name="sum") + func = te.create_prim_func([A, C]) + s = tir.Schedule(func, debug_mask="all") + assert isinstance(s.get_sref(s.get_block("sum")), tir.schedule.StmtSRef) + assert isinstance(s.get_sref(s.get_block("sum_1")), tir.schedule.StmtSRef) + + def _check_workload(te_workload, tir_workload): func = te.create_prim_func(te_workload()) tvm.ir.assert_structural_equal(func, tir_workload) @@ -462,7 +474,8 @@ def test_argmax_val_idx(): if __name__ == "__main__": - test_unique_name() + test_unique_name_complete_block() + test_unique_name_reduction_block() test_matmul() test_element_wise() test_conv2d()