Skip to content

Commit 1f436d7

Browse files
authored
[BugFix] Fix Winograd Test Script (#25)
* Fix winograd test script. * Fix script.
1 parent 793307e commit 1f436d7

1 file changed

Lines changed: 25 additions & 17 deletions

File tree

tests/python/unittest/test_meta_schedule_post_order_apply.py

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -189,24 +189,27 @@ def main(var_placeholder: T.handle, var_placeholder_1: T.handle, var_conv2d_wino
189189
bgemm = T.alloc_buffer([6, 6, 9, 128], elem_offset=0, align=128, offset_factor=1)
190190
A = T.alloc_buffer([6, 4], elem_offset=0, align=128, offset_factor=1)
191191
inverse = T.alloc_buffer([4, 4, 9, 128], elem_offset=0, align=128, offset_factor=1)
192-
for i0_1, i1_1, i2_1, i3_1 in T.grid(1, 16, 16, 128):
192+
for i0, i1, i2, i3 in T.grid(1, 16, 16, 128):
193193
with T.block("data_pad"):
194+
i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3])
194195
T.reads([placeholder[i0_1, i1_1, i2_1, i3_1]])
195196
T.writes([data_pad[i0_1, i1_1, i2_1, i3_1]])
196197
T.block_attr({
197198
"schedule_rule": "None",
198199
})
199200
data_pad[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(((((0 <= i1_1) and (i1_1 < 14)) and (0 <= i2_1)) and (i2_1 < 14)), placeholder[i0_1, i1_1, i2_1, i3_1], T.float32(0), dtype="float32")
200-
for eps, nu, p, ci in T.grid(6, 6, 9, 128):
201+
for i0_2, i1_2, i2_2, i3_2 in T.grid(6, 6, 9, 128):
201202
with T.block("input_tile"):
203+
eps, nu, p, ci = T.axis.remap("SSSS", [i0_2, i1_2, i2_2, i3_2])
202204
T.reads([data_pad[T.floordiv(p, 9), ((T.floordiv(T.floormod(p, 9), 3)*4) + eps), ((T.floormod(p, 3)*4) + nu), ci]])
203205
T.writes([input_tile[eps, nu, p, ci]])
204206
T.block_attr({
205207
"schedule_rule": "None",
206208
})
207209
input_tile[eps, nu, p, ci] = data_pad[T.floordiv(p, 9), ((T.floordiv(T.floormod(p, 9), 3)*4) + eps), ((T.floormod(p, 3)*4) + nu), ci]
208-
for i, j in T.grid(6, 6):
210+
for i0_3, i1_3 in T.grid(6, 6):
209211
with T.block("B"):
212+
i, j = T.axis.remap("SS", [i0_3, i1_3])
210213
T.writes([B[i, j]])
211214
T.block_attr({
212215
"const_matrix" : True,
@@ -236,8 +239,9 @@ def main(var_placeholder: T.handle, var_placeholder_1: T.handle, var_conv2d_wino
236239
with T.init():
237240
bgemm[eps_2, nu_2, p_2, co] = T.float32(0)
238241
bgemm[eps_2, nu_2, p_2, co] = (bgemm[eps_2, nu_2, p_2, co] + (data_pack[eps_2, nu_2, p_2, ci_2]*placeholder_1[eps_2, nu_2, co, ci_2]))
239-
for i_1, j_1 in T.grid(6, 4):
242+
for i0_6, i1_6 in T.grid(6, 4):
240243
with T.block("A"):
244+
i_1, j_1 = T.axis.remap("SS", [i0_6, i1_6])
241245
T.writes([A[i_1, j_1]])
242246
T.block_attr({
243247
"const_matrix" : True,
@@ -256,8 +260,9 @@ def main(var_placeholder: T.handle, var_placeholder_1: T.handle, var_conv2d_wino
256260
with T.init():
257261
inverse[vh, vw, p_3, co_1] = T.float32(0)
258262
inverse[vh, vw, p_3, co_1] = (inverse[vh, vw, p_3, co_1] + ((bgemm[r_a_1, r_b_1, p_3, co_1]*A[r_a_1, vh])*A[r_b_1, vw]))
259-
for n, h, w, co_2 in T.grid(1, 12, 12, 128):
263+
for i0_8, i1_8, i2_6, i3_6 in T.grid(1, 12, 12, 128):
260264
with T.block("conv2d_winograd"):
265+
n, h, w, co_2 = T.axis.remap("SSSS", [i0_8, i1_8, i2_6, i3_6])
261266
T.reads([inverse[T.floormod(h, 4), T.floormod(w, 4), (((n*9) + (T.floordiv(h, 4)*3)) + T.floordiv(w, 4)), co_2]])
262267
T.writes([conv2d_winograd[n, h, w, co_2]])
263268
T.block_attr({
@@ -283,24 +288,27 @@ def main(var_placeholder: T.handle, var_placeholder_1: T.handle, var_conv2d_wino
283288
bgemm = T.alloc_buffer([6, 6, 9, 128], elem_offset=0, align=128, offset_factor=1)
284289
A = T.alloc_buffer([6, 4], elem_offset=0, align=128, offset_factor=1)
285290
inverse = T.alloc_buffer([4, 4, 9, 128], elem_offset=0, align=128, offset_factor=1)
286-
for i0_1, i1_1, i2_1, i3_1 in T.grid(1, 16, 16, 128):
291+
for i0, i1, i2, i3 in T.grid(1, 16, 16, 128):
287292
with T.block("data_pad"):
293+
i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3])
288294
T.block_attr({
289295
"schedule_rule": "None",
290296
})
291297
T.reads([placeholder[i0_1, i1_1, i2_1, i3_1]])
292298
T.writes([data_pad[i0_1, i1_1, i2_1, i3_1]])
293299
data_pad[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(((((0 <= i1_1) and (i1_1 < 14)) and (0 <= i2_1)) and (i2_1 < 14)), placeholder[i0_1, i1_1, i2_1, i3_1], T.float32(0), dtype="float32")
294-
for eps, nu, p, ci in T.grid(6, 6, 9, 128):
300+
for i0_2, i1_2, i2_2, i3_2 in T.grid(6, 6, 9, 128):
295301
with T.block("input_tile"):
302+
eps, nu, p, ci = T.axis.remap("SSSS", [i0_2, i1_2, i2_2, i3_2])
296303
T.block_attr({
297304
"schedule_rule": "None",
298305
})
299306
T.reads([data_pad[T.floordiv(p, 9), ((T.floordiv(T.floormod(p, 9), 3)*4) + eps), ((T.floormod(p, 3)*4) + nu), ci]])
300307
T.writes([input_tile[eps, nu, p, ci]])
301308
input_tile[eps, nu, p, ci] = data_pad[T.floordiv(p, 9), ((T.floordiv(T.floormod(p, 9), 3)*4) + eps), ((T.floormod(p, 3)*4) + nu), ci]
302-
for i, j in T.grid(6, 6):
309+
for i0_3, i1_3 in T.grid(6, 6):
303310
with T.block("B"):
311+
i, j = T.axis.remap("SS", [i0_3, i1_3])
304312
T.writes([B[i, j]])
305313
T.block_attr({
306314
"const_matrix":True,
@@ -330,8 +338,9 @@ def main(var_placeholder: T.handle, var_placeholder_1: T.handle, var_conv2d_wino
330338
with T.init():
331339
bgemm[eps_2, nu_2, p_2, co] = T.float32(0)
332340
bgemm[eps_2, nu_2, p_2, co] = (bgemm[eps_2, nu_2, p_2, co] + (data_pack[eps_2, nu_2, p_2, ci_2]*placeholder_1[eps_2, nu_2, co, ci_2]))
333-
for i_1, j_1 in T.grid(6, 4):
341+
for i0_6, i1_6 in T.grid(6, 4):
334342
with T.block("A"):
343+
i_1, j_1 = T.axis.remap("SS", [i0_6, i1_6])
335344
T.writes([A[i_1, j_1]])
336345
T.block_attr({
337346
"const_matrix":True,
@@ -350,8 +359,9 @@ def main(var_placeholder: T.handle, var_placeholder_1: T.handle, var_conv2d_wino
350359
with T.init():
351360
inverse[vh, vw, p_3, co_1] = T.float32(0)
352361
inverse[vh, vw, p_3, co_1] = (inverse[vh, vw, p_3, co_1] + ((bgemm[r_a_1, r_b_1, p_3, co_1]*A[r_a_1, vh])*A[r_b_1, vw]))
353-
for n, h, w, co_2 in T.grid(1, 12, 12, 128):
362+
for i0_8, i1_8, i2_6, i3_6 in T.grid(1, 12, 12, 128):
354363
with T.block("conv2d_winograd"):
364+
n, h, w, co_2 = T.axis.remap("SSSS", [i0_8, i1_8, i2_6, i3_6])
355365
T.block_attr({
356366
"schedule_rule": "None",
357367
})
@@ -617,7 +627,6 @@ def test_meta_schedule_post_order_apply_custom_search_space_none_rule():
617627
_ = post_order_apply.generate_design_space(mod)
618628

619629

620-
@pytest.mark.xfail # for compute_at bug
621630
def test_meta_schedule_post_order_apply_custom_search_space_winograd():
622631
@register_func("tvm.meta_schedule.test.custom_search_space.winograd")
623632
def custom_search_space_winograd_func(sch: Schedule, block: BlockRV) -> List[Schedule]:
@@ -682,8 +691,7 @@ def custom_search_space_winograd_func(sch: Schedule, block: BlockRV) -> List[Sch
682691
sch.annotate(block_or_loop=b76, ann_key="auto_unroll_explicit", ann_val=v77)
683692

684693
b78 = sch.get_block(name="input_tile")
685-
(b79,) = sch.get_consumers(block=b78)
686-
l80 = sch.sample_compute_location(block=b79, decision=4)
694+
l80 = sch.sample_compute_location(block=b78, decision=4)
687695
sch.compute_at(block=b78, loop=l80, preserve_unit_loops=True)
688696

689697
b81 = sch.get_block(name="data_pad")
@@ -771,16 +779,16 @@ def custom_search_space_winograd_func(sch: Schedule, block: BlockRV) -> List[Sch
771779
"v85 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25], decision=1)",
772780
'sch.annotate(block_or_loop=b84, ann_key="auto_unroll_explicit", ann_val=v85)',
773781
'b86 = sch.get_block(name="input_tile", func_name="main")',
774-
"l87 = sch.sample_compute_location(block=b86, decision=-1)",
782+
"l87 = sch.sample_compute_location(block=b86, decision=4)",
775783
"sch.compute_at(block=b86, loop=l87, preserve_unit_loops=True)",
776784
'b88 = sch.get_block(name="data_pad", func_name="main")',
777-
"l89 = sch.sample_compute_location(block=b88, decision=-1)",
778-
"sch.compute_at(block=b88, loop=l89, preserve_unit_loops=True)",
785+
"b89, = sch.get_consumers(block=b88)",
786+
"l90 = sch.sample_compute_location(block=b89, decision=-2)",
787+
"sch.compute_at(block=b88, loop=l90, preserve_unit_loops=True)",
779788
],
780789
)
781790

782791

783-
@pytest.mark.xfail # for compute_at bug
784792
def test_meta_schedule_post_order_apply_custom_search_space_winograd_cuda():
785793
@register_func("tvm.meta_schedule.test.custom_search_space.winograd.cuda")
786794
def custom_search_space_winograd_func_cuda(sch: Schedule, block: BlockRV) -> List[Schedule]:

0 commit comments

Comments
 (0)