Skip to content

Commit 9cd3b58

Browse files
Lunderbergylc
authored andcommitted
[TVMScript] Support T.buffer_decl using data pointer from Let/Allocate (apache#10099)
* [TVMScript] Added unit tests demonstrating desired functionality * [TVMScript] Implemented parsing of T.Ptr[...] These can be generated when exporting to TVMscript, but were not parsable after being generated. * [TVMScript] Updated buffer_var printing LetStmt and AllocateNode can both be used to generate handles that are used in Buffer objects. In these cases, the Buffer declarations must go after the handle declaration, not in the function header. * Moved printing of var and buffer_decl into separate statements. * Updated following @shingjan's review comments.
1 parent 24c202b commit 9cd3b58

4 files changed

Lines changed: 198 additions & 30 deletions

File tree

python/tvm/script/parser.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
from .tir.node import Slice, BufferSlice
4848
from .tir.scope_handler import ScopeHandler, WithScopeHandler, ForScopeHandler
4949
from .tir.special_stmt import SpecialStmt
50+
from .tir import ty
5051

5152

5253
class CallArgumentReader(object):
@@ -447,7 +448,9 @@ def check_decorator(decorators: List[ast.Expr]) -> bool:
447448
# add parameters of function
448449
for arg in node.params:
449450
# Note that this case is for T.match_buffer syntax sugar
450-
if isinstance(arg.ty, (ast.TypeCall, ast.TypeApply)):
451+
if isinstance(arg.ty, (ast.TypeCall, ast.TypeApply)) and isinstance(
452+
self.transform(arg.ty.func_name), ty.GenericBufferType
453+
):
451454
result = self.handle_match_buffer_type(arg.ty, arg.name)
452455
if not isinstance(result, buffer.Buffer):
453456
self.report_error(
@@ -1138,6 +1141,33 @@ def transform_TypeTuple(self, node):
11381141
"""
11391142
return [self.transform(value) for value in node.values]
11401143

1144+
def transform_TypeApply(self, node):
1145+
"""Visitor for Type[Type] expressions.
1146+
1147+
Mostly used for ``T.Ptr`` expressions.
1148+
"""
1149+
func = self.transform(node.func_name)
1150+
1151+
if not isinstance(func, ty.TypeGeneric) or not hasattr(func, "__getitem__"):
1152+
self.report_error(
1153+
f"Use of type arguments requires a type that accepts type arguments (e.g. T.Ptr), "
1154+
f"but found {type(func).__name__} instead.",
1155+
node.span,
1156+
)
1157+
1158+
param_types = []
1159+
for param in node.params:
1160+
param_type = self.transform(param)
1161+
if not isinstance(param_type, ty.TypeGeneric):
1162+
self.report_error(f"Expected a type but found {type(param).__name__}", param.span)
1163+
1164+
param_types.append(param_type)
1165+
1166+
if len(param_types) == 1:
1167+
return func[param_types[0]]
1168+
else:
1169+
return func[param_types]
1170+
11411171
def handle_match_buffer_type(self, node, buffer_name):
11421172
"""special function to handle syntax sugar for match buffer.
11431173

python/tvm/script/tir/ty.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,26 @@ def __call__(self):
3838

3939

4040
class ConcreteType(TypeGeneric): # pylint: disable=too-few-public-methods, abstract-method
41-
"""TVM script typing class for uniform Type objects"""
41+
"""TVM script typing class for uniform Type objects
42+
43+
Params
44+
------
45+
vtype: Union[str, tvm.ir.Type]
46+
47+
The IR type represented by the type annotation. If a string
48+
(e.g. "float32"), this represents a `ir.PrimType` generated
49+
from that string. If a `ir.Type` is provided, this represents
50+
the type provided.
51+
"""
4252

4353
def __init__(self, vtype):
44-
self.type = vtype
54+
if isinstance(vtype, tvm.ir.Type):
55+
self.type = vtype
56+
else:
57+
self.type = tvm.ir.PrimType(vtype)
4558

4659
def evaluate(self):
47-
return tvm.ir.PrimType(self.type)
60+
return self.type
4861

4962

5063
class GenericPtrType(TypeGeneric): # pylint: disable=abstract-method
@@ -54,6 +67,8 @@ class GenericPtrType(TypeGeneric): # pylint: disable=abstract-method
5467
"""
5568

5669
def __getitem__(self, vtype):
70+
if not isinstance(vtype, TypeGeneric):
71+
raise TypeError(f"Ptr expects a type argument, but received {type(vtype).__name__}")
5772
return ConcreteType(tvm.ir.PointerType(vtype.evaluate()))
5873

5974

@@ -65,6 +80,8 @@ class GenericTupleType(TypeGeneric): # pylint: disable=abstract-method
6580
"""
6681

6782
def __getitem__(self, vtypes):
83+
if isinstance(vtypes, TypeGeneric):
84+
vtypes = [vtypes]
6885
return ConcreteType(tvm.ir.TupleType([vtype.evaluate() for vtype in vtypes]))
6986

7087

src/printer/tvmscript_printer.cc

Lines changed: 100 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,58 @@ enum class ExprPrecedence : int {
6868
kUnknown = 7,
6969
};
7070

71+
/*! \brief Utility used for identifying usage of a buffer_var
72+
*
73+
* \details Find the Buffer object that corresponds to a variable or
74+
* allocation, based on the BufferLoad/BufferStore instances that
75+
* occur within the allocation's body.
76+
*/
77+
class BufferUsageFinder : public StmtExprVisitor {
78+
public:
79+
static Map<Var, Array<Buffer>> FindUsage(Map<Var, Array<Buffer>> usage, Stmt body) {
80+
BufferUsageFinder visitor(std::move(usage));
81+
visitor.VisitStmt(body);
82+
return std::move(visitor.usage_);
83+
}
84+
85+
void VisitExpr_(const VarNode* op) final {
86+
Var var = GetRef<Var>(op);
87+
if (!usage_.count(var)) {
88+
usage_.Set(var, {});
89+
}
90+
}
91+
92+
void VisitExpr_(const BufferLoadNode* op) final {
93+
VisitBuffer(op->buffer);
94+
StmtExprVisitor::VisitExpr_(op);
95+
}
96+
97+
void VisitStmt_(const BufferStoreNode* op) final {
98+
VisitBuffer(op->buffer);
99+
StmtExprVisitor::VisitStmt_(op);
100+
}
101+
102+
private:
103+
explicit BufferUsageFinder(Map<Var, Array<Buffer>> usage) : usage_(usage) {}
104+
105+
void VisitBuffer(const Buffer& buffer) {
106+
if (buffers_visited_.count(buffer.get())) {
107+
return;
108+
}
109+
buffers_visited_.insert(buffer.get());
110+
111+
Array<Buffer> arr = usage_.Get(buffer->data).value_or({});
112+
arr.push_back(buffer);
113+
usage_.Set(buffer->data, arr);
114+
}
115+
116+
// The search result.
117+
Map<Var, Array<Buffer>> usage_;
118+
// The buffers that have been visited so far, to avoid duplicate
119+
// entries in the search result.
120+
std::unordered_set<const BufferNode*> buffers_visited_;
121+
};
122+
71123
/*!
72124
* \brief The printer for TVMScript
73125
* \details The printer obtain the precedence of the top-level operation when printing each
@@ -138,6 +190,14 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
138190
* 3. The iter range is equal to loop range
139191
*/
140192
std::vector<std::pair<IterVar, PrimExpr>> block_var_remaps_;
193+
/*!
194+
* \brief Map from variables to the buffers they are used in.
195+
*
196+
* Used for identifying buffers that should be declared after the
197+
* LetStmt or Allocate that generates their data pointer, rather
198+
* than in the header.
199+
*/
200+
Map<Var, Array<Buffer>> buffer_var_usage_;
141201

142202
Doc VisitExpr_(const CastNode* op, ExprPrecedence* out_precedence) override;
143203
Doc VisitExpr_(const VarNode* op, ExprPrecedence* out_precedence) override;
@@ -201,6 +261,7 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
201261
Doc PrintRange(const RangeNode* op);
202262
Doc PrintArray(const ArrayNode* op);
203263
Doc PrintBuffer(const BufferNode* op);
264+
Doc PrintNonHeaderBufferDeclarations(Var buffer_var, Stmt body);
204265
Doc AllocBufferDeclaration(const Buffer& buf);
205266
Doc PrintBlockVar(const IterVar& iter_var, const PrimExpr& value);
206267
Doc PrintBlockVarRemaps();
@@ -830,11 +891,13 @@ Doc TVMScriptPrinter::VisitStmt_(const LetStmtNode* op) {
830891
Doc doc;
831892
if (current_num_ != num_child_ - 1) {
832893
doc << "with " << tir_prefix_ << ".let(" << Print(op->var) << ", " << Print(op->value) << "):";
833-
doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body));
894+
doc << Doc::Indent(4, Doc::NewLine() << PrintNonHeaderBufferDeclarations(op->var, op->body)
895+
<< PrintBody(op->body));
834896
} else {
835897
if (memo_var_.find(op->var) == memo_var_.end()) var_not_in_headers_.insert(op->var.get());
836898
doc << Print(op->var) << ": " << Print(GetType(op->var)) << " = " << Print(op->value)
837-
<< Doc::NewLine() << PrintBody(op->body);
899+
<< Doc::NewLine();
900+
doc << PrintNonHeaderBufferDeclarations(op->var, op->body) << PrintBody(op->body);
838901
}
839902
return doc;
840903
}
@@ -923,33 +986,30 @@ Doc TVMScriptPrinter::VisitStmt_(const BufferRealizeNode* op) {
923986

924987
Doc TVMScriptPrinter::VisitStmt_(const AllocateNode* op) {
925988
var_not_in_headers_.insert(op->buffer_var.get());
926-
Doc doc;
989+
927990
auto storage_scope = GetPtrStorageScope(op->buffer_var);
991+
Doc func_call;
992+
func_call << tir_prefix_ << ".allocate(" << Print(op->extents) << ", " << PrintDType(op->dtype)
993+
<< ", " << Print(storage_scope);
994+
if (!is_one(op->condition)) {
995+
func_call << ", " << Print(op->condition);
996+
}
997+
if (!op->annotations.empty()) {
998+
func_call << ", annotations={";
999+
func_call << PrintAnnotations(op->annotations);
1000+
func_call << "}";
1001+
}
1002+
func_call << ")";
1003+
1004+
Doc doc;
9281005
if (current_num_ != num_child_ - 1) {
929-
doc << "with " << tir_prefix_ << ".allocate(" << Print(op->extents) << ", "
930-
<< PrintDType(op->dtype) << ", " << Print(storage_scope);
931-
if (!is_one(op->condition)) {
932-
doc << ", " << Print(op->condition);
933-
}
934-
if (!op->annotations.empty()) {
935-
doc << ", annotations={";
936-
doc << PrintAnnotations(op->annotations);
937-
doc << "}";
938-
}
939-
doc << ") as " << Print(op->buffer_var) << ":";
940-
doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body));
1006+
doc << "with " << func_call << " as " << Print(op->buffer_var) << ":";
1007+
doc << Doc::Indent(4, Doc::NewLine()
1008+
<< PrintNonHeaderBufferDeclarations(op->buffer_var, op->body)
1009+
<< PrintBody(op->body));
9411010
} else {
942-
doc << Print(op->buffer_var) << " = " << tir_prefix_ << ".allocate(" << Print(op->extents)
943-
<< ", " << PrintDType(op->dtype) << ", " << Print(storage_scope);
944-
if (!is_one(op->condition)) {
945-
doc << ", " << Print(op->condition);
946-
}
947-
if (!op->annotations.empty()) {
948-
doc << ", annotations={";
949-
doc << PrintAnnotations(op->annotations);
950-
doc << "}";
951-
}
952-
doc << ")" << Doc::NewLine() << PrintBody(op->body);
1011+
doc << Print(op->buffer_var) << " = " << func_call << Doc::NewLine();
1012+
doc << PrintNonHeaderBufferDeclarations(op->buffer_var, op->body) << PrintBody(op->body);
9531013
}
9541014
TryDeallocVar(op->buffer_var);
9551015
return doc;
@@ -1458,6 +1518,20 @@ Doc TVMScriptPrinter::PrintBuffer(const BufferNode* op) {
14581518
return meta_.InMeta(buffer) ? meta_.GetMetaNode(buffer) : AllocBuf(buffer);
14591519
}
14601520

1521+
Doc TVMScriptPrinter::PrintNonHeaderBufferDeclarations(Var buffer_var, Stmt body) {
1522+
if (!buffer_var_usage_.count(buffer_var)) {
1523+
buffer_var_usage_ = BufferUsageFinder::FindUsage(std::move(buffer_var_usage_), body);
1524+
}
1525+
Array<Buffer> buffer_usage = buffer_var_usage_.Get(buffer_var).value_or({});
1526+
Doc decls;
1527+
for (const auto& buf_usage : buffer_usage) {
1528+
decls << Print(buf_usage) << " = " << tir_prefix_ << ".buffer_decl("
1529+
<< memo_buf_decl_[buf_usage] << ")" << Doc::NewLine();
1530+
buf_not_in_headers_.insert(buf_usage.get());
1531+
}
1532+
return decls;
1533+
}
1534+
14611535
Doc TVMScriptPrinter::PrintBufferRegion(const BufferRegionNode* op) {
14621536
Doc doc;
14631537
if (op->region.size() == 0) {

tests/python/unittest/test_tvmscript_roundtrip.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3255,5 +3255,52 @@ def test_root_attr():
32553255
tvm.ir.assert_structural_equal(func, rt_func, True)
32563256

32573257

3258+
@T.prim_func
3259+
def func_T_ptr_let_statement(
3260+
args: T.handle, arg_type_ids_handle: T.Ptr[T.int32], num_args: T.int32
3261+
) -> None:
3262+
# The T.Ptr declaration in the parameter list should parse
3263+
# correctly, and should be usable as the data pointer in a buffer.
3264+
arg_type_ids = T.buffer_decl([2], dtype="int32", data=arg_type_ids_handle)
3265+
3266+
arg0: T.handle = T.tvm_struct_get(args, 0, 12, dtype="handle")
3267+
arg1: T.handle = T.tvm_struct_get(args, 1, 12, dtype="handle")
3268+
3269+
# Functions that return a "handle" can be assigned to a T.Ptr
3270+
# variable. A variable annotated with T.Ptr still has dtype of
3271+
# T.handle, but has type annotation as a pointer type.
3272+
A_data: T.Ptr[T.float32] = T.tvm_struct_get(arg0, 0, 1, dtype="handle")
3273+
3274+
# The buffer declaration has a data pointer defined earlier in
3275+
# this function. It should only be defined after the data pointer
3276+
# has been defined, and should not be hoisted into the header of
3277+
# the function as other buffer_decl statements can be.
3278+
A = T.buffer_decl([1024], dtype="float32", data=A_data)
3279+
B_data: T.Ptr[T.float32] = T.tvm_struct_get(arg1, 0, 1, dtype="handle")
3280+
B = T.buffer_decl([1024], dtype="float32", data=B_data)
3281+
3282+
B[0] = A[0]
3283+
3284+
3285+
def test_T_ptr_let_statement():
3286+
func = func_T_ptr_let_statement
3287+
rt_func = tvm.script.from_source(func.script(show_meta=True))
3288+
tvm.ir.assert_structural_equal(func, rt_func, True)
3289+
3290+
3291+
@T.prim_func
3292+
def func_T_ptr_allocate() -> None:
3293+
A_data: T.Ptr[T.float32] = T.allocate([1024], "float32", "global")
3294+
A = T.buffer_decl([1024], dtype="float32", data=A_data)
3295+
3296+
A[0] = 0.0
3297+
3298+
3299+
def test_T_ptr_allocate():
3300+
func = func_T_ptr_allocate
3301+
rt_func = tvm.script.from_source(func.script(show_meta=True))
3302+
tvm.ir.assert_structural_equal(func, rt_func, True)
3303+
3304+
32583305
if __name__ == "__main__":
32593306
sys.exit(pytest.main([__file__] + sys.argv[1:]))

0 commit comments

Comments
 (0)