Skip to content

Commit e5f1867

Browse files
committed
[TIR][USMP] adding the pass to convert to pool offsets
Fixing the references after changes in the memory planning algorithm. Change-Id: Id7c22356fd5de43d10a2b4fc70e978af2c6d599d
1 parent d66cd2c commit e5f1867

File tree

1 file changed

+53
-50
lines changed

1 file changed

+53
-50
lines changed

tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py

Lines changed: 53 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholde
202202
# fmt: on
203203

204204

205-
def test_linear():
205+
def test_mobilenet_subgraph():
206206
target = Target("c")
207207
fast_memory_pool = usmp_utils.PoolInfo(
208208
pool_name="fast_memory",
@@ -231,6 +231,7 @@ def test_linear():
231231
)(tir_mod)
232232

233233
tir_mod_with_offsets_ref = LinearStructurePlanned
234+
tir_mod_with_offsets_ref = tvm.script.from_source(tir_mod_with_offsets_ref.script(show_meta=False))
234235
# The TIR produced fails on roundtrip TVMScript testing.
235236
# Therefore, indicates the TVMScript produced here and/or the parser
236237
# is lacking functionality. Thus for these tests, uses a string
@@ -365,40 +366,14 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(place
365366
@tvm.script.ir_module
366367
class ResnetStructurePlanned:
367368
@T.prim_func
368-
def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(placeholder_4: T.handle, placeholder_5: T.handle, placeholder_6: T.handle, T_cast_2: T.handle, global_workspace_2_var: T.handle) -> None:
369-
placeholder_7 = T.match_buffer(placeholder_4, [1, 75, 75, 64], dtype="int16")
370-
placeholder_8 = T.match_buffer(placeholder_5, [1, 1, 64, 64], dtype="int16")
371-
placeholder_9 = T.match_buffer(placeholder_6, [1, 1, 1, 64], dtype="int32")
372-
T_cast_3 = T.match_buffer(T_cast_2, [1, 75, 75, 64], dtype="int16")
373-
global_workspace_2_buffer_var = T.match_buffer(global_workspace_2_var, [7200000], dtype="uint8", strides=[1], elem_offset=1, align=16)
374-
# body
375-
PaddedInput_let: T.handle = T.address_of(T.load("uint8", global_workspace_2_buffer_var.data, 0), dtype="handle")
376-
for i0_i1_fused, i2, i3 in T.grid(75, 75, 64):
377-
T.store(PaddedInput_let, i0_i1_fused * 4800 + i2 * 64 + i3, T.load("int16", placeholder_7.data, i0_i1_fused * 4800 + i2 * 64 + i3), True)
378-
for ax0_ax1_fused_ax2_fused in T.serial(0, 5625):
379-
Conv2dOutput_let: T.handle = T.address_of(T.load("uint8", global_workspace_2_buffer_var.data, 1478912), dtype="handle")
380-
for ff in T.serial(0, 64):
381-
T.store(Conv2dOutput_let, ff, 0, True)
382-
for rc in T.serial(0, 64):
383-
T.store(Conv2dOutput_let, ff, T.load("int32", Conv2dOutput_let, ff) + T.cast(T.load("int16", PaddedInput_let, ax0_ax1_fused_ax2_fused * 64 + rc), "int32") * T.cast(T.load("int16", placeholder_8.data, rc * 64 + ff), "int32"), True)
384-
for ax3_inner_1 in T.serial(0, 64):
385-
T.store(T_cast_3.data, ax0_ax1_fused_ax2_fused * 64 + ax3_inner_1, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_let, ax3_inner_1) + T.load("int32", placeholder_9.data, ax3_inner_1), 1843106743, 31, -6, dtype="int32"), 255), 0), "uint8"), "int16"), True)
386-
387-
@T.prim_func
388-
def run_model(input: T.handle, output: T.handle, global_workspace_0_var: T.handle) -> None:
389-
global_workspace_0_buffer_var = T.match_buffer(global_workspace_0_var, [7200000], dtype="uint8", strides=[1], elem_offset=1, align=16)
369+
def tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast(placeholder: T.handle, placeholder_1: T.handle, T_cast: T.handle, global_workspace_1_var: T.handle) -> None:
370+
placeholder_2 = T.match_buffer(placeholder, [1, 75, 75, 64], dtype="uint8")
371+
placeholder_3 = T.match_buffer(placeholder_1, [64], dtype="int32")
372+
T_cast_1 = T.match_buffer(T_cast, [1, 75, 75, 64], dtype="int16")
373+
global_workspace_1_buffer_var = T.match_buffer(global_workspace_1_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16)
390374
# body
391-
T.attr("default", "device_id", 0)
392-
T.attr("default", "device_type", 1)
393-
sid_2_let: T.handle = T.address_of(T.load("uint8", global_workspace_0_buffer_var.data, 6480000), dtype="handle")
394-
sid_6_let: T.handle = T.address_of(T.load("uint8", global_workspace_0_buffer_var.data, 0), dtype="handle")
395-
sid_7_let: T.handle = T.address_of(T.load("uint8", global_workspace_0_buffer_var.data, 758912), dtype="handle")
396-
sid_8_let: T.handle = T.address_of(T.load("uint8", global_workspace_0_buffer_var.data, 758912), dtype="handle")
397-
T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast", input, T.lookup_param("p0", dtype="handle"), sid_2_let, global_workspace_0_buffer_var.data, dtype="int32"))
398-
T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", sid_2_let, T.lookup_param("p3", dtype="handle"), T.lookup_param("p4", dtype="handle"), sid_8_let, global_workspace_0_buffer_var.data, dtype="int32"))
399-
T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1", sid_8_let, T.lookup_param("p5", dtype="handle"), T.lookup_param("p6", dtype="handle"), sid_7_let, global_workspace_0_buffer_var.data, dtype="int32"))
400-
T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_", sid_7_let, T.lookup_param("p7", dtype="handle"), T.lookup_param("p8", dtype="handle"), sid_6_let, global_workspace_0_buffer_var.data, dtype="int32"))
401-
T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_", sid_2_let, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_6_let, output, global_workspace_0_buffer_var.data, dtype="int32"))
375+
for ax0_ax1_fused, ax2, ax3_outer, ax3_inner in T.grid(75, 75, 4, 16):
376+
T.store(T_cast_1.data, ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.load("uint8", placeholder_2.data, ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner), "int32") - 94, 1843157232, 31, 1, dtype="int32") + T.load("int32", placeholder_3.data, ax3_outer * 16 + ax3_inner), 255), 0), "uint8"), "int16"), True)
402377

403378
@T.prim_func
404379
def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_(placeholder_22: T.handle, placeholder_23: T.handle, placeholder_24: T.handle, placeholder_25: T.handle, T_cast_6: T.handle, global_workspace_5_var: T.handle) -> None:
@@ -407,13 +382,13 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_s
407382
placeholder_26 = T.match_buffer(placeholder_24, [1, 1, 1, 256], dtype="int32")
408383
placeholder_28 = T.match_buffer(placeholder_25, [1, 75, 75, 256], dtype="int32")
409384
T_cast_7 = T.match_buffer(T_cast_6, [1, 75, 75, 256], dtype="uint8")
410-
global_workspace_5_buffer_var = T.match_buffer(global_workspace_5_var, [7200000], dtype="uint8", strides=[1], elem_offset=1, align=16)
385+
global_workspace_5_buffer_var = T.match_buffer(global_workspace_5_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16)
411386
# body
412-
PaddedInput_3_let: T.handle = T.address_of(T.load("uint8", global_workspace_5_buffer_var.data, 5760000), dtype="handle")
387+
PaddedInput_3_let: T.handle = T.address_of(T.load("uint8", global_workspace_5_buffer_var.data, 6480000), dtype="handle")
413388
for i0_i1_fused_3, i2_3, i3_3 in T.grid(75, 75, 64):
414389
T.store(PaddedInput_3_let, i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3, T.load("int16", placeholder_29.data, i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3), True)
415390
for ax0_ax1_fused_ax2_fused_3 in T.serial(0, 5625):
416-
Conv2dOutput_3_let: T.handle = T.address_of(T.load("uint8", global_workspace_5_buffer_var.data, 6480000), dtype="handle")
391+
Conv2dOutput_3_let: T.handle = T.address_of(T.load("uint8", global_workspace_5_buffer_var.data, 7200000), dtype="handle")
417392
for ax3_outer_2 in T.serial(0, 4):
418393
for ff_3 in T.serial(0, 64):
419394
T.store(Conv2dOutput_3_let, ff_3, 0, True)
@@ -428,13 +403,13 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_s
428403
placeholder_20 = T.match_buffer(placeholder_17, [1, 1, 64, 256], dtype="int16")
429404
placeholder_21 = T.match_buffer(placeholder_18, [1, 1, 1, 256], dtype="int32")
430405
T_add_1 = T.match_buffer(T_add, [1, 75, 75, 256], dtype="int32")
431-
global_workspace_4_buffer_var = T.match_buffer(global_workspace_4_var, [7200000], dtype="uint8", strides=[1], elem_offset=1, align=16)
406+
global_workspace_4_buffer_var = T.match_buffer(global_workspace_4_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16)
432407
# body
433-
PaddedInput_2_let: T.handle = T.address_of(T.load("uint8", global_workspace_4_buffer_var.data, 5760000), dtype="handle")
408+
PaddedInput_2_let: T.handle = T.address_of(T.load("uint8", global_workspace_4_buffer_var.data, 7200000), dtype="handle")
434409
for i0_i1_fused_2, i2_2, i3_2 in T.grid(75, 75, 64):
435410
T.store(PaddedInput_2_let, i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2, T.load("int16", placeholder_19.data, i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2), True)
436411
for ax0_ax1_fused_ax2_fused_2 in T.serial(0, 5625):
437-
Conv2dOutput_2_let: T.handle = T.address_of(T.load("uint8", global_workspace_4_buffer_var.data, 6480000), dtype="handle")
412+
Conv2dOutput_2_let: T.handle = T.address_of(T.load("uint8", global_workspace_4_buffer_var.data, 7920000), dtype="handle")
438413
for ax3_outer_1 in T.serial(0, 4):
439414
for ff_2 in T.serial(0, 64):
440415
T.store(Conv2dOutput_2_let, ff_2, 0, True)
@@ -444,38 +419,65 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_s
444419
T.store(T_add_1.data, ax0_ax1_fused_ax2_fused_2 * 256 + ax3_outer_1 * 64 + ax3_inner_3, T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_2_let, ax3_inner_3) + T.load("int32", placeholder_21.data, ax3_outer_1 * 64 + ax3_inner_3), 1711626602, 31, -8, dtype="int32") + 132, 255), 0), "uint8"), "int32") - 132, 2094289803, 31, -2, dtype="int32") + 136, True)
445420

446421
@T.prim_func
447-
def tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast(placeholder: T.handle, placeholder_1: T.handle, T_cast: T.handle, global_workspace_1_var: T.handle) -> None:
448-
placeholder_2 = T.match_buffer(placeholder, [1, 75, 75, 64], dtype="uint8")
449-
placeholder_3 = T.match_buffer(placeholder_1, [64], dtype="int32")
450-
T_cast_1 = T.match_buffer(T_cast, [1, 75, 75, 64], dtype="int16")
451-
global_workspace_1_buffer_var = T.match_buffer(global_workspace_1_var, [7200000], dtype="uint8", strides=[1], elem_offset=1, align=16)
422+
def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(placeholder_4: T.handle, placeholder_5: T.handle, placeholder_6: T.handle, T_cast_2: T.handle, global_workspace_2_var: T.handle) -> None:
423+
placeholder_7 = T.match_buffer(placeholder_4, [1, 75, 75, 64], dtype="int16")
424+
placeholder_8 = T.match_buffer(placeholder_5, [1, 1, 64, 64], dtype="int16")
425+
placeholder_9 = T.match_buffer(placeholder_6, [1, 1, 1, 64], dtype="int32")
426+
T_cast_3 = T.match_buffer(T_cast_2, [1, 75, 75, 64], dtype="int16")
427+
global_workspace_2_buffer_var = T.match_buffer(global_workspace_2_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16)
452428
# body
453-
for ax0_ax1_fused, ax2, ax3_outer, ax3_inner in T.grid(75, 75, 4, 16):
454-
T.store(T_cast_1.data, ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.load("uint8", placeholder_2.data, ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner), "int32") - 94, 1843157232, 31, 1, dtype="int32") + T.load("int32", placeholder_3.data, ax3_outer * 16 + ax3_inner), 255), 0), "uint8"), "int16"), True)
429+
PaddedInput_let: T.handle = T.address_of(T.load("uint8", global_workspace_2_buffer_var.data, 7200000), dtype="handle")
430+
for i0_i1_fused, i2, i3 in T.grid(75, 75, 64):
431+
T.store(PaddedInput_let, i0_i1_fused * 4800 + i2 * 64 + i3, T.load("int16", placeholder_7.data, i0_i1_fused * 4800 + i2 * 64 + i3), True)
432+
for ax0_ax1_fused_ax2_fused in T.serial(0, 5625):
433+
Conv2dOutput_let: T.handle = T.address_of(T.load("uint8", global_workspace_2_buffer_var.data, 7920000), dtype="handle")
434+
for ff in T.serial(0, 64):
435+
T.store(Conv2dOutput_let, ff, 0, True)
436+
for rc in T.serial(0, 64):
437+
T.store(Conv2dOutput_let, ff, T.load("int32", Conv2dOutput_let, ff) + T.cast(T.load("int16", PaddedInput_let, ax0_ax1_fused_ax2_fused * 64 + rc), "int32") * T.cast(T.load("int16", placeholder_8.data, rc * 64 + ff), "int32"), True)
438+
for ax3_inner_1 in T.serial(0, 64):
439+
T.store(T_cast_3.data, ax0_ax1_fused_ax2_fused * 64 + ax3_inner_1, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_let, ax3_inner_1) + T.load("int32", placeholder_9.data, ax3_inner_1), 1843106743, 31, -6, dtype="int32"), 255), 0), "uint8"), "int16"), True)
455440

456441
@T.prim_func
457442
def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(placeholder_10: T.handle, placeholder_11: T.handle, placeholder_12: T.handle, T_cast_4: T.handle, global_workspace_3_var: T.handle) -> None:
458443
placeholder_13 = T.match_buffer(placeholder_10, [1, 75, 75, 64], dtype="int16")
459444
placeholder_14 = T.match_buffer(placeholder_11, [3, 3, 64, 64], dtype="int16")
460445
placeholder_15 = T.match_buffer(placeholder_12, [1, 1, 1, 64], dtype="int32")
461446
T_cast_5 = T.match_buffer(T_cast_4, [1, 75, 75, 64], dtype="int16")
462-
global_workspace_3_buffer_var = T.match_buffer(global_workspace_3_var, [7200000], dtype="uint8", strides=[1], elem_offset=1, align=16)
447+
global_workspace_3_buffer_var = T.match_buffer(global_workspace_3_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16)
463448
# body
464449
PaddedInput_1_let: T.handle = T.address_of(T.load("uint8", global_workspace_3_buffer_var.data, 0), dtype="handle")
465450
for i0_i1_fused_1, i2_1, i3_1 in T.grid(77, 77, 64):
466451
T.store(PaddedInput_1_let, i0_i1_fused_1 * 4928 + i2_1 * 64 + i3_1, T.if_then_else(1 <= i0_i1_fused_1 and i0_i1_fused_1 < 76 and 1 <= i2_1 and i2_1 < 76, T.load("int16", placeholder_13.data, i0_i1_fused_1 * 4800 + i2_1 * 64 + i3_1 - 4864), T.int16(0), dtype="int16"), True)
467452
for ax0_ax1_fused_ax2_fused_1 in T.serial(0, 5625):
468-
Conv2dOutput_1_let: T.handle = T.address_of(T.load("uint8", global_workspace_3_buffer_var.data, 1478912), dtype="handle")
453+
Conv2dOutput_1_let: T.handle = T.address_of(T.load("uint8", global_workspace_3_buffer_var.data, 7200000), dtype="handle")
469454
for ff_1 in T.serial(0, 64):
470455
T.store(Conv2dOutput_1_let, ff_1, 0, True)
471456
for ry, rx, rc_1 in T.grid(3, 3, 64):
472457
T.store(Conv2dOutput_1_let, ff_1, T.load("int32", Conv2dOutput_1_let, ff_1) + T.cast(T.load("int16", PaddedInput_1_let, ax0_ax1_fused_ax2_fused_1 // 75 * 4928 + ry * 4928 + rx * 64 + ax0_ax1_fused_ax2_fused_1 % 75 * 64 + rc_1), "int32") * T.cast(T.load("int16", placeholder_14.data, ry * 12288 + rx * 4096 + rc_1 * 64 + ff_1), "int32"), True)
473458
for ax3_inner_2 in T.serial(0, 64):
474459
T.store(T_cast_5.data, ax0_ax1_fused_ax2_fused_1 * 64 + ax3_inner_2, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_1_let, ax3_inner_2) + T.load("int32", placeholder_15.data, ax3_inner_2), 1608879842, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16"), True)
460+
461+
@T.prim_func
462+
def run_model(input: T.handle, output: T.handle, global_workspace_0_var: T.handle) -> None:
463+
global_workspace_0_buffer_var = T.match_buffer(global_workspace_0_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16)
464+
# body
465+
T.attr("default", "device_id", 0)
466+
T.attr("default", "device_type", 1)
467+
sid_2_let: T.handle = T.address_of(T.load("uint8", global_workspace_0_buffer_var.data, 5760000), dtype="handle")
468+
sid_6_let: T.handle = T.address_of(T.load("uint8", global_workspace_0_buffer_var.data, 0), dtype="handle")
469+
sid_7_let: T.handle = T.address_of(T.load("uint8", global_workspace_0_buffer_var.data, 6480000), dtype="handle")
470+
sid_8_let: T.handle = T.address_of(T.load("uint8", global_workspace_0_buffer_var.data, 6480000), dtype="handle")
471+
T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast", input, T.lookup_param("p0", dtype="handle"), sid_2_let, global_workspace_0_buffer_var.data, dtype="int32"))
472+
T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", sid_2_let, T.lookup_param("p3", dtype="handle"), T.lookup_param("p4", dtype="handle"), sid_8_let, global_workspace_0_buffer_var.data, dtype="int32"))
473+
T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1", sid_8_let, T.lookup_param("p5", dtype="handle"), T.lookup_param("p6", dtype="handle"), sid_7_let, global_workspace_0_buffer_var.data, dtype="int32"))
474+
T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_", sid_7_let, T.lookup_param("p7", dtype="handle"), T.lookup_param("p8", dtype="handle"), sid_6_let, global_workspace_0_buffer_var.data, dtype="int32"))
475+
T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_", sid_2_let, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_6_let, output, global_workspace_0_buffer_var.data, dtype="int32"))
476+
__tvm_meta__ = None
475477
# fmt: on
476478

477479

478-
def test_fanout():
480+
def test_resnet_subgraph():
479481
target = Target("c")
480482
global_workspace_pool = usmp_utils.PoolInfo(
481483
pool_name="global_workspace",
@@ -498,6 +500,7 @@ def test_fanout():
498500
)(tir_mod)
499501

500502
tir_mod_with_offsets_ref = ResnetStructurePlanned
503+
501504
# The TIR produced fails on roundtrip TVMScript testing.
502505
# Therefore, indicates the TVMScript produced here and/or the parser
503506
# is lacking functionality. Thus for these tests, uses a string

0 commit comments

Comments
 (0)