Skip to content

Commit eb05d25

Browse files
committed
refactored init
1 parent 5e6b0a0 commit eb05d25

3 files changed

Lines changed: 35 additions & 46 deletions

File tree

src/meta_schedule/schedule_rule/multi_level_tiling.cc

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -251,28 +251,8 @@ ScheduleRule ScheduleRule::MultiLevelTiling(String structure, Optional<Array<Str
251251
Optional<Array<Integer>> vector_load_lens,
252252
Optional<Map<String, ObjectRef>> reuse_read,
253253
Optional<Map<String, ObjectRef>> reuse_write) {
254-
ObjectPtr<MultiLevelTilingNode> n = make_object<MultiLevelTilingNode>();
255-
n->structure = structure;
256-
n->tile_binds = tile_binds.value_or({});
257-
n->max_innermost_factor = max_innermost_factor.value_or(Integer(-1))->value;
258-
n->vector_load_lens = vector_load_lens.defined()
259-
? support::AsVector<Integer, int>(vector_load_lens.value())
260-
: std::vector<int>();
261-
n->reuse_read_ = reuse_read.defined() ? ReuseConfig(reuse_read.value()) : ReuseConfig();
262-
n->reuse_write_ = reuse_write.defined() ? ReuseConfig(reuse_write.value()) : ReuseConfig();
263-
for (int i = 0, len = structure.size(); i < len; ++i) {
264-
char c = structure.data()[i];
265-
if (c == 'S') {
266-
n->s_indices_.push_back(i);
267-
} else if (c == 'R') {
268-
n->r_indices_.push_back(i);
269-
} else {
270-
LOG(FATAL) << "ValueError: Invalid tiling structure: " << structure;
271-
}
272-
}
273-
n->thread_warp_size_ = -1;
274-
n->max_threads_per_block_ = -1;
275-
return ScheduleRule(n);
254+
return MultiLevelTilingInitCommon<MultiLevelTilingNode>(
255+
structure, tile_binds, max_innermost_factor, vector_load_lens, reuse_read, reuse_write);
276256
}
277257

278258
TVM_REGISTER_NODE_TYPE(MultiLevelTilingNode);

src/meta_schedule/schedule_rule/multi_level_tiling.h

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,5 +179,35 @@ class MultiLevelTilingNode : public ScheduleRuleNode {
179179
TVM_DECLARE_BASE_OBJECT_INFO(MultiLevelTilingNode, ScheduleRuleNode);
180180
};
181181

182+
template <typename NodeType>
183+
ScheduleRule MultiLevelTilingInitCommon(String structure, Optional<Array<String>> tile_binds,
184+
Optional<Integer> max_innermost_factor,
185+
Optional<Array<Integer>> vector_load_lens,
186+
Optional<Map<String, ObjectRef>> reuse_read,
187+
Optional<Map<String, ObjectRef>> reuse_write) {
188+
ObjectPtr<NodeType> n = make_object<NodeType>();
189+
n->structure = structure;
190+
n->tile_binds = tile_binds.value_or({});
191+
n->max_innermost_factor = max_innermost_factor.value_or(Integer(-1))->value;
192+
n->vector_load_lens = vector_load_lens.defined()
193+
? support::AsVector<Integer, int>(vector_load_lens.value())
194+
: std::vector<int>();
195+
n->reuse_read_ = reuse_read.defined() ? ReuseConfig(reuse_read.value()) : ReuseConfig();
196+
n->reuse_write_ = reuse_write.defined() ? ReuseConfig(reuse_write.value()) : ReuseConfig();
197+
for (int i = 0, len = structure.size(); i < len; ++i) {
198+
char c = structure.data()[i];
199+
if (c == 'S') {
200+
n->s_indices_.push_back(i);
201+
} else if (c == 'R') {
202+
n->r_indices_.push_back(i);
203+
} else {
204+
LOG(FATAL) << "ValueError: Invalid tiling structure: " << structure;
205+
}
206+
}
207+
n->thread_warp_size_ = -1;
208+
n->max_threads_per_block_ = -1;
209+
return ScheduleRule(n);
210+
}
211+
182212
} // namespace meta_schedule
183213
} // namespace tvm

src/meta_schedule/schedule_rule/multi_level_tiling_vnni.cc

Lines changed: 3 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -262,8 +262,7 @@ class MultiLevelTilingVNNINode : public MultiLevelTilingNode {
262262
return MultiLevelTilingNode::ApplySubRules(states);
263263
}
264264

265-
public:
266-
265+
public:
267266
static constexpr const char* _type_key = "meta_schedule.MultiLevelTilingVNNI";
268267
TVM_DECLARE_FINAL_OBJECT_INFO(MultiLevelTilingVNNINode, MultiLevelTilingNode);
269268
};
@@ -276,28 +275,8 @@ ScheduleRule ScheduleRule::MultiLevelTilingVNNI(String structure,
276275
Optional<Array<Integer>> vector_load_lens,
277276
Optional<Map<String, ObjectRef>> reuse_read,
278277
Optional<Map<String, ObjectRef>> reuse_write) {
279-
ObjectPtr<MultiLevelTilingVNNINode> n = make_object<MultiLevelTilingVNNINode>();
280-
n->structure = structure;
281-
n->tile_binds = tile_binds.value_or({});
282-
n->max_innermost_factor = max_innermost_factor.value_or(Integer(-1))->value;
283-
n->vector_load_lens = vector_load_lens.defined()
284-
? support::AsVector<Integer, int>(vector_load_lens.value())
285-
: std::vector<int>();
286-
n->reuse_read_ = reuse_read.defined() ? ReuseConfig(reuse_read.value()) : ReuseConfig();
287-
n->reuse_write_ = reuse_write.defined() ? ReuseConfig(reuse_write.value()) : ReuseConfig();
288-
for (int i = 0, len = structure.size(); i < len; ++i) {
289-
char c = structure.data()[i];
290-
if (c == 'S') {
291-
n->s_indices_.push_back(i);
292-
} else if (c == 'R') {
293-
n->r_indices_.push_back(i);
294-
} else {
295-
LOG(FATAL) << "ValueError: Invalid tiling structure: " << structure;
296-
}
297-
}
298-
n->thread_warp_size_ = -1;
299-
n->max_threads_per_block_ = -1;
300-
return ScheduleRule(n);
278+
return MultiLevelTilingInitCommon<MultiLevelTilingVNNINode>(
279+
structure, tile_binds, max_innermost_factor, vector_load_lens, reuse_read, reuse_write);
301280
}
302281

303282
TVM_REGISTER_NODE_TYPE(MultiLevelTilingVNNINode);

0 commit comments

Comments
 (0)