Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion python/tvm/script/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from .tir.node import Slice, BufferSlice
from .tir.scope_handler import ScopeHandler, WithScopeHandler, ForScopeHandler
from .tir.special_stmt import SpecialStmt
from .tir import ty


class CallArgumentReader(object):
Expand Down Expand Up @@ -447,7 +448,9 @@ def check_decorator(decorators: List[ast.Expr]) -> bool:
# add parameters of function
for arg in node.params:
# Note that this case is for T.match_buffer syntax sugar
if isinstance(arg.ty, (ast.TypeCall, ast.TypeApply)):
if isinstance(arg.ty, (ast.TypeCall, ast.TypeApply)) and isinstance(
self.transform(arg.ty.func_name), ty.GenericBufferType
):
result = self.handle_match_buffer_type(arg.ty, arg.name)
if not isinstance(result, buffer.Buffer):
self.report_error(
Expand Down Expand Up @@ -1138,6 +1141,29 @@ def transform_TypeTuple(self, node):
"""
return [self.transform(value) for value in node.values]

def transform_TypeApply(self, node):
"""Visitor for Type[Type] expressions.

Mostly used for ``T.Ptr`` expressions.
"""
func = self.transform(node.func_name)

if not isinstance(func, ty.TypeGeneric):
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use of GenericPtrType here maybe more accurate

Copy link
Copy Markdown
Contributor Author

@Lunderberg Lunderberg Feb 3, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From my understanding, the TypeApply ast node would also be used for T.Tuple[type1,type2] type annotations. While I couldn't find any usage of T.Tuple type annotations, GenericTupleType exists in tvm.script.tir.ty and would also use TypeApply in the synr ast. I've added a check for hasattr(func, __getitem__), since that expresses the intent that this is a type that can accept type arguments, and would work for both T.Ptr and T.Tuple.

self.report_error(f"Expected a type but found {type(func).__name__}", node.span)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Expect a type -> Expect a GenericPtrType maybe be better

Copy link
Copy Markdown
Contributor Author

@Lunderberg Lunderberg Feb 3, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to above, I'd like to avoid breaking any usage of T.Tuple, if it exists outside of the implementation above. I've updated the error message to read "Use of type arguments requires a type that accepts type arguments (e.g. T.Ptr), but found {type(func).__name__} instead." to clarify that the error would be in applying a type argument, not in the use of a type annotation.


param_types = []
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is a transform_TypeTuple impl that may be useful here

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had looked at the transform_TypeTuple implementation, and I don't think it's directly applicable. It assumes that there is a synr.ast.TypeTuple node, with types in node.values, and doesn't have a way to accept a list of types directly or to access the node.params of a TypeApply node. The TVMScriptParser.parse_arg_list is the closest I found to the desired functionality, but it requires the node to be an intrinsic, scope handler, or special statement, and doesn't have a case for type annotations.

for param in node.params:
param_type = self.transform(param)
if not isinstance(param_type, ty.TypeGeneric):
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this for-loop imply that all params of TypeApply call will be transformed into GenericPtrType? If so can we specify that here as well instead of ty.TypeGeneric

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't, no. The parameters of TypeApply typically aren't GenericPtrType. For example, T.Ptr[T.int32], the T.int32 parameter is a ConcreteType.

At some point, I may see if there's support for renaming ty.TypeGeneric to ty.Type. As it is, ty.ConcreteType is a subclass of ty.TypeGeneric, which doesn't make very much sense to me as there since it require any generic parameters in the user-supplied tvmscript.

self.report_error(f"Expected a type but found {type(param).__name__}", param.span)

param_types.append(param_type)

if len(param_types) == 1:
return func[param_types[0]]
else:
return func[param_types]

def handle_match_buffer_type(self, node, buffer_name):
"""special function to handle syntax sugar for match buffer.

Expand Down
7 changes: 7 additions & 0 deletions python/tvm/script/tir/ty.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ def __init__(self, vtype):
self.type = vtype

def evaluate(self):
if isinstance(self.type, tvm.ir.Type):
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought this change only impact GenericPtrType instead of ConcreteType here. Is that the case?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check gets hit when calling .evaluate() on the return type of GenericPtrType.__getitem__. The return type is a ConcreteType that has been passed a tvm.ir.PointerType, which should then be returned directly when the concrete type is evaluated, rather than the usual behavior of wrapping it in a tvm.ir.PrimType.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After thinking on it, I think that it would be cleaner if the type conversion happens during the call to __init__, rather than being delayed until the .evaluate() call. That way, the self.type object has a consistent type, regardless of whether it was initialized with a string to represent a primitive, or was initialized with a tvm.ir.Type to represent that type.

return self.type

return tvm.ir.PrimType(self.type)


Expand All @@ -54,6 +57,8 @@ class GenericPtrType(TypeGeneric): # pylint: disable=abstract-method
"""

def __getitem__(self, vtype):
if not isinstance(vtype, TypeGeneric):
raise TypeError(f"Ptr expects a type argument, but received {type(vtype).__name__}")
return ConcreteType(tvm.ir.PointerType(vtype.evaluate()))


Expand All @@ -65,6 +70,8 @@ class GenericTupleType(TypeGeneric): # pylint: disable=abstract-method
"""

def __getitem__(self, vtypes):
if isinstance(vtypes, TypeGeneric):
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to be a bug fix right?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's correct. If transform_TypeApply was to be able to handle both T.Ptr and T.Tuple cases, then I wanted to make sure that they both accepted the same types of arguments. The two options I considered were (a) always passing a tuple of parameter types even if there is only 1, or (b) passing a bare type when there is only 1 and otherwise passing a tuple of parameter types. Previously, T.Ptr implicitly followed convention (b), while T.Tuple followed convention (a). Option (b) matches python's subscripting syntax, so that's the one that I chose.

vtypes = [vtypes]
return ConcreteType(tvm.ir.TupleType([vtype.evaluate() for vtype in vtypes]))


Expand Down
126 changes: 100 additions & 26 deletions src/printer/tvmscript_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,58 @@ enum class ExprPrecedence : int {
kUnknown = 7,
};

/*! \brief Utility used for identifying usage of a buffer_var
*
* \details Find the Buffer object that corresponds to a variable or
* allocation, based on the BufferLoad/BufferStore instances that
* occur within the allocation's body.
*/
class BufferUsageFinder : public StmtExprVisitor {
public:
static Map<Var, Array<Buffer>> FindUsage(Map<Var, Array<Buffer>> usage, Stmt body) {
BufferUsageFinder visitor(std::move(usage));
visitor.VisitStmt(body);
return std::move(visitor.usage_);
}

void VisitExpr_(const VarNode* op) final {
Var var = GetRef<Var>(op);
if (!usage_.count(var)) {
usage_.Set(var, {});
}
}

void VisitExpr_(const BufferLoadNode* op) final {
VisitBuffer(op->buffer);
StmtExprVisitor::VisitExpr_(op);
}

void VisitStmt_(const BufferStoreNode* op) final {
VisitBuffer(op->buffer);
StmtExprVisitor::VisitStmt_(op);
}

private:
explicit BufferUsageFinder(Map<Var, Array<Buffer>> usage) : usage_(usage) {}

void VisitBuffer(const Buffer& buffer) {
if (buffers_visited_.count(buffer.get())) {
return;
}
buffers_visited_.insert(buffer.get());

Array<Buffer> arr = usage_.Get(buffer->data).value_or({});
arr.push_back(buffer);
usage_.Set(buffer->data, arr);
}

// The search result.
Map<Var, Array<Buffer>> usage_;
// The buffers that have been visited so far, to avoid duplicate
// entries in the search result.
std::unordered_set<const BufferNode*> buffers_visited_;
};

/*!
* \brief The printer for TVMScript
* \details The printer obtain the precedence of the top-level operation when printing each
Expand Down Expand Up @@ -138,6 +190,14 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
* 3. The iter range is equal to loop range
*/
std::vector<std::pair<IterVar, PrimExpr>> block_var_remaps_;
/*!
* \brief Map from variables to the buffers they are used in.
*
* Used for identifying buffers that should be declared after the
* LetStmt or Allocate that generates their data pointer, rather
* than in the header.
*/
Map<Var, Array<Buffer>> buffer_var_usage_;

Doc VisitExpr_(const CastNode* op, ExprPrecedence* out_precedence) override;
Doc VisitExpr_(const VarNode* op, ExprPrecedence* out_precedence) override;
Expand Down Expand Up @@ -201,6 +261,7 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
Doc PrintRange(const RangeNode* op);
Doc PrintArray(const ArrayNode* op);
Doc PrintBuffer(const BufferNode* op);
Doc PrintNonHeaderBufferDeclarations(Var buffer_var, Stmt body);
Doc AllocBufferDeclaration(const Buffer& buf);
Doc PrintBlockVar(const IterVar& iter_var, const PrimExpr& value);
Doc PrintBlockVarRemaps();
Expand Down Expand Up @@ -830,11 +891,13 @@ Doc TVMScriptPrinter::VisitStmt_(const LetStmtNode* op) {
Doc doc;
if (current_num_ != num_child_ - 1) {
doc << "with " << tir_prefix_ << ".let(" << Print(op->var) << ", " << Print(op->value) << "):";
doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body));
doc << Doc::Indent(4, Doc::NewLine() << PrintNonHeaderBufferDeclarations(op->var, op->body)
<< PrintBody(op->body));
} else {
if (memo_var_.find(op->var) == memo_var_.end()) var_not_in_headers_.insert(op->var.get());
doc << Print(op->var) << ": " << Print(GetType(op->var)) << " = " << Print(op->value)
<< Doc::NewLine() << PrintBody(op->body);
<< Doc::NewLine();
doc << PrintNonHeaderBufferDeclarations(op->var, op->body) << PrintBody(op->body);
}
return doc;
}
Expand Down Expand Up @@ -923,33 +986,30 @@ Doc TVMScriptPrinter::VisitStmt_(const BufferRealizeNode* op) {

Doc TVMScriptPrinter::VisitStmt_(const AllocateNode* op) {
var_not_in_headers_.insert(op->buffer_var.get());
Doc doc;

auto storage_scope = GetPtrStorageScope(op->buffer_var);
Doc func_call;
func_call << tir_prefix_ << ".allocate(" << Print(op->extents) << ", " << PrintDType(op->dtype)
<< ", " << Print(storage_scope);
if (!is_one(op->condition)) {
func_call << ", " << Print(op->condition);
}
if (!op->annotations.empty()) {
func_call << ", annotations={";
func_call << PrintAnnotations(op->annotations);
func_call << "}";
}
func_call << ")";

Doc doc;
if (current_num_ != num_child_ - 1) {
doc << "with " << tir_prefix_ << ".allocate(" << Print(op->extents) << ", "
<< PrintDType(op->dtype) << ", " << Print(storage_scope);
if (!is_one(op->condition)) {
doc << ", " << Print(op->condition);
}
if (!op->annotations.empty()) {
doc << ", annotations={";
doc << PrintAnnotations(op->annotations);
doc << "}";
}
doc << ") as " << Print(op->buffer_var) << ":";
doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body));
doc << "with " << func_call << " as " << Print(op->buffer_var) << ":";
doc << Doc::Indent(4, Doc::NewLine()
<< PrintNonHeaderBufferDeclarations(op->buffer_var, op->body)
<< PrintBody(op->body));
} else {
doc << Print(op->buffer_var) << " = " << tir_prefix_ << ".allocate(" << Print(op->extents)
<< ", " << PrintDType(op->dtype) << ", " << Print(storage_scope);
if (!is_one(op->condition)) {
doc << ", " << Print(op->condition);
}
if (!op->annotations.empty()) {
doc << ", annotations={";
doc << PrintAnnotations(op->annotations);
doc << "}";
}
doc << ")" << Doc::NewLine() << PrintBody(op->body);
doc << Print(op->buffer_var) << " = " << func_call << Doc::NewLine();
doc << PrintNonHeaderBufferDeclarations(op->buffer_var, op->body) << PrintBody(op->body);
}
TryDeallocVar(op->buffer_var);
return doc;
Expand Down Expand Up @@ -1458,6 +1518,20 @@ Doc TVMScriptPrinter::PrintBuffer(const BufferNode* op) {
return meta_.InMeta(buffer) ? meta_.GetMetaNode(buffer) : AllocBuf(buffer);
}

Doc TVMScriptPrinter::PrintNonHeaderBufferDeclarations(Var buffer_var, Stmt body) {
if (!buffer_var_usage_.count(buffer_var)) {
buffer_var_usage_ = BufferUsageFinder::FindUsage(std::move(buffer_var_usage_), body);
}
Array<Buffer> buffer_usage = buffer_var_usage_.Get(buffer_var).value_or({});
Doc decls;
for (const auto& buf_usage : buffer_usage) {
decls << Print(buf_usage) << " = " << tir_prefix_ << ".buffer_decl("
<< memo_buf_decl_[buf_usage] << ")" << Doc::NewLine();
buf_not_in_headers_.insert(buf_usage.get());
}
return decls;
}

Doc TVMScriptPrinter::PrintBufferRegion(const BufferRegionNode* op) {
Doc doc;
if (op->region.size() == 0) {
Expand Down
39 changes: 39 additions & 0 deletions tests/python/unittest/test_tvmscript_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -3255,5 +3255,44 @@ def test_root_attr():
tvm.ir.assert_structural_equal(func, rt_func, True)


@T.prim_func
def func_T_ptr_let_statement(
args: T.handle, arg_type_ids_handle: T.Ptr[T.int32], num_args: T.int32
) -> None:
# buffer definition
arg_type_ids = T.buffer_decl([2], dtype="int32", data=arg_type_ids_handle)
# body
assert num_args == 2, "main: num_args should be 3"
Comment thread
Lunderberg marked this conversation as resolved.
Outdated
arg0: T.handle = T.tvm_struct_get(args, 0, 12, dtype="handle")
arg1: T.handle = T.tvm_struct_get(args, 1, 12, dtype="handle")

A_data: T.Ptr[T.float32] = T.tvm_struct_get(arg0, 0, 1, dtype="handle")
A = T.buffer_decl([1024], dtype="float32", data=A_data)
B_data: T.Ptr[T.float32] = T.tvm_struct_get(arg1, 0, 1, dtype="handle")
B = T.buffer_decl([1024], dtype="float32", data=B_data)

B[0] = A[0]


def test_T_ptr_let_statement():
func = func_T_ptr_let_statement
rt_func = tvm.script.from_source(func.script(show_meta=True))
tvm.ir.assert_structural_equal(func, rt_func, True)


@T.prim_func
def func_T_ptr_allocate() -> None:
A_data: T.Ptr[T.float32] = T.allocate([1024], "float32", "global")
A = T.buffer_decl([1024], dtype="float32", data=A_data)

A[0] = 0.0


def test_T_ptr_allocate():
func = func_T_ptr_allocate
rt_func = tvm.script.from_source(func.script(show_meta=True))
tvm.ir.assert_structural_equal(func, rt_func, True)


if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))