Skip to content

Commit f9041dc

Browse files
committed
[TVMScript] Updated buffer_var printing
LetStmt and AllocateNode can both be used to generate handles that are used in Buffer objects. In these cases, the Buffer declarations must go after the handle declaration, not in the function header.
1 parent 66dac85 commit f9041dc

File tree

1 file changed

+99
-26
lines changed

1 file changed

+99
-26
lines changed

src/printer/tvmscript_printer.cc

Lines changed: 99 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,57 @@ enum class ExprPrecedence : int {
6868
kUnknown = 7,
6969
};
7070

71+
/*! \brief Utility used for identifying usage of a buffer_var
72+
*
73+
* \details Find the Buffer object that corresponds to a variable or
74+
* allocation, based on the BufferLoad/BufferStore instances that
75+
* occur within the allocation's body.
76+
*/
77+
class BufferUsageFinder : public StmtExprVisitor {
78+
public:
79+
static void FindUsage(Map<Var, Array<Buffer>>& usage, Stmt body) {
80+
BufferUsageFinder visitor(usage);
81+
visitor.VisitStmt(body);
82+
}
83+
84+
void VisitExpr_(const VarNode* op) final {
85+
Var var = GetRef<Var>(op);
86+
if (!usage_.count(var)) {
87+
usage_.Set(var, {});
88+
}
89+
}
90+
91+
void VisitExpr_(const BufferLoadNode* op) final {
92+
VisitBuffer(op->buffer);
93+
StmtExprVisitor::VisitExpr_(op);
94+
}
95+
96+
void VisitStmt_(const BufferStoreNode* op) final {
97+
VisitBuffer(op->buffer);
98+
StmtExprVisitor::VisitStmt_(op);
99+
}
100+
101+
private:
102+
BufferUsageFinder(Map<Var, Array<Buffer>>& usage) : usage_(usage) {}
103+
104+
void VisitBuffer(const Buffer& buffer) {
105+
if (buffers_visited_.count(buffer.get())) {
106+
return;
107+
}
108+
buffers_visited_.insert(buffer.get());
109+
110+
Array<Buffer> arr = usage_.Get(buffer->data).value_or({});
111+
arr.push_back(buffer);
112+
usage_.Set(buffer->data, arr);
113+
}
114+
115+
// The search result.
116+
Map<Var, Array<Buffer>>& usage_;
117+
// The buffers that have been visited so far, to avoid duplicate
118+
// entries in the search result.
119+
std::unordered_set<const BufferNode*> buffers_visited_;
120+
};
121+
71122
/*!
72123
* \brief The printer for TVMScript
73124
* \details The printer obtain the precedence of the top-level operation when printing each
@@ -138,6 +189,14 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
138189
* 3. The iter range is equal to loop range
139190
*/
140191
std::vector<std::pair<IterVar, PrimExpr>> block_var_remaps_;
192+
/*!
193+
* \brief Map from variables to the buffers they are used in.
194+
*
195+
* Used for identifying buffers that should be declared after the
196+
* LetStmt or Allocate that generates their data pointer, rather
197+
* than in the header.
198+
*/
199+
Map<Var, Array<Buffer>> buffer_var_usage_;
141200

142201
Doc VisitExpr_(const CastNode* op, ExprPrecedence* out_precedence) override;
143202
Doc VisitExpr_(const VarNode* op, ExprPrecedence* out_precedence) override;
@@ -201,6 +260,7 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
201260
Doc PrintRange(const RangeNode* op);
202261
Doc PrintArray(const ArrayNode* op);
203262
Doc PrintBuffer(const BufferNode* op);
263+
Doc PrintNonHeaderBufferDeclarations(Var buffer_var, Stmt body);
204264
Doc AllocBufferDeclaration(const Buffer& buf);
205265
Doc PrintBlockVar(const IterVar& iter_var, const PrimExpr& value);
206266
Doc PrintBlockVarRemaps();
@@ -830,11 +890,13 @@ Doc TVMScriptPrinter::VisitStmt_(const LetStmtNode* op) {
830890
Doc doc;
831891
if (current_num_ != num_child_ - 1) {
832892
doc << "with " << tir_prefix_ << ".let(" << Print(op->var) << ", " << Print(op->value) << "):";
833-
doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body));
893+
doc << Doc::Indent(4, Doc::NewLine() << PrintNonHeaderBufferDeclarations(op->var, op->body)
894+
<< PrintBody(op->body));
834895
} else {
835896
if (memo_var_.find(op->var) == memo_var_.end()) var_not_in_headers_.insert(op->var.get());
836897
doc << Print(op->var) << ": " << Print(GetType(op->var)) << " = " << Print(op->value)
837-
<< Doc::NewLine() << PrintBody(op->body);
898+
<< Doc::NewLine() << PrintNonHeaderBufferDeclarations(op->var, op->body)
899+
<< PrintBody(op->body);
838900
}
839901
return doc;
840902
}
@@ -923,33 +985,30 @@ Doc TVMScriptPrinter::VisitStmt_(const BufferRealizeNode* op) {
923985

924986
Doc TVMScriptPrinter::VisitStmt_(const AllocateNode* op) {
925987
var_not_in_headers_.insert(op->buffer_var.get());
926-
Doc doc;
988+
927989
auto storage_scope = GetPtrStorageScope(op->buffer_var);
990+
Doc func_call;
991+
func_call << tir_prefix_ << ".allocate(" << Print(op->extents) << ", " << PrintDType(op->dtype)
992+
<< ", " << Print(storage_scope);
993+
if (!is_one(op->condition)) {
994+
func_call << ", " << Print(op->condition);
995+
}
996+
if (!op->annotations.empty()) {
997+
func_call << ", annotations={";
998+
func_call << PrintAnnotations(op->annotations);
999+
func_call << "}";
1000+
}
1001+
func_call << ")";
1002+
1003+
Doc doc;
9281004
if (current_num_ != num_child_ - 1) {
929-
doc << "with " << tir_prefix_ << ".allocate(" << Print(op->extents) << ", "
930-
<< PrintDType(op->dtype) << ", " << Print(storage_scope);
931-
if (!is_one(op->condition)) {
932-
doc << ", " << Print(op->condition);
933-
}
934-
if (!op->annotations.empty()) {
935-
doc << ", annotations={";
936-
doc << PrintAnnotations(op->annotations);
937-
doc << "}";
938-
}
939-
doc << ") as " << Print(op->buffer_var) << ":";
940-
doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body));
1005+
doc << "with " << func_call << " as " << Print(op->buffer_var) << ":";
1006+
doc << Doc::Indent(4, Doc::NewLine()
1007+
<< PrintNonHeaderBufferDeclarations(op->buffer_var, op->body)
1008+
<< PrintBody(op->body));
9411009
} else {
942-
doc << Print(op->buffer_var) << " = " << tir_prefix_ << ".allocate(" << Print(op->extents)
943-
<< ", " << PrintDType(op->dtype) << ", " << Print(storage_scope);
944-
if (!is_one(op->condition)) {
945-
doc << ", " << Print(op->condition);
946-
}
947-
if (!op->annotations.empty()) {
948-
doc << ", annotations={";
949-
doc << PrintAnnotations(op->annotations);
950-
doc << "}";
951-
}
952-
doc << ")" << Doc::NewLine() << PrintBody(op->body);
1010+
doc << Print(op->buffer_var) << " = " << func_call << Doc::NewLine()
1011+
<< PrintNonHeaderBufferDeclarations(op->buffer_var, op->body) << PrintBody(op->body);
9531012
}
9541013
TryDeallocVar(op->buffer_var);
9551014
return doc;
@@ -1458,6 +1517,20 @@ Doc TVMScriptPrinter::PrintBuffer(const BufferNode* op) {
14581517
return meta_.InMeta(buffer) ? meta_.GetMetaNode(buffer) : AllocBuf(buffer);
14591518
}
14601519

1520+
Doc TVMScriptPrinter::PrintNonHeaderBufferDeclarations(Var buffer_var, Stmt body) {
1521+
if (!buffer_var_usage_.count(buffer_var)) {
1522+
BufferUsageFinder::FindUsage(buffer_var_usage_, body);
1523+
}
1524+
Array<Buffer> buffer_usage = buffer_var_usage_.Get(buffer_var).value_or({});
1525+
Doc decls;
1526+
for (const auto& buf_usage : buffer_usage) {
1527+
decls << Print(buf_usage) << " = " << tir_prefix_ << ".buffer_decl("
1528+
<< memo_buf_decl_[buf_usage] << ")" << Doc::NewLine();
1529+
buf_not_in_headers_.insert(buf_usage.get());
1530+
}
1531+
return decls;
1532+
}
1533+
14611534
Doc TVMScriptPrinter::PrintBufferRegion(const BufferRegionNode* op) {
14621535
Doc doc;
14631536
if (op->region.size() == 0) {

0 commit comments

Comments
 (0)