Skip to content

Commit 7a6281e

Browse files
authored
[BugFix] Generate unique names for reduction blocks (#10726)
1 parent 3bd52e6 commit 7a6281e

File tree

2 files changed

+16
-3
lines changed

2 files changed

+16
-3
lines changed

src/te/operation/create_primfunc.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op,
136136
Stmt body;
137137
if (const auto* reduce = expr_body.as<ReduceNode>()) {
138138
// Case 1. Reduce compute
139-
block_name = compute_op->name;
139+
block_name = info->GetUniqueName(compute_op->name);
140140
int n_buffers = buffers.size();
141141

142142
Array<PrimExpr> lhs;

tests/python/unittest/test_te_create_primfunc.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import tvm.testing
2323

2424

25-
def test_unique_name():
25+
def test_unique_name_complete_block():
2626
A = te.placeholder((16, 16), name="A")
2727
B = te.compute((16, 16), lambda x, y: A[x, y] * 2, name="main")
2828
C = te.compute((16, 16), lambda x, y: B[x, y] + 1, name="main")
@@ -32,6 +32,18 @@ def test_unique_name():
3232
assert isinstance(s.get_sref(s.get_block("main_1")), tir.schedule.StmtSRef)
3333

3434

35+
def test_unique_name_reduction_block():
36+
k1 = te.reduce_axis((0, 16), "k1")
37+
k2 = te.reduce_axis((0, 16), "k2")
38+
A = te.placeholder((16, 16), name="A")
39+
B = te.compute((16,), lambda i: te.sum(A[i, k1], axis=k1), name="sum")
40+
C = te.compute((), lambda: te.sum(B[k2], axis=k2), name="sum")
41+
func = te.create_prim_func([A, C])
42+
s = tir.Schedule(func, debug_mask="all")
43+
assert isinstance(s.get_sref(s.get_block("sum")), tir.schedule.StmtSRef)
44+
assert isinstance(s.get_sref(s.get_block("sum_1")), tir.schedule.StmtSRef)
45+
46+
3547
def _check_workload(te_workload, tir_workload):
3648
func = te.create_prim_func(te_workload())
3749
tvm.ir.assert_structural_equal(func, tir_workload)
@@ -462,7 +474,8 @@ def test_argmax_val_idx():
462474

463475

464476
if __name__ == "__main__":
465-
test_unique_name()
477+
test_unique_name_complete_block()
478+
test_unique_name_reduction_block()
466479
test_matmul()
467480
test_element_wise()
468481
test_conv2d()

0 commit comments

Comments
 (0)