From fcab55e4ceda2366ec2afb02960b29e282a909e0 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 28 Jan 2022 11:43:14 -0600 Subject: [PATCH 1/5] [TVMScript] Added unit tests demonstrating desired functionality --- .../unittest/test_tvmscript_roundtrip.py | 39 +++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index bf1235a4dc42..f1d62faab4e0 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -3255,5 +3255,44 @@ def test_root_attr(): tvm.ir.assert_structural_equal(func, rt_func, True) +@T.prim_func +def func_T_ptr_let_statement( + args: T.handle, arg_type_ids_handle: T.Ptr[T.int32], num_args: T.int32 +) -> None: + # buffer definition + arg_type_ids = T.buffer_decl([2], dtype="int32", data=arg_type_ids_handle) + # body + assert num_args == 2, "main: num_args should be 3" + arg0: T.handle = T.tvm_struct_get(args, 0, 12, dtype="handle") + arg1: T.handle = T.tvm_struct_get(args, 1, 12, dtype="handle") + + A_data: T.Ptr[T.float32] = T.tvm_struct_get(arg0, 0, 1, dtype="handle") + A = T.buffer_decl([1024], dtype="float32", data=A_data) + B_data: T.Ptr[T.float32] = T.tvm_struct_get(arg1, 0, 1, dtype="handle") + B = T.buffer_decl([1024], dtype="float32", data=B_data) + + B[0] = A[0] + + +def test_T_ptr_let_statement(): + func = func_T_ptr_let_statement + rt_func = tvm.script.from_source(func.script(show_meta=True)) + tvm.ir.assert_structural_equal(func, rt_func, True) + + +@T.prim_func +def func_T_ptr_allocate() -> None: + A_data: T.Ptr[T.float32] = T.allocate([1024], "float32", "global") + A = T.buffer_decl([1024], dtype="float32", data=A_data) + + A[0] = 0.0 + + +def test_T_ptr_allocate(): + func = func_T_ptr_allocate + rt_func = tvm.script.from_source(func.script(show_meta=True)) + tvm.ir.assert_structural_equal(func, rt_func, True) + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) From 0c589aa7775e49e8bd5ae8b8aac0055c6b8b3f3d Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 27 Jan 2022 15:01:44 -0600 Subject: [PATCH 2/5] [TVMScript] Implemented parsing of T.Ptr[...] These can be generated when exporting to TVMscript, but were not parsable after being generated. --- python/tvm/script/parser.py | 28 +++++++++++++++++++++++++++- python/tvm/script/tir/ty.py | 7 +++++++ 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser.py index 917a16d478c0..b6dbeae4ff77 100644 --- a/python/tvm/script/parser.py +++ b/python/tvm/script/parser.py @@ -47,6 +47,7 @@ from .tir.node import Slice, BufferSlice from .tir.scope_handler import ScopeHandler, WithScopeHandler, ForScopeHandler from .tir.special_stmt import SpecialStmt +from .tir import ty class CallArgumentReader(object): @@ -447,7 +448,9 @@ def check_decorator(decorators: List[ast.Expr]) -> bool: # add parameters of function for arg in node.params: # Note that this case is for T.match_buffer syntax sugar - if isinstance(arg.ty, (ast.TypeCall, ast.TypeApply)): + if isinstance(arg.ty, (ast.TypeCall, ast.TypeApply)) and isinstance( + self.transform(arg.ty.func_name), ty.GenericBufferType + ): result = self.handle_match_buffer_type(arg.ty, arg.name) if not isinstance(result, buffer.Buffer): self.report_error( @@ -1138,6 +1141,29 @@ def transform_TypeTuple(self, node): """ return [self.transform(value) for value in node.values] + def transform_TypeApply(self, node): + """Visitor for Type[Type] expressions. + + Mostly used for ``T.Ptr`` expressions. + """ + func = self.transform(node.func_name) + + if not isinstance(func, ty.TypeGeneric): + self.report_error(f"Expected a type but found {type(func).__name__}", node.span) + + param_types = [] + for param in node.params: + param_type = self.transform(param) + if not isinstance(param_type, ty.TypeGeneric): + self.report_error(f"Expected a type but found {type(param).__name__}", param.span) + + param_types.append(param_type) + + if len(param_types) == 1: + return func[param_types[0]] + else: + return func[param_types] + def handle_match_buffer_type(self, node, buffer_name): """special function to handle syntax sugar for match buffer. diff --git a/python/tvm/script/tir/ty.py b/python/tvm/script/tir/ty.py index d450ce554cb2..757d18575189 100644 --- a/python/tvm/script/tir/ty.py +++ b/python/tvm/script/tir/ty.py @@ -44,6 +44,9 @@ def __init__(self, vtype): self.type = vtype def evaluate(self): + if isinstance(self.type, tvm.ir.Type): + return self.type + return tvm.ir.PrimType(self.type) @@ -54,6 +57,8 @@ class GenericPtrType(TypeGeneric): # pylint: disable=abstract-method """ def __getitem__(self, vtype): + if not isinstance(vtype, TypeGeneric): + raise TypeError(f"Ptr expects a type argument, but received {type(vtype).__name__}") return ConcreteType(tvm.ir.PointerType(vtype.evaluate())) @@ -65,6 +70,8 @@ class GenericTupleType(TypeGeneric): # pylint: disable=abstract-method """ def __getitem__(self, vtypes): + if isinstance(vtypes, TypeGeneric): + vtypes = [vtypes] return ConcreteType(tvm.ir.TupleType([vtype.evaluate() for vtype in vtypes])) From 1b70fa278f3336967fac123dfe997fde52252ee0 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 28 Jan 2022 11:44:30 -0600 Subject: [PATCH 3/5] [TVMScript] Updated buffer_var printing LetStmt and AllocateNode can both be used to generate handles that are used in Buffer objects. In these cases, the Buffer declarations must go after the handle declaration, not in the function header. --- src/printer/tvmscript_printer.cc | 126 ++++++++++++++++++++++++------- 1 file changed, 100 insertions(+), 26 deletions(-) diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index 023bb0d3ef00..eb513f503343 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -68,6 +68,58 @@ enum class ExprPrecedence : int { kUnknown = 7, }; +/*! \brief Utility used for identifying usage of a buffer_var + * + * \details Find the Buffer object that corresponds to a variable or + * allocation, based on the BufferLoad/BufferStore instances that + * occur within the allocation's body. + */ +class BufferUsageFinder : public StmtExprVisitor { + public: + static Map> FindUsage(Map> usage, Stmt body) { + BufferUsageFinder visitor(std::move(usage)); + visitor.VisitStmt(body); + return std::move(visitor.usage_); + } + + void VisitExpr_(const VarNode* op) final { + Var var = GetRef(op); + if (!usage_.count(var)) { + usage_.Set(var, {}); + } + } + + void VisitExpr_(const BufferLoadNode* op) final { + VisitBuffer(op->buffer); + StmtExprVisitor::VisitExpr_(op); + } + + void VisitStmt_(const BufferStoreNode* op) final { + VisitBuffer(op->buffer); + StmtExprVisitor::VisitStmt_(op); + } + + private: + explicit BufferUsageFinder(Map> usage) : usage_(usage) {} + + void VisitBuffer(const Buffer& buffer) { + if (buffers_visited_.count(buffer.get())) { + return; + } + buffers_visited_.insert(buffer.get()); + + Array arr = usage_.Get(buffer->data).value_or({}); + arr.push_back(buffer); + usage_.Set(buffer->data, arr); + } + + // The search result. + Map> usage_; + // The buffers that have been visited so far, to avoid duplicate + // entries in the search result. + std::unordered_set buffers_visited_; +}; + /*! * \brief The printer for TVMScript * \details The printer obtain the precedence of the top-level operation when printing each @@ -138,6 +190,14 @@ class TVMScriptPrinter : public StmtFunctor, * 3. The iter range is equal to loop range */ std::vector> block_var_remaps_; + /*! + * \brief Map from variables to the buffers they are used in. + * + * Used for identifying buffers that should be declared after the + * LetStmt or Allocate that generates their data pointer, rather + * than in the header. + */ + Map> buffer_var_usage_; Doc VisitExpr_(const CastNode* op, ExprPrecedence* out_precedence) override; Doc VisitExpr_(const VarNode* op, ExprPrecedence* out_precedence) override; @@ -201,6 +261,7 @@ class TVMScriptPrinter : public StmtFunctor, Doc PrintRange(const RangeNode* op); Doc PrintArray(const ArrayNode* op); Doc PrintBuffer(const BufferNode* op); + Doc PrintNonHeaderBufferDeclarations(Var buffer_var, Stmt body); Doc AllocBufferDeclaration(const Buffer& buf); Doc PrintBlockVar(const IterVar& iter_var, const PrimExpr& value); Doc PrintBlockVarRemaps(); @@ -830,11 +891,13 @@ Doc TVMScriptPrinter::VisitStmt_(const LetStmtNode* op) { Doc doc; if (current_num_ != num_child_ - 1) { doc << "with " << tir_prefix_ << ".let(" << Print(op->var) << ", " << Print(op->value) << "):"; - doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body)); + doc << Doc::Indent(4, Doc::NewLine() << PrintNonHeaderBufferDeclarations(op->var, op->body) + << PrintBody(op->body)); } else { if (memo_var_.find(op->var) == memo_var_.end()) var_not_in_headers_.insert(op->var.get()); doc << Print(op->var) << ": " << Print(GetType(op->var)) << " = " << Print(op->value) - << Doc::NewLine() << PrintBody(op->body); + << Doc::NewLine() << PrintNonHeaderBufferDeclarations(op->var, op->body) + << PrintBody(op->body); } return doc; } @@ -923,33 +986,30 @@ Doc TVMScriptPrinter::VisitStmt_(const BufferRealizeNode* op) { Doc TVMScriptPrinter::VisitStmt_(const AllocateNode* op) { var_not_in_headers_.insert(op->buffer_var.get()); - Doc doc; + auto storage_scope = GetPtrStorageScope(op->buffer_var); + Doc func_call; + func_call << tir_prefix_ << ".allocate(" << Print(op->extents) << ", " << PrintDType(op->dtype) + << ", " << Print(storage_scope); + if (!is_one(op->condition)) { + func_call << ", " << Print(op->condition); + } + if (!op->annotations.empty()) { + func_call << ", annotations={"; + func_call << PrintAnnotations(op->annotations); + func_call << "}"; + } + func_call << ")"; + + Doc doc; if (current_num_ != num_child_ - 1) { - doc << "with " << tir_prefix_ << ".allocate(" << Print(op->extents) << ", " - << PrintDType(op->dtype) << ", " << Print(storage_scope); - if (!is_one(op->condition)) { - doc << ", " << Print(op->condition); - } - if (!op->annotations.empty()) { - doc << ", annotations={"; - doc << PrintAnnotations(op->annotations); - doc << "}"; - } - doc << ") as " << Print(op->buffer_var) << ":"; - doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body)); + doc << "with " << func_call << " as " << Print(op->buffer_var) << ":"; + doc << Doc::Indent(4, Doc::NewLine() + << PrintNonHeaderBufferDeclarations(op->buffer_var, op->body) + << PrintBody(op->body)); } else { - doc << Print(op->buffer_var) << " = " << tir_prefix_ << ".allocate(" << Print(op->extents) - << ", " << PrintDType(op->dtype) << ", " << Print(storage_scope); - if (!is_one(op->condition)) { - doc << ", " << Print(op->condition); - } - if (!op->annotations.empty()) { - doc << ", annotations={"; - doc << PrintAnnotations(op->annotations); - doc << "}"; - } - doc << ")" << Doc::NewLine() << PrintBody(op->body); + doc << Print(op->buffer_var) << " = " << func_call << Doc::NewLine() + << PrintNonHeaderBufferDeclarations(op->buffer_var, op->body) << PrintBody(op->body); } TryDeallocVar(op->buffer_var); return doc; @@ -1458,6 +1518,20 @@ Doc TVMScriptPrinter::PrintBuffer(const BufferNode* op) { return meta_.InMeta(buffer) ? meta_.GetMetaNode(buffer) : AllocBuf(buffer); } +Doc TVMScriptPrinter::PrintNonHeaderBufferDeclarations(Var buffer_var, Stmt body) { + if (!buffer_var_usage_.count(buffer_var)) { + buffer_var_usage_ = BufferUsageFinder::FindUsage(std::move(buffer_var_usage_), body); + } + Array buffer_usage = buffer_var_usage_.Get(buffer_var).value_or({}); + Doc decls; + for (const auto& buf_usage : buffer_usage) { + decls << Print(buf_usage) << " = " << tir_prefix_ << ".buffer_decl(" + << memo_buf_decl_[buf_usage] << ")" << Doc::NewLine(); + buf_not_in_headers_.insert(buf_usage.get()); + } + return decls; +} + Doc TVMScriptPrinter::PrintBufferRegion(const BufferRegionNode* op) { Doc doc; if (op->region.size() == 0) { From a18667c749878ffb6fb7d68c991e99dba933db01 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 28 Jan 2022 21:09:14 -0600 Subject: [PATCH 4/5] Moved printing of var and buffer_decl into separate statements. --- src/printer/tvmscript_printer.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index eb513f503343..0d6c6e5deeba 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -896,8 +896,8 @@ Doc TVMScriptPrinter::VisitStmt_(const LetStmtNode* op) { } else { if (memo_var_.find(op->var) == memo_var_.end()) var_not_in_headers_.insert(op->var.get()); doc << Print(op->var) << ": " << Print(GetType(op->var)) << " = " << Print(op->value) - << Doc::NewLine() << PrintNonHeaderBufferDeclarations(op->var, op->body) - << PrintBody(op->body); + << Doc::NewLine(); + doc << PrintNonHeaderBufferDeclarations(op->var, op->body) << PrintBody(op->body); } return doc; } @@ -1008,8 +1008,8 @@ Doc TVMScriptPrinter::VisitStmt_(const AllocateNode* op) { << PrintNonHeaderBufferDeclarations(op->buffer_var, op->body) << PrintBody(op->body)); } else { - doc << Print(op->buffer_var) << " = " << func_call << Doc::NewLine() - << PrintNonHeaderBufferDeclarations(op->buffer_var, op->body) << PrintBody(op->body); + doc << Print(op->buffer_var) << " = " << func_call << Doc::NewLine(); + doc << PrintNonHeaderBufferDeclarations(op->buffer_var, op->body) << PrintBody(op->body); } TryDeallocVar(op->buffer_var); return doc; From 2454682baf56acf973114720fe2d4ee305a9ba61 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 3 Feb 2022 09:23:31 -0600 Subject: [PATCH 5/5] Updated following @shingjan's review comments. --- python/tvm/script/parser.py | 8 +++++-- python/tvm/script/tir/ty.py | 22 ++++++++++++++----- .../unittest/test_tvmscript_roundtrip.py | 14 +++++++++--- 3 files changed, 33 insertions(+), 11 deletions(-) diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser.py index b6dbeae4ff77..0668dc83df55 100644 --- a/python/tvm/script/parser.py +++ b/python/tvm/script/parser.py @@ -1148,8 +1148,12 @@ def transform_TypeApply(self, node): """ func = self.transform(node.func_name) - if not isinstance(func, ty.TypeGeneric): - self.report_error(f"Expected a type but found {type(func).__name__}", node.span) + if not isinstance(func, ty.TypeGeneric) or not hasattr(func, "__getitem__"): + self.report_error( + f"Use of type arguments requires a type that accepts type arguments (e.g. T.Ptr), " + f"but found {type(func).__name__} instead.", + node.span, + ) param_types = [] for param in node.params: diff --git a/python/tvm/script/tir/ty.py b/python/tvm/script/tir/ty.py index 757d18575189..158302649a2c 100644 --- a/python/tvm/script/tir/ty.py +++ b/python/tvm/script/tir/ty.py @@ -38,16 +38,26 @@ def __call__(self): class ConcreteType(TypeGeneric): # pylint: disable=too-few-public-methods, abstract-method - """TVM script typing class for uniform Type objects""" + """TVM script typing class for uniform Type objects + + Params + ------ + vtype: Union[str, tvm.ir.Type] + + The IR type represented by the type annotation. If a string + (e.g. "float32"), this represents a `ir.PrimType` generated + from that string. If a `ir.Type` is provided, this represents + the type provided. + """ def __init__(self, vtype): - self.type = vtype + if isinstance(vtype, tvm.ir.Type): + self.type = vtype + else: + self.type = tvm.ir.PrimType(vtype) def evaluate(self): - if isinstance(self.type, tvm.ir.Type): - return self.type - - return tvm.ir.PrimType(self.type) + return self.type class GenericPtrType(TypeGeneric): # pylint: disable=abstract-method diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index f1d62faab4e0..51a4ce7960a8 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -3259,14 +3259,22 @@ def test_root_attr(): def func_T_ptr_let_statement( args: T.handle, arg_type_ids_handle: T.Ptr[T.int32], num_args: T.int32 ) -> None: - # buffer definition + # The T.Ptr declaration in the parameter list should parse + # correctly, and should be usable as the data pointer in a buffer. arg_type_ids = T.buffer_decl([2], dtype="int32", data=arg_type_ids_handle) - # body - assert num_args == 2, "main: num_args should be 3" + arg0: T.handle = T.tvm_struct_get(args, 0, 12, dtype="handle") arg1: T.handle = T.tvm_struct_get(args, 1, 12, dtype="handle") + # Functions that return a "handle" can be assigned to a T.Ptr + # variable. A variable annotated with T.Ptr still has dtype of + # T.handle, but has type annotation as a pointer type. A_data: T.Ptr[T.float32] = T.tvm_struct_get(arg0, 0, 1, dtype="handle") + + # The buffer declaration has a data pointer defined earlier in + # this function. It should only be defined after the data pointer + # has been defined, and should not be hoisted into the header of + # the function as other buffer_decl statements can be. A = T.buffer_decl([1024], dtype="float32", data=A_data) B_data: T.Ptr[T.float32] = T.tvm_struct_get(arg1, 0, 1, dtype="handle") B = T.buffer_decl([1024], dtype="float32", data=B_data)