Skip to content

Commit 66dac85

Browse files
committed
[TVMScript] Implemented parsing of T.Ptr[...]
These can be generated when exporting to TVMscript, but were not parsable after being generated.
1 parent fcab55e commit 66dac85

File tree

2 files changed

+31
-2
lines changed

2 files changed

+31
-2
lines changed

python/tvm/script/parser.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
from .tir.node import Slice, BufferSlice
4848
from .tir.scope_handler import ScopeHandler, WithScopeHandler, ForScopeHandler
4949
from .tir.special_stmt import SpecialStmt
50+
from .tir import ty
5051

5152

5253
class CallArgumentReader(object):
@@ -447,7 +448,9 @@ def check_decorator(decorators: List[ast.Expr]) -> bool:
447448
# add parameters of function
448449
for arg in node.params:
449450
# Note that this case is for T.match_buffer syntax sugar
450-
if isinstance(arg.ty, (ast.TypeCall, ast.TypeApply)):
451+
if isinstance(arg.ty, (ast.TypeCall, ast.TypeApply)) and isinstance(
452+
self.transform(arg.ty.func_name), ty.GenericBufferType
453+
):
451454
result = self.handle_match_buffer_type(arg.ty, arg.name)
452455
if not isinstance(result, buffer.Buffer):
453456
self.report_error(
@@ -1138,6 +1141,25 @@ def transform_TypeTuple(self, node):
11381141
"""
11391142
return [self.transform(value) for value in node.values]
11401143

1144+
def transform_TypeApply(self, node):
1145+
func = self.transform(node.func_name)
1146+
1147+
if not isinstance(func, ty.TypeGeneric):
1148+
self.report_error(f"Expected a type but found {type(func).__name__}", node.span)
1149+
1150+
param_types = []
1151+
for param in node.params:
1152+
param_type = self.transform(param)
1153+
if not isinstance(param_type, ty.TypeGeneric):
1154+
self.report_error(f"Expected a type but found {type(param).__name__}", param.span)
1155+
1156+
param_types.append(param_type)
1157+
1158+
if len(param_types) == 1:
1159+
return func[param_types[0]]
1160+
else:
1161+
return func[param_types]
1162+
11411163
def handle_match_buffer_type(self, node, buffer_name):
11421164
"""special function to handle syntax sugar for match buffer.
11431165

python/tvm/script/tir/ty.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,10 @@ def __init__(self, vtype):
4444
self.type = vtype
4545

4646
def evaluate(self):
47-
return tvm.ir.PrimType(self.type)
47+
if isinstance(self.type, tvm.ir.Type):
48+
return self.type
49+
else:
50+
return tvm.ir.PrimType(self.type)
4851

4952

5053
class GenericPtrType(TypeGeneric): # pylint: disable=abstract-method
@@ -54,6 +57,8 @@ class GenericPtrType(TypeGeneric): # pylint: disable=abstract-method
5457
"""
5558

5659
def __getitem__(self, vtype):
60+
if not isinstance(vtype, TypeGeneric):
61+
raise TypeError(f"Ptr expects a type argument, but received {type(vtype).__name__}")
5762
return ConcreteType(tvm.ir.PointerType(vtype.evaluate()))
5863

5964

@@ -65,6 +70,8 @@ class GenericTupleType(TypeGeneric): # pylint: disable=abstract-method
6570
"""
6671

6772
def __getitem__(self, vtypes):
73+
if isinstance(vtypes, TypeGeneric):
74+
vtypes = [vtypes]
6875
return ConcreteType(tvm.ir.TupleType([vtype.evaluate() for vtype in vtypes]))
6976

7077

0 commit comments

Comments
 (0)