@@ -68,6 +68,57 @@ enum class ExprPrecedence : int {
6868 kUnknown = 7 ,
6969};
7070
71+ /* ! \brief Utility used for identifying usage of a buffer_var
72+ *
73+ * \details Find the Buffer object that corresponds to a variable or
74+ * allocation, based on the BufferLoad/BufferStore instances that
75+ * occur within the allocation's body.
76+ */
77+ class BufferUsageFinder : public StmtExprVisitor {
78+ public:
79+ static void FindUsage (Map<Var, Array<Buffer>>& usage, Stmt body) {
80+ BufferUsageFinder visitor (usage);
81+ visitor.VisitStmt (body);
82+ }
83+
84+ void VisitExpr_ (const VarNode* op) final {
85+ Var var = GetRef<Var>(op);
86+ if (!usage_.count (var)) {
87+ usage_.Set (var, {});
88+ }
89+ }
90+
91+ void VisitExpr_ (const BufferLoadNode* op) final {
92+ VisitBuffer (op->buffer );
93+ StmtExprVisitor::VisitExpr_ (op);
94+ }
95+
96+ void VisitStmt_ (const BufferStoreNode* op) final {
97+ VisitBuffer (op->buffer );
98+ StmtExprVisitor::VisitStmt_ (op);
99+ }
100+
101+ private:
102+ BufferUsageFinder (Map<Var, Array<Buffer>>& usage) : usage_(usage) {}
103+
104+ void VisitBuffer (const Buffer& buffer) {
105+ if (buffers_visited_.count (buffer.get ())) {
106+ return ;
107+ }
108+ buffers_visited_.insert (buffer.get ());
109+
110+ Array<Buffer> arr = usage_.Get (buffer->data ).value_or ({});
111+ arr.push_back (buffer);
112+ usage_.Set (buffer->data , arr);
113+ }
114+
115+ // The search result.
116+ Map<Var, Array<Buffer>>& usage_;
117+ // The buffers that have been visited so far, to avoid duplicate
118+ // entries in the search result.
119+ std::unordered_set<const BufferNode*> buffers_visited_;
120+ };
121+
71122/* !
72123 * \brief The printer for TVMScript
73124 * \details The printer obtain the precedence of the top-level operation when printing each
@@ -138,6 +189,14 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
138189 * 3. The iter range is equal to loop range
139190 */
140191 std::vector<std::pair<IterVar, PrimExpr>> block_var_remaps_;
192+ /* !
193+ * \brief Map from variables to the buffers they are used in.
194+ *
195+ * Used for identifying buffers that should be declared after the
196+ * LetStmt or Allocate that generates their data pointer, rather
197+ * than in the header.
198+ */
199+ Map<Var, Array<Buffer>> buffer_var_usage_;
141200
142201 Doc VisitExpr_ (const CastNode* op, ExprPrecedence* out_precedence) override ;
143202 Doc VisitExpr_ (const VarNode* op, ExprPrecedence* out_precedence) override ;
@@ -201,6 +260,7 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
201260 Doc PrintRange (const RangeNode* op);
202261 Doc PrintArray (const ArrayNode* op);
203262 Doc PrintBuffer (const BufferNode* op);
263+ Doc PrintNonHeaderBufferDeclarations (Var buffer_var, Stmt body);
204264 Doc AllocBufferDeclaration (const Buffer& buf);
205265 Doc PrintBlockVar (const IterVar& iter_var, const PrimExpr& value);
206266 Doc PrintBlockVarRemaps ();
@@ -830,11 +890,13 @@ Doc TVMScriptPrinter::VisitStmt_(const LetStmtNode* op) {
830890 Doc doc;
831891 if (current_num_ != num_child_ - 1 ) {
832892 doc << " with " << tir_prefix_ << " .let(" << Print (op->var ) << " , " << Print (op->value ) << " ):" ;
833- doc << Doc::Indent (4 , Doc::NewLine () << PrintBody (op->body ));
893+ doc << Doc::Indent (4 , Doc::NewLine () << PrintNonHeaderBufferDeclarations (op->var , op->body )
894+ << PrintBody (op->body ));
834895 } else {
835896 if (memo_var_.find (op->var ) == memo_var_.end ()) var_not_in_headers_.insert (op->var .get ());
836897 doc << Print (op->var ) << " : " << Print (GetType (op->var )) << " = " << Print (op->value )
837- << Doc::NewLine () << PrintBody (op->body );
898+ << Doc::NewLine () << PrintNonHeaderBufferDeclarations (op->var , op->body )
899+ << PrintBody (op->body );
838900 }
839901 return doc;
840902}
@@ -923,33 +985,30 @@ Doc TVMScriptPrinter::VisitStmt_(const BufferRealizeNode* op) {
923985
924986Doc TVMScriptPrinter::VisitStmt_ (const AllocateNode* op) {
925987 var_not_in_headers_.insert (op->buffer_var .get ());
926- Doc doc;
988+
927989 auto storage_scope = GetPtrStorageScope (op->buffer_var );
990+ Doc func_call;
991+ func_call << tir_prefix_ << " .allocate(" << Print (op->extents ) << " , " << PrintDType (op->dtype )
992+ << " , " << Print (storage_scope);
993+ if (!is_one (op->condition )) {
994+ func_call << " , " << Print (op->condition );
995+ }
996+ if (!op->annotations .empty ()) {
997+ func_call << " , annotations={" ;
998+ func_call << PrintAnnotations (op->annotations );
999+ func_call << " }" ;
1000+ }
1001+ func_call << " )" ;
1002+
1003+ Doc doc;
9281004 if (current_num_ != num_child_ - 1 ) {
929- doc << " with " << tir_prefix_ << " .allocate(" << Print (op->extents ) << " , "
930- << PrintDType (op->dtype ) << " , " << Print (storage_scope);
931- if (!is_one (op->condition )) {
932- doc << " , " << Print (op->condition );
933- }
934- if (!op->annotations .empty ()) {
935- doc << " , annotations={" ;
936- doc << PrintAnnotations (op->annotations );
937- doc << " }" ;
938- }
939- doc << " ) as " << Print (op->buffer_var ) << " :" ;
940- doc << Doc::Indent (4 , Doc::NewLine () << PrintBody (op->body ));
1005+ doc << " with " << func_call << " as " << Print (op->buffer_var ) << " :" ;
1006+ doc << Doc::Indent (4 , Doc::NewLine ()
1007+ << PrintNonHeaderBufferDeclarations (op->buffer_var , op->body )
1008+ << PrintBody (op->body ));
9411009 } else {
942- doc << Print (op->buffer_var ) << " = " << tir_prefix_ << " .allocate(" << Print (op->extents )
943- << " , " << PrintDType (op->dtype ) << " , " << Print (storage_scope);
944- if (!is_one (op->condition )) {
945- doc << " , " << Print (op->condition );
946- }
947- if (!op->annotations .empty ()) {
948- doc << " , annotations={" ;
949- doc << PrintAnnotations (op->annotations );
950- doc << " }" ;
951- }
952- doc << " )" << Doc::NewLine () << PrintBody (op->body );
1010+ doc << Print (op->buffer_var ) << " = " << func_call << Doc::NewLine ()
1011+ << PrintNonHeaderBufferDeclarations (op->buffer_var , op->body ) << PrintBody (op->body );
9531012 }
9541013 TryDeallocVar (op->buffer_var );
9551014 return doc;
@@ -1458,6 +1517,20 @@ Doc TVMScriptPrinter::PrintBuffer(const BufferNode* op) {
14581517 return meta_.InMeta (buffer) ? meta_.GetMetaNode (buffer) : AllocBuf (buffer);
14591518}
14601519
1520+ Doc TVMScriptPrinter::PrintNonHeaderBufferDeclarations (Var buffer_var, Stmt body) {
1521+ if (!buffer_var_usage_.count (buffer_var)) {
1522+ BufferUsageFinder::FindUsage (buffer_var_usage_, body);
1523+ }
1524+ Array<Buffer> buffer_usage = buffer_var_usage_.Get (buffer_var).value_or ({});
1525+ Doc decls;
1526+ for (const auto & buf_usage : buffer_usage) {
1527+ decls << Print (buf_usage) << " = " << tir_prefix_ << " .buffer_decl("
1528+ << memo_buf_decl_[buf_usage] << " )" << Doc::NewLine ();
1529+ buf_not_in_headers_.insert (buf_usage.get ());
1530+ }
1531+ return decls;
1532+ }
1533+
14611534Doc TVMScriptPrinter::PrintBufferRegion (const BufferRegionNode* op) {
14621535 Doc doc;
14631536 if (op->region .size () == 0 ) {
0 commit comments