@@ -189,24 +189,27 @@ def main(var_placeholder: T.handle, var_placeholder_1: T.handle, var_conv2d_wino
189189 bgemm = T .alloc_buffer ([6 , 6 , 9 , 128 ], elem_offset = 0 , align = 128 , offset_factor = 1 )
190190 A = T .alloc_buffer ([6 , 4 ], elem_offset = 0 , align = 128 , offset_factor = 1 )
191191 inverse = T .alloc_buffer ([4 , 4 , 9 , 128 ], elem_offset = 0 , align = 128 , offset_factor = 1 )
192- for i0_1 , i1_1 , i2_1 , i3_1 in T .grid (1 , 16 , 16 , 128 ):
192+ for i0 , i1 , i2 , i3 in T .grid (1 , 16 , 16 , 128 ):
193193 with T .block ("data_pad" ):
194+ i0_1 , i1_1 , i2_1 , i3_1 = T .axis .remap ("SSSS" , [i0 , i1 , i2 , i3 ])
194195 T .reads ([placeholder [i0_1 , i1_1 , i2_1 , i3_1 ]])
195196 T .writes ([data_pad [i0_1 , i1_1 , i2_1 , i3_1 ]])
196197 T .block_attr ({
197198 "schedule_rule" : "None" ,
198199 })
199200 data_pad [i0_1 , i1_1 , i2_1 , i3_1 ] = T .if_then_else (((((0 <= i1_1 ) and (i1_1 < 14 )) and (0 <= i2_1 )) and (i2_1 < 14 )), placeholder [i0_1 , i1_1 , i2_1 , i3_1 ], T .float32 (0 ), dtype = "float32" )
200- for eps , nu , p , ci in T .grid (6 , 6 , 9 , 128 ):
201+ for i0_2 , i1_2 , i2_2 , i3_2 in T .grid (6 , 6 , 9 , 128 ):
201202 with T .block ("input_tile" ):
203+ eps , nu , p , ci = T .axis .remap ("SSSS" , [i0_2 , i1_2 , i2_2 , i3_2 ])
202204 T .reads ([data_pad [T .floordiv (p , 9 ), ((T .floordiv (T .floormod (p , 9 ), 3 )* 4 ) + eps ), ((T .floormod (p , 3 )* 4 ) + nu ), ci ]])
203205 T .writes ([input_tile [eps , nu , p , ci ]])
204206 T .block_attr ({
205207 "schedule_rule" : "None" ,
206208 })
207209 input_tile [eps , nu , p , ci ] = data_pad [T .floordiv (p , 9 ), ((T .floordiv (T .floormod (p , 9 ), 3 )* 4 ) + eps ), ((T .floormod (p , 3 )* 4 ) + nu ), ci ]
208- for i , j in T .grid (6 , 6 ):
210+ for i0_3 , i1_3 in T .grid (6 , 6 ):
209211 with T .block ("B" ):
212+ i , j = T .axis .remap ("SS" , [i0_3 , i1_3 ])
210213 T .writes ([B [i , j ]])
211214 T .block_attr ({
212215 "const_matrix" : True ,
@@ -236,8 +239,9 @@ def main(var_placeholder: T.handle, var_placeholder_1: T.handle, var_conv2d_wino
236239 with T .init ():
237240 bgemm [eps_2 , nu_2 , p_2 , co ] = T .float32 (0 )
238241 bgemm [eps_2 , nu_2 , p_2 , co ] = (bgemm [eps_2 , nu_2 , p_2 , co ] + (data_pack [eps_2 , nu_2 , p_2 , ci_2 ]* placeholder_1 [eps_2 , nu_2 , co , ci_2 ]))
239- for i_1 , j_1 in T .grid (6 , 4 ):
242+ for i0_6 , i1_6 in T .grid (6 , 4 ):
240243 with T .block ("A" ):
244+ i_1 , j_1 = T .axis .remap ("SS" , [i0_6 , i1_6 ])
241245 T .writes ([A [i_1 , j_1 ]])
242246 T .block_attr ({
243247 "const_matrix" : True ,
@@ -256,8 +260,9 @@ def main(var_placeholder: T.handle, var_placeholder_1: T.handle, var_conv2d_wino
256260 with T .init ():
257261 inverse [vh , vw , p_3 , co_1 ] = T .float32 (0 )
258262 inverse [vh , vw , p_3 , co_1 ] = (inverse [vh , vw , p_3 , co_1 ] + ((bgemm [r_a_1 , r_b_1 , p_3 , co_1 ]* A [r_a_1 , vh ])* A [r_b_1 , vw ]))
259- for n , h , w , co_2 in T .grid (1 , 12 , 12 , 128 ):
263+ for i0_8 , i1_8 , i2_6 , i3_6 in T .grid (1 , 12 , 12 , 128 ):
260264 with T .block ("conv2d_winograd" ):
265+ n , h , w , co_2 = T .axis .remap ("SSSS" , [i0_8 , i1_8 , i2_6 , i3_6 ])
261266 T .reads ([inverse [T .floormod (h , 4 ), T .floormod (w , 4 ), (((n * 9 ) + (T .floordiv (h , 4 )* 3 )) + T .floordiv (w , 4 )), co_2 ]])
262267 T .writes ([conv2d_winograd [n , h , w , co_2 ]])
263268 T .block_attr ({
@@ -283,24 +288,27 @@ def main(var_placeholder: T.handle, var_placeholder_1: T.handle, var_conv2d_wino
283288 bgemm = T .alloc_buffer ([6 , 6 , 9 , 128 ], elem_offset = 0 , align = 128 , offset_factor = 1 )
284289 A = T .alloc_buffer ([6 , 4 ], elem_offset = 0 , align = 128 , offset_factor = 1 )
285290 inverse = T .alloc_buffer ([4 , 4 , 9 , 128 ], elem_offset = 0 , align = 128 , offset_factor = 1 )
286- for i0_1 , i1_1 , i2_1 , i3_1 in T .grid (1 , 16 , 16 , 128 ):
291+ for i0 , i1 , i2 , i3 in T .grid (1 , 16 , 16 , 128 ):
287292 with T .block ("data_pad" ):
293+ i0_1 , i1_1 , i2_1 , i3_1 = T .axis .remap ("SSSS" , [i0 , i1 , i2 , i3 ])
288294 T .block_attr ({
289295 "schedule_rule" : "None" ,
290296 })
291297 T .reads ([placeholder [i0_1 , i1_1 , i2_1 , i3_1 ]])
292298 T .writes ([data_pad [i0_1 , i1_1 , i2_1 , i3_1 ]])
293299 data_pad [i0_1 , i1_1 , i2_1 , i3_1 ] = T .if_then_else (((((0 <= i1_1 ) and (i1_1 < 14 )) and (0 <= i2_1 )) and (i2_1 < 14 )), placeholder [i0_1 , i1_1 , i2_1 , i3_1 ], T .float32 (0 ), dtype = "float32" )
294- for eps , nu , p , ci in T .grid (6 , 6 , 9 , 128 ):
300+ for i0_2 , i1_2 , i2_2 , i3_2 in T .grid (6 , 6 , 9 , 128 ):
295301 with T .block ("input_tile" ):
302+ eps , nu , p , ci = T .axis .remap ("SSSS" , [i0_2 , i1_2 , i2_2 , i3_2 ])
296303 T .block_attr ({
297304 "schedule_rule" : "None" ,
298305 })
299306 T .reads ([data_pad [T .floordiv (p , 9 ), ((T .floordiv (T .floormod (p , 9 ), 3 )* 4 ) + eps ), ((T .floormod (p , 3 )* 4 ) + nu ), ci ]])
300307 T .writes ([input_tile [eps , nu , p , ci ]])
301308 input_tile [eps , nu , p , ci ] = data_pad [T .floordiv (p , 9 ), ((T .floordiv (T .floormod (p , 9 ), 3 )* 4 ) + eps ), ((T .floormod (p , 3 )* 4 ) + nu ), ci ]
302- for i , j in T .grid (6 , 6 ):
309+ for i0_3 , i1_3 in T .grid (6 , 6 ):
303310 with T .block ("B" ):
311+ i , j = T .axis .remap ("SS" , [i0_3 , i1_3 ])
304312 T .writes ([B [i , j ]])
305313 T .block_attr ({
306314 "const_matrix" :True ,
@@ -330,8 +338,9 @@ def main(var_placeholder: T.handle, var_placeholder_1: T.handle, var_conv2d_wino
330338 with T .init ():
331339 bgemm [eps_2 , nu_2 , p_2 , co ] = T .float32 (0 )
332340 bgemm [eps_2 , nu_2 , p_2 , co ] = (bgemm [eps_2 , nu_2 , p_2 , co ] + (data_pack [eps_2 , nu_2 , p_2 , ci_2 ]* placeholder_1 [eps_2 , nu_2 , co , ci_2 ]))
333- for i_1 , j_1 in T .grid (6 , 4 ):
341+ for i0_6 , i1_6 in T .grid (6 , 4 ):
334342 with T .block ("A" ):
343+ i_1 , j_1 = T .axis .remap ("SS" , [i0_6 , i1_6 ])
335344 T .writes ([A [i_1 , j_1 ]])
336345 T .block_attr ({
337346 "const_matrix" :True ,
@@ -350,8 +359,9 @@ def main(var_placeholder: T.handle, var_placeholder_1: T.handle, var_conv2d_wino
350359 with T .init ():
351360 inverse [vh , vw , p_3 , co_1 ] = T .float32 (0 )
352361 inverse [vh , vw , p_3 , co_1 ] = (inverse [vh , vw , p_3 , co_1 ] + ((bgemm [r_a_1 , r_b_1 , p_3 , co_1 ]* A [r_a_1 , vh ])* A [r_b_1 , vw ]))
353- for n , h , w , co_2 in T .grid (1 , 12 , 12 , 128 ):
362+ for i0_8 , i1_8 , i2_6 , i3_6 in T .grid (1 , 12 , 12 , 128 ):
354363 with T .block ("conv2d_winograd" ):
364+ n , h , w , co_2 = T .axis .remap ("SSSS" , [i0_8 , i1_8 , i2_6 , i3_6 ])
355365 T .block_attr ({
356366 "schedule_rule" : "None" ,
357367 })
@@ -617,7 +627,6 @@ def test_meta_schedule_post_order_apply_custom_search_space_none_rule():
617627 _ = post_order_apply .generate_design_space (mod )
618628
619629
620- @pytest .mark .xfail # for compute_at bug
621630def test_meta_schedule_post_order_apply_custom_search_space_winograd ():
622631 @register_func ("tvm.meta_schedule.test.custom_search_space.winograd" )
623632 def custom_search_space_winograd_func (sch : Schedule , block : BlockRV ) -> List [Schedule ]:
@@ -682,8 +691,7 @@ def custom_search_space_winograd_func(sch: Schedule, block: BlockRV) -> List[Sch
682691 sch .annotate (block_or_loop = b76 , ann_key = "auto_unroll_explicit" , ann_val = v77 )
683692
684693 b78 = sch .get_block (name = "input_tile" )
685- (b79 ,) = sch .get_consumers (block = b78 )
686- l80 = sch .sample_compute_location (block = b79 , decision = 4 )
694+ l80 = sch .sample_compute_location (block = b78 , decision = 4 )
687695 sch .compute_at (block = b78 , loop = l80 , preserve_unit_loops = True )
688696
689697 b81 = sch .get_block (name = "data_pad" )
@@ -771,16 +779,16 @@ def custom_search_space_winograd_func(sch: Schedule, block: BlockRV) -> List[Sch
771779 "v85 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25], decision=1)" ,
772780 'sch.annotate(block_or_loop=b84, ann_key="auto_unroll_explicit", ann_val=v85)' ,
773781 'b86 = sch.get_block(name="input_tile", func_name="main")' ,
774- "l87 = sch.sample_compute_location(block=b86, decision=-1 )" ,
782+ "l87 = sch.sample_compute_location(block=b86, decision=4 )" ,
775783 "sch.compute_at(block=b86, loop=l87, preserve_unit_loops=True)" ,
776784 'b88 = sch.get_block(name="data_pad", func_name="main")' ,
777- "l89 = sch.sample_compute_location(block=b88, decision=-1)" ,
778- "sch.compute_at(block=b88, loop=l89, preserve_unit_loops=True)" ,
785+ "b89, = sch.get_consumers(block=b88)" ,
786+ "l90 = sch.sample_compute_location(block=b89, decision=-2)" ,
787+ "sch.compute_at(block=b88, loop=l90, preserve_unit_loops=True)" ,
779788 ],
780789 )
781790
782791
783- @pytest .mark .xfail # for compute_at bug
784792def test_meta_schedule_post_order_apply_custom_search_space_winograd_cuda ():
785793 @register_func ("tvm.meta_schedule.test.custom_search_space.winograd.cuda" )
786794 def custom_search_space_winograd_func_cuda (sch : Schedule , block : BlockRV ) -> List [Schedule ]:
0 commit comments