Skip to content

Commit b899690

Browse files
author
wrongtest
committed
fix create reduce block with spatial iter dependent init value
1 parent 541f9c2 commit b899690

File tree

2 files changed

+84
-6
lines changed

2 files changed

+84
-6
lines changed

src/te/operation/create_primfunc.cc

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,10 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op,
228228
}
229229

230230
// Step 4. Create block body.
231+
// helper to transform the expr and remap iters to the block domain
232+
auto f_transform_and_remap = [&](const PrimExpr& e) {
233+
return Substitute(info->transformer(e), var_map);
234+
};
231235
String block_name{nullptr};
232236
Optional<Stmt> init = NullOpt;
233237
Stmt body;
@@ -246,8 +250,7 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op,
246250
// - A RHS operand is the value to be reduced.
247251
for (int i = 0; i < n_buffers; ++i) {
248252
const PrimExpr& left = BufferLoad(buffers[i], indices);
249-
const PrimExpr& right =
250-
analyzer->Simplify(Substitute(info->transformer(reduce->source[i]), var_map));
253+
const PrimExpr& right = analyzer->Simplify(f_transform_and_remap(reduce->source[i]));
251254
lhs.push_back(left);
252255
rhs.push_back(right);
253256
ICHECK_EQ(left->dtype, right->dtype);
@@ -267,13 +270,15 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op,
267270
// then store the value of the variables into the target buffer positions.
268271
for (int i = 0; i < n_buffers; ++i) {
269272
const Buffer& buffer = buffers[i];
270-
init_stmts.push_back(BufferStore(buffer, reduce->combiner->identity_element[i], indices));
273+
PrimExpr identity = f_transform_and_remap(reduce->combiner->identity_element[i]);
274+
init_stmts.push_back(BufferStore(buffer, identity, indices));
271275
PrimExpr value{nullptr};
272276
if (n_buffers > 1) {
273277
temp_vars.push_back(Var("v_" + buffer->name, PrimType(lhs[i].dtype())));
274278
value = temp_vars.back();
275279
} else {
276-
value = reduce->combiner.get()->operator()(lhs, rhs)[i];
280+
PrimExpr combined = reduce->combiner.get()->operator()(lhs, rhs)[i];
281+
value = f_transform_and_remap(combined);
277282
}
278283
body_stmts.push_back(BufferStore(buffer, value, indices));
279284
}
@@ -283,15 +288,15 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op,
283288
if (n_buffers > 1) {
284289
// When there are multiple buffers, we wrap the body with LetStmts.
285290
for (int i = n_buffers - 1; i >= 0; --i) {
286-
PrimExpr value = reduce->combiner.get()->operator()(lhs, rhs)[i];
291+
PrimExpr value = f_transform_and_remap(reduce->combiner.get()->operator()(lhs, rhs)[i]);
287292
body = LetStmt(temp_vars[i], std::move(value), std::move(body));
288293
}
289294
}
290295
} else {
291296
// Case 2. Data parallel compute
292297
ICHECK_EQ(tensors.size(), 1);
293298
block_name = info->FreshName(tensors[0]->GetNameHint());
294-
const PrimExpr& compute_body = Substitute(info->transformer(expr_body), var_map);
299+
const PrimExpr& compute_body = f_transform_and_remap(expr_body);
295300
body = BufferStore(info->tensor2buffers[tensors[0]], analyzer->Simplify(compute_body), indices);
296301
}
297302

tests/python/te/test_te_create_primfunc.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -814,5 +814,78 @@ def test_with_var_input():
814814
_check_workload(te_slice_with_var_input, tir_slice_with_var_input, index_dtype_override="int64")
815815

816816

817+
def test_loop_aware_initial_value():
818+
"""Test initial value aware of spatial iter position"""
819+
820+
@T.prim_func
821+
def tir_workload(var_a: T.handle, var_b: T.handle, var_sum_red: T.handle):
822+
T.func_attr({"tir.noalias": T.bool(True), "global_symbol": "main"})
823+
a = T.match_buffer(var_a, (5, 5))
824+
b = T.match_buffer(var_b, (5,))
825+
sum_red = T.match_buffer(var_sum_red, (5,))
826+
for i, ax in T.grid(5, 5):
827+
with T.block("sum_red"):
828+
v_i, v_ax = T.axis.remap("SR", [i, ax])
829+
T.reads(b[v_i], a[v_i, v_ax])
830+
T.writes(sum_red[v_i])
831+
with T.init():
832+
sum_red[v_i] = b[v_i]
833+
sum_red[v_i] = sum_red[v_i] + a[v_i, v_ax]
834+
835+
def te_workload():
836+
data = te.placeholder((5, 5), "float32", "a")
837+
init = te.placeholder((5,), "float32", "b")
838+
ax = te.reduce_axis((0, 5), "ax")
839+
sum_red = te.compute(
840+
(5,),
841+
lambda i: te.comm_reducer(
842+
lambda x, y: x + y,
843+
lambda t: init[i],
844+
)(data[i, ax], axis=[ax]),
845+
name="sum_red",
846+
)
847+
return [data, init, sum_red]
848+
849+
_check_workload(te_workload, tir_workload)
850+
851+
852+
def test_loop_aware_reducer_combiner():
853+
"""Test combiner aware of spatial iter position"""
854+
855+
@T.prim_func
856+
def tir_workload(var_a: T.handle, var_b: T.handle, var_sum_red: T.handle):
857+
T.func_attr({"tir.noalias": T.bool(True), "global_symbol": "main"})
858+
a = T.match_buffer(var_a, (5, 5))
859+
b = T.match_buffer(var_b, (5,))
860+
sum_red = T.match_buffer(var_sum_red, (5,))
861+
for i, ax in T.grid(5, 5):
862+
with T.block("sum_red"):
863+
v_i = T.axis.spatial(5, i)
864+
v_ax = T.axis.reduce(5, ax)
865+
T.reads(a[v_i, 0:5])
866+
T.writes(sum_red[v_i])
867+
with T.init():
868+
sum_red[v_i] = T.float32(0.0)
869+
sum_red[v_i] = T.if_then_else(
870+
a[v_i, sum_red[v_i]] < a[v_i, v_ax], sum_red[v_i], T.Cast("float32", v_ax)
871+
)
872+
873+
def te_workload():
874+
data = te.placeholder((5, 5), "float32", "a")
875+
init = te.placeholder((5,), "float32", "b")
876+
ax = te.reduce_axis((0, 5), "ax")
877+
sum_red = te.compute(
878+
(5,),
879+
lambda i: te.comm_reducer(
880+
lambda x, y: te.if_then_else(data[i, x] < y, x, ax),
881+
lambda _: te.const(0, "float32"),
882+
)(data[i, ax], axis=[ax]),
883+
name="sum_red",
884+
)
885+
return [data, init, sum_red]
886+
887+
_check_workload(te_workload, tir_workload)
888+
889+
817890
if __name__ == "__main__":
818891
tvm.testing.main()

0 commit comments

Comments
 (0)