Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/tir/transforms/flatten_buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@ class BufferFlattener : public StmtExprMutator {
if (!is_one(predicate)) {
body = IfThenElse(predicate, std::move(body));
}
// If the block has bound predicates, transform it to if-then-else
const Optional<ObjectRef>& bound_predicate =
new_block->annotations.Get(tir::attr::require_block_var_bound_predicate);
if (bound_predicate.defined()) {
body = IfThenElse(Downcast<PrimExpr>(bound_predicate.value()), std::move(body));
}
// Step 3. Handle allocations in reverse order
for (size_t i = new_block->alloc_buffers.size(); i > 0; --i) {
const Buffer& buffer = new_block->alloc_buffers[i - 1];
Expand Down
62 changes: 62 additions & 0 deletions tests/python/unittest/test_tir_transform_flatten_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,64 @@ def annotated_loops(a: T.handle) -> None:
A[i] = 0.0


@T.prim_func
def tiled_pooling_cache_after_compute_at(a: T.handle, b: T.handle) -> None:
X = T.match_buffer(a, [224, 224], dtype="float32")
Y = T.match_buffer(b, [224, 224], dtype="float32")
# body
# with T.block("root")
cache = T.alloc_buffer([10, 10], dtype="float32")
dache = T.alloc_buffer([10, 10], dtype="float32")
for hh_0, ww_0 in T.grid(28, 28):
for ax0, ax1 in T.grid(10, 10):
with T.block("cache"):
T.reads(X[hh_0 * 8 - 1 + ax0, ww_0 * 8 - 1 + ax1])
T.writes(cache[hh_0 * 8 - 1 + ax0, ww_0 * 8 - 1 + ax1])
T.block_attr({"require_bound_predicate":hh_0 * 8 - 1 + ax0 >= 0 and hh_0 * 8 - 1 + ax0 < 224 and ww_0 * 8 - 1 + ax1 >= 0 and ww_0 * 8 - 1 + ax1 < 224})
cache[hh_0 * 8 - 1 + ax0, ww_0 * 8 - 1 + ax1] = X[hh_0 * 8 - 1 + ax0, ww_0 * 8 - 1 + ax1]
for ax0, ax1 in T.grid(10, 10):
with T.block("dache"):
T.reads(X[hh_0 * 8 - 1 + ax0, ww_0 * 8 - 1 + ax1])
T.writes(dache[hh_0 * 8 - 1 + ax0, ww_0 * 8 - 1 + ax1])
T.block_attr({"require_bound_predicate":hh_0 * 8 - 1 + ax0 >= 0 and hh_0 * 8 - 1 + ax0 < 224 and ww_0 * 8 - 1 + ax1 >= 0 and ww_0 * 8 - 1 + ax1 < 224})
dache[hh_0 * 8 - 1 + ax0, ww_0 * 8 - 1 + ax1] = X[hh_0 * 8 - 1 + ax0, ww_0 * 8 - 1 + ax1]
for hh_1, ww_1, khh, kww in T.grid(8, 8, 3, 3):
with T.block("compute"):
T.reads(Y[hh_0 * 8 + hh_1, ww_0 * 8 + ww_1], cache[hh_0 * 8 + hh_1 + khh - 1, ww_0 * 8 + ww_1 + kww - 1], dache[hh_0 * 8 + hh_1 + khh - 1, ww_0 * 8 + ww_1 + kww - 1])
T.writes(Y[hh_0 * 8 + hh_1, ww_0 * 8 + ww_1])
Y[hh_0 * 8 + hh_1, ww_0 * 8 + ww_1] = T.max(Y[hh_0 * 8 + hh_1, ww_0 * 8 + ww_1],
T.if_then_else(T.likely(1 <= hh_0 * 8 + hh_1 + khh, dtype="bool")
and T.likely(hh_0 * 8 + hh_1 + khh < 225, dtype="bool")
and T.likely(1 <= ww_0 * 8 + ww_1 + kww, dtype="bool")
and T.likely(ww_0 * 8 + ww_1 + kww < 225, dtype="bool"),
cache[hh_0 * 8 + hh_1 + khh - 1, ww_0 * 8 + ww_1 + kww - 1]
+ dache[hh_0 * 8 + hh_1 + khh - 1, ww_0 * 8 + ww_1 + kww - 1],
T.float32(0), dtype="float32"))


@T.prim_func
def flattened_tiled_pooling_cache_after_compute_at(X: T.Buffer[(224, 224), "float32"], Y: T.Buffer[(224, 224), "float32"]) -> None:
cache = T.allocate([100], "float32", "global")
dache = T.allocate([100], "float32", "global")
for hh_0, ww_0 in T.grid(28, 28):
for ax0, ax1 in T.grid(10, 10):
if 1 <= hh_0 * 8 + ax0 and hh_0 * 8 + ax0 < 225 and 1 <= ww_0 * 8 + ax1 and ww_0 * 8 + ax1 < 225:
T.store(cache, hh_0 * 80 + ax0 * 10 + ww_0 * 8 + ax1 - 11, T.load("float32", X.data, hh_0 * 1792 + ax0 * 224 + ww_0 * 8 + ax1 - 225), True)
for ax0, ax1 in T.grid(10, 10):
if 1 <= hh_0 * 8 + ax0 and hh_0 * 8 + ax0 < 225 and 1 <= ww_0 * 8 + ax1 and ww_0 * 8 + ax1 < 225:
T.store(dache, hh_0 * 80 + ax0 * 10 + ww_0 * 8 + ax1 - 11, T.load("float32", X.data, hh_0 * 1792 + ax0 * 224 + ww_0 * 8 + ax1 - 225), True)
for hh_1, ww_1, khh, kww in T.grid(8, 8, 3, 3):
T.store(Y.data, hh_0 * 1792 + hh_1 * 224 + ww_0 * 8 + ww_1,
T.max(T.load("float32", Y.data, hh_0 * 1792 + hh_1 * 224 + ww_0 * 8 + ww_1),
T.if_then_else(T.likely(1 <= hh_0 * 8 + hh_1 + khh, dtype="bool")
and T.likely(hh_0 * 8 + hh_1 + khh < 225, dtype="bool")
and T.likely(1 <= ww_0 * 8 + ww_1 + kww, dtype="bool")
and T.likely(ww_0 * 8 + ww_1 + kww < 225, dtype="bool"),
T.load("float32", cache, hh_0 * 80 + hh_1 * 10 + khh * 10 + ww_0 * 8 + ww_1 + kww - 11)
+ T.load("float32", dache, hh_0 * 80 + hh_1 * 10 + khh * 10 + ww_0 * 8 + ww_1 + kww - 11),
T.float32(0), dtype="float32")), True)


def test_elementwise():
_check(compacted_elementwise_func, flattened_elementwise_func)

Expand Down Expand Up @@ -305,6 +363,10 @@ def test_annotated_loops():
tvm.ir.assert_structural_equal(attr3.value, tvm.tir.FloatImm("float32", 0.0))


def test_bound_predicate():
_check(tiled_pooling_cache_after_compute_at, flattened_tiled_pooling_cache_after_compute_at)


if __name__ == "__main__":
test_elementwise()
test_gpu_workload()
Expand Down