Skip to content

Commit 57597f6

Browse files
authored
[Fix][TIR]fix symbolic strides lower (#16000)
* [Fix][TIR]fix symbolic strides lower * [Fix][TIR] run the black formatter
1 parent 7eedea5 commit 57597f6

2 files changed

Lines changed: 50 additions & 1 deletion

File tree

src/tir/transforms/ir_utils.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,8 @@ Array<PrimExpr> GetBufferAllocationShape(const Buffer& buffer) {
417417
if (buffer->strides.size()) {
418418
ICHECK_EQ(buffer->shape.size(), buffer->strides.size());
419419
for (size_t i = buffer->strides.size() - 1; i > 0; --i) {
420-
ICHECK(is_zero(floormod(buffer->strides[i - 1], buffer->strides[i])));
420+
ICHECK(
421+
arith::Analyzer().CanProveEqual(floormod(buffer->strides[i - 1], buffer->strides[i]), 0));
421422
alloc_shape.Set(i, buffer->strides[i - 1] / buffer->strides[i]);
422423
}
423424
}

tests/python/unittest/test_tir_transform_lower_opaque_block.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,50 @@ def transformed_strided_buffer_func(
250250
C[i0 * 4 + i1, j] = B[i1, j] * T.float32(2)
251251

252252

253+
@T.prim_func
254+
def compacted_symbolic_strided_buffer_func(a: T.handle) -> None:
255+
n = T.int32()
256+
A = T.match_buffer(a, (1, n, 10240))
257+
padded_size = T.meta_var(T.min((n + 63) // 64 * 64, 96))
258+
# with T.block("root"):
259+
for i, j, k in T.grid(((n + 63) // 64 * 4 + 7) // 8, 2, 160):
260+
with T.block(""):
261+
A_pad_shared_dyn = T.alloc_buffer(
262+
(1, padded_size, 64), strides=(72 * padded_size, 72, 1), scope="shared.dyn"
263+
)
264+
for ax0, ax1 in T.grid(96, 64):
265+
with T.block("A_pad_shared.dyn"):
266+
T.where(i * 128 + j * 32 + ax0 < (n + 63) // 64 * 64)
267+
A_pad_shared_dyn[0, ax0, ax1] = T.if_then_else(
268+
i * 128 + j * 32 + ax0 < n,
269+
A[0, i * 128 + j * 32 + ax0, k * 64 + ax1],
270+
T.float32(0),
271+
)
272+
273+
274+
@T.prim_func
275+
def transformed_symbolic_strided_buffer_func(a: T.handle):
276+
n = T.int32()
277+
A = T.match_buffer(a, (1, n, 10240))
278+
for i, j, k in T.grid(((n + 63) // 64 * 4 + 7) // 8, 2, 160):
279+
A_pad_shared_dyn = T.allocate(
280+
[1, T.min((n + 63) // 64 * 64, 96), 72], "float32", "shared.dyn"
281+
)
282+
A_pad_shared_dyn_1 = T.decl_buffer(
283+
(1, T.min((n + 63) // 64 * 64, 96), 64),
284+
data=A_pad_shared_dyn,
285+
strides=(72 * T.min((n + 63) // 64 * 64, 96), 72, 1),
286+
scope="shared.dyn",
287+
)
288+
for ax0, ax1 in T.grid(96, 64):
289+
if i * 128 + j * 32 + ax0 < (n + 63) // 64 * 64:
290+
A_pad_shared_dyn_1[0, ax0, ax1] = T.if_then_else(
291+
i * 128 + j * 32 + ax0 < n,
292+
A[0, i * 128 + j * 32 + ax0, k * 64 + ax1],
293+
T.float32(0),
294+
)
295+
296+
253297
@T.prim_func
254298
def annotated_loops(a: T.handle) -> None:
255299
A = T.match_buffer(a, (16,), "float32")
@@ -301,6 +345,10 @@ def test_strided_buffer():
301345
_check(compacted_strided_buffer_func, transformed_strided_buffer_func)
302346

303347

348+
def test_symbolic_strided_buffer():
349+
_check(compacted_symbolic_strided_buffer_func, transformed_symbolic_strided_buffer_func)
350+
351+
304352
def test_lower_te():
305353
x = te.placeholder((1,))
306354
y = te.compute((1,), lambda i: x[i] + 2)

0 commit comments

Comments
 (0)