Skip to content

Commit 8813d0a

Browse files
authored
[TVMScript] Parser int64 support (#10789)
## Context When dealing with end-to-end models, we note that some tensors may have large shapes. Thus, when designing graph-level IR, we sometimes use `int64` instead of `int32` for the shape. Below is an dense GeMM example which has `int64` input tensor shape: ```python @tvm.script.ir_module class Module: @T.prim_func def main(rxplaceholder: T.Buffer[(1, 512), "float32"], rxplaceholder_1: T.Buffer[(T.int64(1000), T.int64(512)), "float32"], T_matmul_NT: T.Buffer[(1, T.int64(1000)), "float32"]) -> None: # function attr dict T.func_attr({"global_symbol": "dense", "tir.noalias": True, "op_pattern": 3}) # body # with T.block("root") for i0_0, i1_0, i0_1, i1_1, i2_0, i0_2, i1_2, i2_1, i0_3, i1_3 in T.grid(1, 4, 1, 25, 8, 1, 10, 64, 1, 1): with T.block("T_matmul_NT"): i = T.axis.spatial(1, 0) j = T.axis.spatial(T.int64(1000), i1_0 * T.int64(250) + i1_1 * T.int64(10) + i1_2) k = T.axis.reduce(512, i2_0 * 64 + i2_1) T.reads(T_matmul_NT[i, j], rxplaceholder[i, k], rxplaceholder_1[j, k]) T.writes(T_matmul_NT[i, j]) T.block_attr({"layout_free_placeholders":[rxplaceholder_1], "meta_schedule.tiling_structure":"SSRSRS"}) with T.init(): T_matmul_NT[i, j] = T.float32(0) T_matmul_NT[i, j] = T_matmul_NT[i, j] + rxplaceholder[i, k] * rxplaceholder_1[j, k] ``` ## Problem Though our TVMScript printer can easily print `int64` constants, the parser had poor support for `int64`. So this PR introduces some parser support for `int64`, basically about the data type of loop variables, block iterators and block read/write regions. Besides the parser, most of the TIR schedule primitives didn't take `int64` into account in their implementations. These schedule primitives will be fixed and updated in recent future, in followup PRs.
1 parent e956eb3 commit 8813d0a

File tree

6 files changed

+39
-8
lines changed

6 files changed

+39
-8
lines changed

python/tvm/script/tir/node.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,9 @@ def check_index(index: Union[int, PrimExpr]):
9797
report_error("Negative index is not allowed during buffer access", span)
9898
elif isinstance(index, PrimExpr):
9999
element_dtype = index.dtype.split("x", maxsplit=1)[0]
100-
if element_dtype != "int32":
100+
if element_dtype[:3] != "int":
101101
report_error(
102-
"index expected an int32 type PrimExpr but got " + str(index.dtype),
102+
"index expected an integer type PrimExpr but got " + str(index.dtype),
103103
index.span,
104104
)
105105
else:

python/tvm/script/tir/special_stmt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -486,7 +486,6 @@ def axis(
486486
if var_name in [iter_var.var.name for iter_var in block_scope.iter_vars]:
487487
self.context.report_error("Duplicate block axis " + var_name, self.node.span)
488488

489-
block_var = tvm.tir.Var(var_name, dtype="int32")
490489
dom = tvm.runtime.convert(dom)
491490
if isinstance(dom, PrimExpr):
492491
dom = tvm.ir.Range(dom)
@@ -497,6 +496,7 @@ def axis(
497496
f"Block axis domain expected PrimExpr or Range, but got {type(dom)}",
498497
self.node.span,
499498
)
499+
block_var = tvm.tir.Var(var_name, dtype=dom.extent.dtype)
500500
value = tvm.runtime.convert(value)
501501
if not isinstance(value, PrimExpr):
502502
self.context.report_error(

python/tvm/script/tir/utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,12 @@
1616
# under the License.
1717
"""Helper functions in TVM Script Parser"""
1818

19-
from typing import List, Optional, Union
19+
from typing import List, Optional
2020

2121
from tvm.arith import Analyzer
2222
from tvm.ir import Range
2323
from tvm.tir import PrimExpr, BufferRegion
24+
from tvm.tir.expr import IntImm
2425
from .node import BufferSlice
2526

2627

@@ -44,8 +45,8 @@ def buffer_slice_to_region(
4445
"""
4546
region: List[Range] = []
4647
for s in buffer_slice.slices:
47-
start: Union[PrimExpr, int] = s.start
48-
extent: Union[PrimExpr, int] = 1 if s.stop is None else s.stop - s.start
48+
start = s.start if isinstance(s.start, PrimExpr) else IntImm("int32", s.start)
49+
extent = IntImm(start.dtype, 1) if s.stop is None else s.stop - s.start
4950
if not analyzer:
5051
analyzer = Analyzer()
5152
if isinstance(extent, PrimExpr):

src/tir/ir/expr.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,10 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
147147
IterVar::IterVar(Range dom, Var var, IterVarType t, String thread_tag, Span span) {
148148
ObjectPtr<IterVarNode> n = make_object<IterVarNode>();
149149
if (dom.defined() && dom->extent.defined()) {
150+
CHECK(dom->extent.dtype().is_int())
151+
<< "The dtype of the domain of an IterVar must be an integer type. However, the domain's "
152+
"dtype is "
153+
<< dom->extent.dtype();
150154
CHECK_EQ(dom->extent.dtype(), var.dtype())
151155
<< "The dtype of the extent of an IterVar (" << dom->extent.dtype()
152156
<< ") must match its associated Var's dtype (" << var.dtype() << ")";

tests/python/unittest/test_tvmscript_error_report.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@
2121
from tvm import tir
2222
from tvm.testing import check_error
2323
from tvm.script import tir as T
24-
from tvm.ir.diagnostics import override_renderer
25-
import inspect
2624

2725

2826
def buffer_bind_missing_args(a: T.handle) -> None:
@@ -629,5 +627,14 @@ def test_floor_dtype():
629627
check_error(floor_dtype, 3)
630628

631629

630+
def non_integer_typed_block_iter():
631+
with T.block():
632+
i = T.axis.S(0.1, 0.1) # error IterVar requires an integer dtype
633+
634+
635+
def test_non_integer_typed_block_iter():
636+
check_error(non_integer_typed_block_iter, 3)
637+
638+
632639
if __name__ == "__main__":
633640
sys.exit(pytest.main([__file__] + sys.argv[1:]))

tests/python/unittest/test_tvmscript_roundtrip.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3205,6 +3205,24 @@ def segment_sum(
32053205
return segment_sum
32063206

32073207

3208+
def int64_support():
3209+
@T.prim_func
3210+
def elementwise_shape_int64(a: T.handle, c: T.handle) -> None:
3211+
A = T.match_buffer(a, (T.int64(128), T.int64(128)), dtype="float32")
3212+
B = T.alloc_buffer((T.int64(128), T.int64(128)), dtype="float32")
3213+
C = T.match_buffer(c, (T.int64(128), T.int64(128)), dtype="float32")
3214+
for i, j in T.grid(128, 128):
3215+
with T.block("B"):
3216+
vi, vj = T.axis.remap("SS", [i, j])
3217+
B[vi, vj] = A[vi, vj] * 2.0
3218+
for i, j in T.grid(T.int64(128), T.int64(128)):
3219+
with T.block("C"):
3220+
vi, vj = T.axis.remap("SS", [i, j])
3221+
C[vi, vj] = B[vi, vj] + 1.0
3222+
3223+
return elementwise_shape_int64
3224+
3225+
32083226
ir_generator = tvm.testing.parameter(
32093227
opt_gemm_normalize,
32103228
opt_gemm_lower,
@@ -3237,6 +3255,7 @@ def segment_sum(
32373255
func_T_ptr_allocate,
32383256
llvm_intrin_call,
32393257
parse_bufferslice_as_range_bound,
3258+
int64_support,
32403259
)
32413260

32423261

0 commit comments

Comments
 (0)