diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser.py index 917a16d478c0..0668dc83df55 100644 --- a/python/tvm/script/parser.py +++ b/python/tvm/script/parser.py @@ -47,6 +47,7 @@ from .tir.node import Slice, BufferSlice from .tir.scope_handler import ScopeHandler, WithScopeHandler, ForScopeHandler from .tir.special_stmt import SpecialStmt +from .tir import ty class CallArgumentReader(object): @@ -447,7 +448,9 @@ def check_decorator(decorators: List[ast.Expr]) -> bool: # add parameters of function for arg in node.params: # Note that this case is for T.match_buffer syntax sugar - if isinstance(arg.ty, (ast.TypeCall, ast.TypeApply)): + if isinstance(arg.ty, (ast.TypeCall, ast.TypeApply)) and isinstance( + self.transform(arg.ty.func_name), ty.GenericBufferType + ): result = self.handle_match_buffer_type(arg.ty, arg.name) if not isinstance(result, buffer.Buffer): self.report_error( @@ -1138,6 +1141,33 @@ def transform_TypeTuple(self, node): """ return [self.transform(value) for value in node.values] + def transform_TypeApply(self, node): + """Visitor for Type[Type] expressions. + + Mostly used for ``T.Ptr`` expressions. + """ + func = self.transform(node.func_name) + + if not isinstance(func, ty.TypeGeneric) or not hasattr(func, "__getitem__"): + self.report_error( + f"Use of type arguments requires a type that accepts type arguments (e.g. T.Ptr), " + f"but found {type(func).__name__} instead.", + node.span, + ) + + param_types = [] + for param in node.params: + param_type = self.transform(param) + if not isinstance(param_type, ty.TypeGeneric): + self.report_error(f"Expected a type but found {type(param).__name__}", param.span) + + param_types.append(param_type) + + if len(param_types) == 1: + return func[param_types[0]] + else: + return func[param_types] + def handle_match_buffer_type(self, node, buffer_name): """special function to handle syntax sugar for match buffer. diff --git a/python/tvm/script/tir/ty.py b/python/tvm/script/tir/ty.py index d450ce554cb2..158302649a2c 100644 --- a/python/tvm/script/tir/ty.py +++ b/python/tvm/script/tir/ty.py @@ -38,13 +38,26 @@ def __call__(self): class ConcreteType(TypeGeneric): # pylint: disable=too-few-public-methods, abstract-method - """TVM script typing class for uniform Type objects""" + """TVM script typing class for uniform Type objects + + Params + ------ + vtype: Union[str, tvm.ir.Type] + + The IR type represented by the type annotation. If a string + (e.g. "float32"), this represents a `ir.PrimType` generated + from that string. If a `ir.Type` is provided, this represents + the type provided. + """ def __init__(self, vtype): - self.type = vtype + if isinstance(vtype, tvm.ir.Type): + self.type = vtype + else: + self.type = tvm.ir.PrimType(vtype) def evaluate(self): - return tvm.ir.PrimType(self.type) + return self.type class GenericPtrType(TypeGeneric): # pylint: disable=abstract-method @@ -54,6 +67,8 @@ class GenericPtrType(TypeGeneric): # pylint: disable=abstract-method """ def __getitem__(self, vtype): + if not isinstance(vtype, TypeGeneric): + raise TypeError(f"Ptr expects a type argument, but received {type(vtype).__name__}") return ConcreteType(tvm.ir.PointerType(vtype.evaluate())) @@ -65,6 +80,8 @@ class GenericTupleType(TypeGeneric): # pylint: disable=abstract-method """ def __getitem__(self, vtypes): + if isinstance(vtypes, TypeGeneric): + vtypes = [vtypes] return ConcreteType(tvm.ir.TupleType([vtype.evaluate() for vtype in vtypes])) diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index 023bb0d3ef00..0d6c6e5deeba 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -68,6 +68,58 @@ enum class ExprPrecedence : int { kUnknown = 7, }; +/*! \brief Utility used for identifying usage of a buffer_var + * + * \details Find the Buffer object that corresponds to a variable or + * allocation, based on the BufferLoad/BufferStore instances that + * occur within the allocation's body. + */ +class BufferUsageFinder : public StmtExprVisitor { + public: + static Map> FindUsage(Map> usage, Stmt body) { + BufferUsageFinder visitor(std::move(usage)); + visitor.VisitStmt(body); + return std::move(visitor.usage_); + } + + void VisitExpr_(const VarNode* op) final { + Var var = GetRef(op); + if (!usage_.count(var)) { + usage_.Set(var, {}); + } + } + + void VisitExpr_(const BufferLoadNode* op) final { + VisitBuffer(op->buffer); + StmtExprVisitor::VisitExpr_(op); + } + + void VisitStmt_(const BufferStoreNode* op) final { + VisitBuffer(op->buffer); + StmtExprVisitor::VisitStmt_(op); + } + + private: + explicit BufferUsageFinder(Map> usage) : usage_(usage) {} + + void VisitBuffer(const Buffer& buffer) { + if (buffers_visited_.count(buffer.get())) { + return; + } + buffers_visited_.insert(buffer.get()); + + Array arr = usage_.Get(buffer->data).value_or({}); + arr.push_back(buffer); + usage_.Set(buffer->data, arr); + } + + // The search result. + Map> usage_; + // The buffers that have been visited so far, to avoid duplicate + // entries in the search result. + std::unordered_set buffers_visited_; +}; + /*! * \brief The printer for TVMScript * \details The printer obtain the precedence of the top-level operation when printing each @@ -138,6 +190,14 @@ class TVMScriptPrinter : public StmtFunctor, * 3. The iter range is equal to loop range */ std::vector> block_var_remaps_; + /*! + * \brief Map from variables to the buffers they are used in. + * + * Used for identifying buffers that should be declared after the + * LetStmt or Allocate that generates their data pointer, rather + * than in the header. + */ + Map> buffer_var_usage_; Doc VisitExpr_(const CastNode* op, ExprPrecedence* out_precedence) override; Doc VisitExpr_(const VarNode* op, ExprPrecedence* out_precedence) override; @@ -201,6 +261,7 @@ class TVMScriptPrinter : public StmtFunctor, Doc PrintRange(const RangeNode* op); Doc PrintArray(const ArrayNode* op); Doc PrintBuffer(const BufferNode* op); + Doc PrintNonHeaderBufferDeclarations(Var buffer_var, Stmt body); Doc AllocBufferDeclaration(const Buffer& buf); Doc PrintBlockVar(const IterVar& iter_var, const PrimExpr& value); Doc PrintBlockVarRemaps(); @@ -830,11 +891,13 @@ Doc TVMScriptPrinter::VisitStmt_(const LetStmtNode* op) { Doc doc; if (current_num_ != num_child_ - 1) { doc << "with " << tir_prefix_ << ".let(" << Print(op->var) << ", " << Print(op->value) << "):"; - doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body)); + doc << Doc::Indent(4, Doc::NewLine() << PrintNonHeaderBufferDeclarations(op->var, op->body) + << PrintBody(op->body)); } else { if (memo_var_.find(op->var) == memo_var_.end()) var_not_in_headers_.insert(op->var.get()); doc << Print(op->var) << ": " << Print(GetType(op->var)) << " = " << Print(op->value) - << Doc::NewLine() << PrintBody(op->body); + << Doc::NewLine(); + doc << PrintNonHeaderBufferDeclarations(op->var, op->body) << PrintBody(op->body); } return doc; } @@ -923,33 +986,30 @@ Doc TVMScriptPrinter::VisitStmt_(const BufferRealizeNode* op) { Doc TVMScriptPrinter::VisitStmt_(const AllocateNode* op) { var_not_in_headers_.insert(op->buffer_var.get()); - Doc doc; + auto storage_scope = GetPtrStorageScope(op->buffer_var); + Doc func_call; + func_call << tir_prefix_ << ".allocate(" << Print(op->extents) << ", " << PrintDType(op->dtype) + << ", " << Print(storage_scope); + if (!is_one(op->condition)) { + func_call << ", " << Print(op->condition); + } + if (!op->annotations.empty()) { + func_call << ", annotations={"; + func_call << PrintAnnotations(op->annotations); + func_call << "}"; + } + func_call << ")"; + + Doc doc; if (current_num_ != num_child_ - 1) { - doc << "with " << tir_prefix_ << ".allocate(" << Print(op->extents) << ", " - << PrintDType(op->dtype) << ", " << Print(storage_scope); - if (!is_one(op->condition)) { - doc << ", " << Print(op->condition); - } - if (!op->annotations.empty()) { - doc << ", annotations={"; - doc << PrintAnnotations(op->annotations); - doc << "}"; - } - doc << ") as " << Print(op->buffer_var) << ":"; - doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body)); + doc << "with " << func_call << " as " << Print(op->buffer_var) << ":"; + doc << Doc::Indent(4, Doc::NewLine() + << PrintNonHeaderBufferDeclarations(op->buffer_var, op->body) + << PrintBody(op->body)); } else { - doc << Print(op->buffer_var) << " = " << tir_prefix_ << ".allocate(" << Print(op->extents) - << ", " << PrintDType(op->dtype) << ", " << Print(storage_scope); - if (!is_one(op->condition)) { - doc << ", " << Print(op->condition); - } - if (!op->annotations.empty()) { - doc << ", annotations={"; - doc << PrintAnnotations(op->annotations); - doc << "}"; - } - doc << ")" << Doc::NewLine() << PrintBody(op->body); + doc << Print(op->buffer_var) << " = " << func_call << Doc::NewLine(); + doc << PrintNonHeaderBufferDeclarations(op->buffer_var, op->body) << PrintBody(op->body); } TryDeallocVar(op->buffer_var); return doc; @@ -1458,6 +1518,20 @@ Doc TVMScriptPrinter::PrintBuffer(const BufferNode* op) { return meta_.InMeta(buffer) ? meta_.GetMetaNode(buffer) : AllocBuf(buffer); } +Doc TVMScriptPrinter::PrintNonHeaderBufferDeclarations(Var buffer_var, Stmt body) { + if (!buffer_var_usage_.count(buffer_var)) { + buffer_var_usage_ = BufferUsageFinder::FindUsage(std::move(buffer_var_usage_), body); + } + Array buffer_usage = buffer_var_usage_.Get(buffer_var).value_or({}); + Doc decls; + for (const auto& buf_usage : buffer_usage) { + decls << Print(buf_usage) << " = " << tir_prefix_ << ".buffer_decl(" + << memo_buf_decl_[buf_usage] << ")" << Doc::NewLine(); + buf_not_in_headers_.insert(buf_usage.get()); + } + return decls; +} + Doc TVMScriptPrinter::PrintBufferRegion(const BufferRegionNode* op) { Doc doc; if (op->region.size() == 0) { diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index bf1235a4dc42..51a4ce7960a8 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -3255,5 +3255,52 @@ def test_root_attr(): tvm.ir.assert_structural_equal(func, rt_func, True) +@T.prim_func +def func_T_ptr_let_statement( + args: T.handle, arg_type_ids_handle: T.Ptr[T.int32], num_args: T.int32 +) -> None: + # The T.Ptr declaration in the parameter list should parse + # correctly, and should be usable as the data pointer in a buffer. + arg_type_ids = T.buffer_decl([2], dtype="int32", data=arg_type_ids_handle) + + arg0: T.handle = T.tvm_struct_get(args, 0, 12, dtype="handle") + arg1: T.handle = T.tvm_struct_get(args, 1, 12, dtype="handle") + + # Functions that return a "handle" can be assigned to a T.Ptr + # variable. A variable annotated with T.Ptr still has dtype of + # T.handle, but has type annotation as a pointer type. + A_data: T.Ptr[T.float32] = T.tvm_struct_get(arg0, 0, 1, dtype="handle") + + # The buffer declaration has a data pointer defined earlier in + # this function. It should only be defined after the data pointer + # has been defined, and should not be hoisted into the header of + # the function as other buffer_decl statements can be. + A = T.buffer_decl([1024], dtype="float32", data=A_data) + B_data: T.Ptr[T.float32] = T.tvm_struct_get(arg1, 0, 1, dtype="handle") + B = T.buffer_decl([1024], dtype="float32", data=B_data) + + B[0] = A[0] + + +def test_T_ptr_let_statement(): + func = func_T_ptr_let_statement + rt_func = tvm.script.from_source(func.script(show_meta=True)) + tvm.ir.assert_structural_equal(func, rt_func, True) + + +@T.prim_func +def func_T_ptr_allocate() -> None: + A_data: T.Ptr[T.float32] = T.allocate([1024], "float32", "global") + A = T.buffer_decl([1024], dtype="float32", data=A_data) + + A[0] = 0.0 + + +def test_T_ptr_allocate(): + func = func_T_ptr_allocate + rt_func = tvm.script.from_source(func.script(show_meta=True)) + tvm.ir.assert_structural_equal(func, rt_func, True) + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:]))