@@ -262,8 +262,7 @@ class MultiLevelTilingVNNINode : public MultiLevelTilingNode {
262262 return MultiLevelTilingNode::ApplySubRules (states);
263263 }
264264
265- public:
266-
265+ public:
267266 static constexpr const char * _type_key = " meta_schedule.MultiLevelTilingVNNI" ;
268267 TVM_DECLARE_FINAL_OBJECT_INFO (MultiLevelTilingVNNINode, MultiLevelTilingNode);
269268};
@@ -276,28 +275,8 @@ ScheduleRule ScheduleRule::MultiLevelTilingVNNI(String structure,
276275 Optional<Array<Integer>> vector_load_lens,
277276 Optional<Map<String, ObjectRef>> reuse_read,
278277 Optional<Map<String, ObjectRef>> reuse_write) {
279- ObjectPtr<MultiLevelTilingVNNINode> n = make_object<MultiLevelTilingVNNINode>();
280- n->structure = structure;
281- n->tile_binds = tile_binds.value_or ({});
282- n->max_innermost_factor = max_innermost_factor.value_or (Integer (-1 ))->value ;
283- n->vector_load_lens = vector_load_lens.defined ()
284- ? support::AsVector<Integer, int >(vector_load_lens.value ())
285- : std::vector<int >();
286- n->reuse_read_ = reuse_read.defined () ? ReuseConfig (reuse_read.value ()) : ReuseConfig ();
287- n->reuse_write_ = reuse_write.defined () ? ReuseConfig (reuse_write.value ()) : ReuseConfig ();
288- for (int i = 0 , len = structure.size (); i < len; ++i) {
289- char c = structure.data ()[i];
290- if (c == ' S' ) {
291- n->s_indices_ .push_back (i);
292- } else if (c == ' R' ) {
293- n->r_indices_ .push_back (i);
294- } else {
295- LOG (FATAL) << " ValueError: Invalid tiling structure: " << structure;
296- }
297- }
298- n->thread_warp_size_ = -1 ;
299- n->max_threads_per_block_ = -1 ;
300- return ScheduleRule (n);
278+ return MultiLevelTilingInitCommon<MultiLevelTilingVNNINode>(
279+ structure, tile_binds, max_innermost_factor, vector_load_lens, reuse_read, reuse_write);
301280}
302281
303282TVM_REGISTER_NODE_TYPE (MultiLevelTilingVNNINode);
0 commit comments