[TVMScript] Parser int64 support#10789
Conversation
|
Thanks @MasterJH5574 . it would be great if those context can also be part of commit message. Perhaps the merger can manually copy them into the commit message |
|
Oh I'm surprised that it was not supported! Thanks for doing that :-) |
|
Hi @shingjan, I noted that our TVMScript parser fails when For example, when parsing the following script, @T.prim_func
def elementwise_shape_int64(
A: T.Buffer[(T.int64(128), T.int64(128)), "float32"],
C: T.Buffer[(T.int64(128), T.int64(128)), "float32"],
) -> None:
B = T.alloc_buffer((T.int64(128), T.int64(128)), dtype="float32")
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj] * 2.0
for i, j in T.grid(T.int64(128), T.int64(128)):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
C[vi, vj] = B[vi, vj] + 1.0the parser fails with message Could you help take a look into this case? If the parser cannot support |
… non-int32 dtypes (#10795) _This PR is a follow-up effort of #10789, which enables the `int64` support for TIR schedule primitive Cache-Read and Cache-Write._ Prior to this PR, the IterVars of the generated cache stage block are always `int32`-typed, which might conflict with the dtypes of the domains of the IterVars. In this PR, the dtype of new IterVars are constructed according to the data types of their domains, and thereby the possible conflicts are resolved. Meanwhile the data types of the read/write regions of the cache stage blocks are also constructed according to correct data types.
… non-int32 dtypes (apache#10795) _This PR is a follow-up effort of apache#10789, which enables the `int64` support for TIR schedule primitive Cache-Read and Cache-Write._ Prior to this PR, the IterVars of the generated cache stage block are always `int32`-typed, which might conflict with the dtypes of the domains of the IterVars. In this PR, the dtype of new IterVars are constructed according to the data types of their domains, and thereby the possible conflicts are resolved. Meanwhile the data types of the read/write regions of the cache stage blocks are also constructed according to correct data types.
|
Hi @shingjan, would you like to take a look at my comment(#10789 (comment)) above? Would be super helpful if you can help us fix the type annotation with |
|
@MasterJH5574 I am taking a look. Should be able to send out a fix very soon. Thanks for the reminder! |
|
@MasterJH5574 it seems to me that it is a bug that the |
## Context
When dealing with end-to-end models, we note that some tensors may have large shapes. Thus, when designing graph-level IR, we sometimes use `int64` instead of `int32` for the shape. Below is an dense GeMM example which has `int64` input tensor shape:
```python
@tvm.script.ir_module
class Module:
@T.prim_func
def main(rxplaceholder: T.Buffer[(1, 512), "float32"], rxplaceholder_1: T.Buffer[(T.int64(1000), T.int64(512)), "float32"], T_matmul_NT: T.Buffer[(1, T.int64(1000)), "float32"]) -> None:
# function attr dict
T.func_attr({"global_symbol": "dense", "tir.noalias": True, "op_pattern": 3})
# body
# with T.block("root")
for i0_0, i1_0, i0_1, i1_1, i2_0, i0_2, i1_2, i2_1, i0_3, i1_3 in T.grid(1, 4, 1, 25, 8, 1, 10, 64, 1, 1):
with T.block("T_matmul_NT"):
i = T.axis.spatial(1, 0)
j = T.axis.spatial(T.int64(1000), i1_0 * T.int64(250) + i1_1 * T.int64(10) + i1_2)
k = T.axis.reduce(512, i2_0 * 64 + i2_1)
T.reads(T_matmul_NT[i, j], rxplaceholder[i, k], rxplaceholder_1[j, k])
T.writes(T_matmul_NT[i, j])
T.block_attr({"layout_free_placeholders":[rxplaceholder_1], "meta_schedule.tiling_structure":"SSRSRS"})
with T.init():
T_matmul_NT[i, j] = T.float32(0)
T_matmul_NT[i, j] = T_matmul_NT[i, j] + rxplaceholder[i, k] * rxplaceholder_1[j, k]
```
## Problem
Though our TVMScript printer can easily print `int64` constants, the parser had poor support for `int64`. So this PR introduces some parser support for `int64`, basically about the data type of loop variables, block iterators and block read/write regions.
Besides the parser, most of the TIR schedule primitives didn't take `int64` into account in their implementations. These schedule primitives will be fixed and updated in recent future, in followup PRs.
… non-int32 dtypes (apache#10795) _This PR is a follow-up effort of apache#10789, which enables the `int64` support for TIR schedule primitive Cache-Read and Cache-Write._ Prior to this PR, the IterVars of the generated cache stage block are always `int32`-typed, which might conflict with the dtypes of the domains of the IterVars. In this PR, the dtype of new IterVars are constructed according to the data types of their domains, and thereby the possible conflicts are resolved. Meanwhile the data types of the read/write regions of the cache stage blocks are also constructed according to correct data types.
Context
When dealing with end-to-end models, we note that some tensors may have large shapes. Thus, when designing graph-level IR, we sometimes use
int64instead ofint32for the shape. Below is an dense GeMM example which hasint64input tensor shape:Problem
Though our TVMScript printer can easily print
int64constants, the parser had poor support forint64. So this PR introduces some parser support forint64, basically about the data type of loop variables, block iterators and block read/write regions.Besides the parser, most of the TIR schedule primitives didn't take
int64into account in their implementations. These schedule primitives will be fixed and updated in recent future, in followup PRs.cc @tqchen @spectrometerHBH @junrushao1994 @jinhongyii @Hzfengsy