[Fix][TIR]fix symbolic strides lower#15986
[Fix][TIR]fix symbolic strides lower#15986JackWeiw wants to merge 2 commits intoapache:mainfrom JackWeiw:main
Conversation
| def transformed_symbolic_strided_buffer_func(a: T.handle): | ||
| n = T.int64() | ||
| A = T.match_buffer(a, (1, n, 10240)) | ||
| for i, j, k in T.grid(((n + T.int64(63)) // T.int64(64) * T.int64(4) + T.int64(7)) // T.int64(8), 2, 160): |
There was a problem hiding this comment.
Does the test case depend on T.int64 datatypes? If not, this would be much more readable by using T.int32. Because it is the default integer type in TVMScript, it wouldn't require the explicit type conversions. (e.g. (n + 63) instead of (n + T.int64(63)).
There was a problem hiding this comment.
Thank u for the adivice, i've modified it using T.int32 and pulled out into padded_size = T.meta_var(T.min((n + T.int64(63)) // T.int64(64) * T.int64(64)) in the test case.
| A = T.match_buffer(a, (1, n, 10240)) | ||
| for i, j, k in T.grid(((n + T.int64(63)) // T.int64(64) * T.int64(4) + T.int64(7)) // T.int64(8), 2, 160): | ||
| A_pad_shared_dyn = T.allocate([1, T.min((n + T.int64(63)) // T.int64(64) * T.int64(64), T.int64(96)), 72], "float32", "shared.dyn") | ||
| A_pad_shared_dyn_1 = T.decl_buffer((1, T.min((n + T.int64(63)) // T.int64(64) * T.int64(64), T.int64(96)), 64), data=A_pad_shared_dyn, strides=(T.int64(72) * T.min((n + T.int64(63)) // T.int64(64) * T.int64(64), T.int64(96)), 72, 1), scope="shared.dyn") |
There was a problem hiding this comment.
The expression T.min((n + T.int64(63)) // T.int64(64) * T.int64(64) occurs frequently, and makes it difficult to read. Can this be pulled out into padded_size = T.meta_var(T.min((n + T.int64(63)) // T.int64(64) * T.int64(64))? The generated TIR will still contain the full expression, but the test case can be easier to read.
| A_pad_shared_dyn = T.alloc_buffer((1, T.min((n + 63) // 64 * 64, 96), 64), "float32", strides=(72 * T.min((n + 63) // 64 * 64, 96), 72, 1), scope="shared.dyn") | ||
| for ax0, ax1 in T.grid(96, 64): | ||
| with T.block("A_pad_shared.dyn"): | ||
| T.where(i * 128 + j * 32 + ax0 < (n + 63) // 64 * 64) |
There was a problem hiding this comment.
This looks like the same T.reads and T.writes annotations as would be automatically inferred from the block's body. Unless the test depends on a specific override to use non-default read/write annotations, it should be removed for readability.
| def compacted_symbolic_strided_buffer_func(a: T.handle) -> None: | ||
| n = T.int64() | ||
| A = T.match_buffer(a, (1, n, 10240), "float32") | ||
| for i, j, k in T.grid(((n + 63) // 64 * 4 + 7) // 8, 2, 160): |
There was a problem hiding this comment.
Unrelated, the presence of this expression is kind of odd to me. Assuming this example came from a TIR printout, I would have expected ((n + 63) // 64 * 4 + 7) // 8 to be simplified to the equivalent (n + 127) // 128. The fact that it didn't simplify may indicate that I should take a look at the CanonicalSimplifier.
| auto src_offset = load->indices[0]; | ||
| auto dst_offset = store->indices[0]; | ||
| Array<PrimExpr> args = {store->buffer->data, tir::Mul(dst_offset, PrimExpr(index_factor)), | ||
| Array<PrimExpr> args = {store->buffer->data, mul(dst_offset, PrimExpr(index_factor)), |
There was a problem hiding this comment.
Since your description mentions this as a separate bug, can it either be split out into a separate PR, or (since it is a relatively small change), have a test case added for it?
There was a problem hiding this comment.
I open a new PR here . Please have a check.
I will open a new PR to fix dtype mismatch bug in PASS InjectPTXAsyncCopy after symbolic strides PR is merged
compact_buffer_regionPASS modify shared buffer stride[0] toT.int64(72) * T.min((n + T.int64(63)) // T.int64(64) * T.int64(64), T.int64(96))and stride[1] isT.int64(72)but in LowerOpaqueBlock PASS it report error:
InternalError: Check failed: (is_zero(floormod(buffer->strides[i - 1], buffer->strides[i]))) is false:
For more detaied discuss, see here
Another bug occurs in PASS InjectPTXAsyncCopy .
that is dst_offset.dtype could be int64, the dtype of PrimExpr(index_factor) would be set to default to int32.
cause dtype inconsistent when calling tir::Mul.
To reproduce the problem in InjectPTXAsyncCopy, see script here