Skip to content

Commit 3e8a78e

Browse files
committed
[TIR][USMP] Added buffer info extraction pass
* Change the class data members to have a trailing underscore Change-Id: I71809b3c73b0bc0cd133fad1392ae8c17c895ee4
1 parent 01111bd commit 3e8a78e

File tree

1 file changed

+30
-30
lines changed

1 file changed

+30
-30
lines changed

src/tir/usmp/analysis/extract_buffer_info.cc

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,10 @@ class BufferInfoExtractor : public StmtExprVisitor {
3838
public:
3939
explicit BufferInfoExtractor(const IRModule& module) : module_(module) {
4040
for (const auto& gv_func : module_->functions) {
41-
functions.Set(gv_func.first->name_hint, Downcast<PrimFunc>(gv_func.second));
41+
functions_.Set(gv_func.first->name_hint, Downcast<PrimFunc>(gv_func.second));
4242
}
4343
// Pushing a scope info for the initial body of the main function
44-
scope_stack.push(ScopeInfo());
44+
scope_stack_.push(ScopeInfo());
4545
}
4646
Map<BufferInfo, tir::Stmt> operator()(const PrimFunc& func);
4747

@@ -56,24 +56,24 @@ class BufferInfoExtractor : public StmtExprVisitor {
5656

5757
void UpdateAliases(const Array<PrimExpr>& args, const PrimFunc& func);
5858

59-
Map<BufferInfo, tir::Stmt> buffer_info_map;
60-
Map<tir::Stmt, Integer> buffer_info_start_stmt_idx;
61-
Map<tir::Stmt, Integer> buffer_info_end_stmt_idx;
62-
Map<tir::Var, tir::Stmt> allocate_var_to_stmt_map;
59+
Map<BufferInfo, tir::Stmt> buffer_info_map_;
60+
Map<tir::Stmt, Integer> buffer_info_start_stmt_idx_;
61+
Map<tir::Stmt, Integer> buffer_info_end_stmt_idx_;
62+
Map<tir::Var, tir::Stmt> allocate_var_to_stmt_map_;
6363

6464
std::unordered_set<Stmt, ObjectPtrHash, ObjectPtrEqual> currently_live_allocates;
65-
int current_stmt_idx = 0;
65+
int current_stmt_idx_ = 0;
6666
struct ScopeInfo {
6767
For for_loop;
6868
};
69-
std::stack<ScopeInfo> scope_stack;
69+
std::stack<ScopeInfo> scope_stack_;
7070

71-
Map<String, PrimFunc> functions;
71+
Map<String, PrimFunc> functions_;
7272
IRModule module_;
7373
};
7474

7575
void BufferInfoExtractor::VisitStmt(const Stmt& n) {
76-
current_stmt_idx += 1;
76+
current_stmt_idx_ += 1;
7777
StmtExprVisitor::VisitStmt(n);
7878
}
7979

@@ -92,7 +92,7 @@ static size_t CalculateExtentsSize(const AllocateNode* op) {
9292
}
9393

9494
void BufferInfoExtractor::VisitStmt_(const AllocateNode* op) {
95-
const auto& currect_scope_info = scope_stack.top();
95+
const auto& currect_scope_info = scope_stack_.top();
9696
const auto& type = Downcast<PointerType>(op->buffer_var->type_annotation);
9797
const auto& storage_scope = type->storage_scope;
9898

@@ -116,8 +116,8 @@ void BufferInfoExtractor::VisitStmt_(const AllocateNode* op) {
116116
"un-restricted pool is assigned";
117117
auto buffer_info = BufferInfo(op->buffer_var->name_hint, size_bytes, pool_candidates);
118118
auto allocate = GetRef<Allocate>(op);
119-
allocate_var_to_stmt_map.Set(op->buffer_var, allocate);
120-
buffer_info_map.Set(buffer_info, allocate);
119+
allocate_var_to_stmt_map_.Set(op->buffer_var, allocate);
120+
buffer_info_map_.Set(buffer_info, allocate);
121121
}
122122
}
123123
StmtExprVisitor::VisitStmt(op->body);
@@ -127,9 +127,9 @@ void BufferInfoExtractor::VisitStmt_(const ForNode* op) {
127127
ScopeInfo si{
128128
GetRef<For>(op),
129129
};
130-
scope_stack.push(si);
130+
scope_stack_.push(si);
131131
StmtExprVisitor::VisitStmt_(op);
132-
scope_stack.pop();
132+
scope_stack_.pop();
133133
}
134134

135135
void BufferInfoExtractor::VisitExpr_(const LoadNode* op) {
@@ -144,12 +144,12 @@ void BufferInfoExtractor::VisitStmt_(const StoreNode* op) {
144144

145145
void BufferInfoExtractor::VisitExpr_(const VarNode* op) {
146146
auto var = GetRef<Var>(op);
147-
if (allocate_var_to_stmt_map.count(var)) {
148-
auto allocate = allocate_var_to_stmt_map[var];
149-
if (buffer_info_start_stmt_idx.count(allocate) == 0) {
150-
buffer_info_start_stmt_idx.Set(allocate, current_stmt_idx);
147+
if (allocate_var_to_stmt_map_.count(var)) {
148+
auto allocate = allocate_var_to_stmt_map_[var];
149+
if (buffer_info_start_stmt_idx_.count(allocate) == 0) {
150+
buffer_info_start_stmt_idx_.Set(allocate, current_stmt_idx_);
151151
}
152-
buffer_info_end_stmt_idx.Set(allocate, current_stmt_idx);
152+
buffer_info_end_stmt_idx_.Set(allocate, current_stmt_idx_);
153153
}
154154
StmtExprVisitor::VisitExpr_(op);
155155
}
@@ -173,21 +173,21 @@ void BufferInfoExtractor::UpdateAliases(const Array<PrimExpr>& args, const PrimF
173173
// to the original allocate
174174
if (arg->IsInstance<LoadNode>()) {
175175
auto load = Downcast<Load>(arg);
176-
if (allocate_var_to_stmt_map.count(load->buffer_var)) {
177-
allocate_var_to_stmt_map.Set(param_buf, allocate_var_to_stmt_map[load->buffer_var]);
176+
if (allocate_var_to_stmt_map_.count(load->buffer_var)) {
177+
allocate_var_to_stmt_map_.Set(param_buf, allocate_var_to_stmt_map_[load->buffer_var]);
178178
}
179179
} else if (arg->IsInstance<VarNode>()) {
180180
auto var = Downcast<Var>(arg);
181-
if (allocate_var_to_stmt_map.count(var)) {
182-
allocate_var_to_stmt_map.Set(param_buf, allocate_var_to_stmt_map[var]);
181+
if (allocate_var_to_stmt_map_.count(var)) {
182+
allocate_var_to_stmt_map_.Set(param_buf, allocate_var_to_stmt_map_[var]);
183183
}
184184
}
185185
}
186186
}
187187

188188
void BufferInfoExtractor::VisitExpr_(const CallNode* op) {
189189
if (op->op.same_as(builtin::call_extern())) {
190-
auto func = functions.at(Downcast<StringImm>(op->args[0])->value);
190+
auto func = functions_.at(Downcast<StringImm>(op->args[0])->value);
191191
auto actual_args = Array<PrimExpr>(op->args.begin() + 1, op->args.end());
192192
this->UpdateAliases(actual_args, func);
193193
this->VisitStmt(func->body);
@@ -217,26 +217,26 @@ Map<BufferInfo, tir::Stmt> BufferInfoExtractor::operator()(const PrimFunc& main_
217217
};
218218

219219
std::vector<LivenessEvent> le_events;
220-
for (const auto& kv : buffer_info_map) {
220+
for (const auto& kv : buffer_info_map_) {
221221
if (!kv.second->IsInstance<AllocateNode>()) {
222222
continue;
223223
}
224224
auto allocate = Downcast<Allocate>(kv.second);
225225
auto buffer_info = Downcast<BufferInfo>(kv.first);
226226
// If the allocate is not used; we remove it from the analysis
227-
if (buffer_info_start_stmt_idx.count(allocate) == 0) {
227+
if (buffer_info_start_stmt_idx_.count(allocate) == 0) {
228228
continue;
229229
}
230230
LivenessEvent le_event_start;
231231
le_event_start.buffer_info = buffer_info;
232232
le_event_start.le_type = START;
233-
le_event_start.tick = buffer_info_start_stmt_idx[allocate];
233+
le_event_start.tick = buffer_info_start_stmt_idx_[allocate];
234234
le_events.push_back(le_event_start);
235235

236236
LivenessEvent le_event_end;
237237
le_event_end.buffer_info = buffer_info;
238238
le_event_end.le_type = END;
239-
le_event_end.tick = buffer_info_end_stmt_idx[allocate];
239+
le_event_end.tick = buffer_info_end_stmt_idx_[allocate];
240240
le_events.push_back(le_event_end);
241241
}
242242

@@ -262,7 +262,7 @@ Map<BufferInfo, tir::Stmt> BufferInfoExtractor::operator()(const PrimFunc& main_
262262
open_set.erase(le_event.buffer_info);
263263
}
264264
}
265-
return this->buffer_info_map;
265+
return this->buffer_info_map_;
266266
}
267267

268268
Map<BufferInfo, tir::Stmt> ExtractBufferInfo(const PrimFunc& main_func, const IRModule& mod) {

0 commit comments

Comments
 (0)