@@ -202,7 +202,7 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholde
202202# fmt: on
203203
204204
205- def test_linear ():
205+ def test_mobilenet_subgraph ():
206206 target = Target ("c" )
207207 fast_memory_pool = usmp_utils .PoolInfo (
208208 pool_name = "fast_memory" ,
@@ -231,6 +231,7 @@ def test_linear():
231231 )(tir_mod )
232232
233233 tir_mod_with_offsets_ref = LinearStructurePlanned
234+ tir_mod_with_offsets_ref = tvm .script .from_source (tir_mod_with_offsets_ref .script (show_meta = False ))
234235 # The TIR produced fails on roundtrip TVMScript testing.
235236 # Therefore, indicates the TVMScript produced here and/or the parser
236237 # is lacking functionality. Thus for these tests, uses a string
@@ -365,40 +366,14 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(place
365366@tvm .script .ir_module
366367class ResnetStructurePlanned :
367368 @T .prim_func
368- def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast (placeholder_4 : T .handle , placeholder_5 : T .handle , placeholder_6 : T .handle , T_cast_2 : T .handle , global_workspace_2_var : T .handle ) -> None :
369- placeholder_7 = T .match_buffer (placeholder_4 , [1 , 75 , 75 , 64 ], dtype = "int16" )
370- placeholder_8 = T .match_buffer (placeholder_5 , [1 , 1 , 64 , 64 ], dtype = "int16" )
371- placeholder_9 = T .match_buffer (placeholder_6 , [1 , 1 , 1 , 64 ], dtype = "int32" )
372- T_cast_3 = T .match_buffer (T_cast_2 , [1 , 75 , 75 , 64 ], dtype = "int16" )
373- global_workspace_2_buffer_var = T .match_buffer (global_workspace_2_var , [7200000 ], dtype = "uint8" , strides = [1 ], elem_offset = 1 , align = 16 )
374- # body
375- PaddedInput_let : T .handle = T .address_of (T .load ("uint8" , global_workspace_2_buffer_var .data , 0 ), dtype = "handle" )
376- for i0_i1_fused , i2 , i3 in T .grid (75 , 75 , 64 ):
377- T .store (PaddedInput_let , i0_i1_fused * 4800 + i2 * 64 + i3 , T .load ("int16" , placeholder_7 .data , i0_i1_fused * 4800 + i2 * 64 + i3 ), True )
378- for ax0_ax1_fused_ax2_fused in T .serial (0 , 5625 ):
379- Conv2dOutput_let : T .handle = T .address_of (T .load ("uint8" , global_workspace_2_buffer_var .data , 1478912 ), dtype = "handle" )
380- for ff in T .serial (0 , 64 ):
381- T .store (Conv2dOutput_let , ff , 0 , True )
382- for rc in T .serial (0 , 64 ):
383- T .store (Conv2dOutput_let , ff , T .load ("int32" , Conv2dOutput_let , ff ) + T .cast (T .load ("int16" , PaddedInput_let , ax0_ax1_fused_ax2_fused * 64 + rc ), "int32" ) * T .cast (T .load ("int16" , placeholder_8 .data , rc * 64 + ff ), "int32" ), True )
384- for ax3_inner_1 in T .serial (0 , 64 ):
385- T .store (T_cast_3 .data , ax0_ax1_fused_ax2_fused * 64 + ax3_inner_1 , T .cast (T .cast (T .max (T .min (T .q_multiply_shift (T .load ("int32" , Conv2dOutput_let , ax3_inner_1 ) + T .load ("int32" , placeholder_9 .data , ax3_inner_1 ), 1843106743 , 31 , - 6 , dtype = "int32" ), 255 ), 0 ), "uint8" ), "int16" ), True )
386-
387- @T .prim_func
388- def run_model (input : T .handle , output : T .handle , global_workspace_0_var : T .handle ) -> None :
389- global_workspace_0_buffer_var = T .match_buffer (global_workspace_0_var , [7200000 ], dtype = "uint8" , strides = [1 ], elem_offset = 1 , align = 16 )
369+ def tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast (placeholder : T .handle , placeholder_1 : T .handle , T_cast : T .handle , global_workspace_1_var : T .handle ) -> None :
370+ placeholder_2 = T .match_buffer (placeholder , [1 , 75 , 75 , 64 ], dtype = "uint8" )
371+ placeholder_3 = T .match_buffer (placeholder_1 , [64 ], dtype = "int32" )
372+ T_cast_1 = T .match_buffer (T_cast , [1 , 75 , 75 , 64 ], dtype = "int16" )
373+ global_workspace_1_buffer_var = T .match_buffer (global_workspace_1_var , [7920256 ], dtype = "uint8" , strides = [1 ], elem_offset = 1 , align = 16 )
390374 # body
391- T .attr ("default" , "device_id" , 0 )
392- T .attr ("default" , "device_type" , 1 )
393- sid_2_let : T .handle = T .address_of (T .load ("uint8" , global_workspace_0_buffer_var .data , 6480000 ), dtype = "handle" )
394- sid_6_let : T .handle = T .address_of (T .load ("uint8" , global_workspace_0_buffer_var .data , 0 ), dtype = "handle" )
395- sid_7_let : T .handle = T .address_of (T .load ("uint8" , global_workspace_0_buffer_var .data , 758912 ), dtype = "handle" )
396- sid_8_let : T .handle = T .address_of (T .load ("uint8" , global_workspace_0_buffer_var .data , 758912 ), dtype = "handle" )
397- T .evaluate (T .call_extern ("tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast" , input , T .lookup_param ("p0" , dtype = "handle" ), sid_2_let , global_workspace_0_buffer_var .data , dtype = "int32" ))
398- T .evaluate (T .call_extern ("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast" , sid_2_let , T .lookup_param ("p3" , dtype = "handle" ), T .lookup_param ("p4" , dtype = "handle" ), sid_8_let , global_workspace_0_buffer_var .data , dtype = "int32" ))
399- T .evaluate (T .call_extern ("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1" , sid_8_let , T .lookup_param ("p5" , dtype = "handle" ), T .lookup_param ("p6" , dtype = "handle" ), sid_7_let , global_workspace_0_buffer_var .data , dtype = "int32" ))
400- T .evaluate (T .call_extern ("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_" , sid_7_let , T .lookup_param ("p7" , dtype = "handle" ), T .lookup_param ("p8" , dtype = "handle" ), sid_6_let , global_workspace_0_buffer_var .data , dtype = "int32" ))
401- T .evaluate (T .call_extern ("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_" , sid_2_let , T .lookup_param ("p1" , dtype = "handle" ), T .lookup_param ("p2" , dtype = "handle" ), sid_6_let , output , global_workspace_0_buffer_var .data , dtype = "int32" ))
375+ for ax0_ax1_fused , ax2 , ax3_outer , ax3_inner in T .grid (75 , 75 , 4 , 16 ):
376+ T .store (T_cast_1 .data , ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner , T .cast (T .cast (T .max (T .min (T .q_multiply_shift (T .cast (T .load ("uint8" , placeholder_2 .data , ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner ), "int32" ) - 94 , 1843157232 , 31 , 1 , dtype = "int32" ) + T .load ("int32" , placeholder_3 .data , ax3_outer * 16 + ax3_inner ), 255 ), 0 ), "uint8" ), "int16" ), True )
402377
403378 @T .prim_func
404379 def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_ (placeholder_22 : T .handle , placeholder_23 : T .handle , placeholder_24 : T .handle , placeholder_25 : T .handle , T_cast_6 : T .handle , global_workspace_5_var : T .handle ) -> None :
@@ -407,13 +382,13 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_s
407382 placeholder_26 = T .match_buffer (placeholder_24 , [1 , 1 , 1 , 256 ], dtype = "int32" )
408383 placeholder_28 = T .match_buffer (placeholder_25 , [1 , 75 , 75 , 256 ], dtype = "int32" )
409384 T_cast_7 = T .match_buffer (T_cast_6 , [1 , 75 , 75 , 256 ], dtype = "uint8" )
410- global_workspace_5_buffer_var = T .match_buffer (global_workspace_5_var , [7200000 ], dtype = "uint8" , strides = [1 ], elem_offset = 1 , align = 16 )
385+ global_workspace_5_buffer_var = T .match_buffer (global_workspace_5_var , [7920256 ], dtype = "uint8" , strides = [1 ], elem_offset = 1 , align = 16 )
411386 # body
412- PaddedInput_3_let : T .handle = T .address_of (T .load ("uint8" , global_workspace_5_buffer_var .data , 5760000 ), dtype = "handle" )
387+ PaddedInput_3_let : T .handle = T .address_of (T .load ("uint8" , global_workspace_5_buffer_var .data , 6480000 ), dtype = "handle" )
413388 for i0_i1_fused_3 , i2_3 , i3_3 in T .grid (75 , 75 , 64 ):
414389 T .store (PaddedInput_3_let , i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3 , T .load ("int16" , placeholder_29 .data , i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3 ), True )
415390 for ax0_ax1_fused_ax2_fused_3 in T .serial (0 , 5625 ):
416- Conv2dOutput_3_let : T .handle = T .address_of (T .load ("uint8" , global_workspace_5_buffer_var .data , 6480000 ), dtype = "handle" )
391+ Conv2dOutput_3_let : T .handle = T .address_of (T .load ("uint8" , global_workspace_5_buffer_var .data , 7200000 ), dtype = "handle" )
417392 for ax3_outer_2 in T .serial (0 , 4 ):
418393 for ff_3 in T .serial (0 , 64 ):
419394 T .store (Conv2dOutput_3_let , ff_3 , 0 , True )
@@ -428,13 +403,13 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_s
428403 placeholder_20 = T .match_buffer (placeholder_17 , [1 , 1 , 64 , 256 ], dtype = "int16" )
429404 placeholder_21 = T .match_buffer (placeholder_18 , [1 , 1 , 1 , 256 ], dtype = "int32" )
430405 T_add_1 = T .match_buffer (T_add , [1 , 75 , 75 , 256 ], dtype = "int32" )
431- global_workspace_4_buffer_var = T .match_buffer (global_workspace_4_var , [7200000 ], dtype = "uint8" , strides = [1 ], elem_offset = 1 , align = 16 )
406+ global_workspace_4_buffer_var = T .match_buffer (global_workspace_4_var , [7920256 ], dtype = "uint8" , strides = [1 ], elem_offset = 1 , align = 16 )
432407 # body
433- PaddedInput_2_let : T .handle = T .address_of (T .load ("uint8" , global_workspace_4_buffer_var .data , 5760000 ), dtype = "handle" )
408+ PaddedInput_2_let : T .handle = T .address_of (T .load ("uint8" , global_workspace_4_buffer_var .data , 7200000 ), dtype = "handle" )
434409 for i0_i1_fused_2 , i2_2 , i3_2 in T .grid (75 , 75 , 64 ):
435410 T .store (PaddedInput_2_let , i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2 , T .load ("int16" , placeholder_19 .data , i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2 ), True )
436411 for ax0_ax1_fused_ax2_fused_2 in T .serial (0 , 5625 ):
437- Conv2dOutput_2_let : T .handle = T .address_of (T .load ("uint8" , global_workspace_4_buffer_var .data , 6480000 ), dtype = "handle" )
412+ Conv2dOutput_2_let : T .handle = T .address_of (T .load ("uint8" , global_workspace_4_buffer_var .data , 7920000 ), dtype = "handle" )
438413 for ax3_outer_1 in T .serial (0 , 4 ):
439414 for ff_2 in T .serial (0 , 64 ):
440415 T .store (Conv2dOutput_2_let , ff_2 , 0 , True )
@@ -444,38 +419,65 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_s
444419 T .store (T_add_1 .data , ax0_ax1_fused_ax2_fused_2 * 256 + ax3_outer_1 * 64 + ax3_inner_3 , T .q_multiply_shift (T .cast (T .cast (T .max (T .min (T .q_multiply_shift (T .load ("int32" , Conv2dOutput_2_let , ax3_inner_3 ) + T .load ("int32" , placeholder_21 .data , ax3_outer_1 * 64 + ax3_inner_3 ), 1711626602 , 31 , - 8 , dtype = "int32" ) + 132 , 255 ), 0 ), "uint8" ), "int32" ) - 132 , 2094289803 , 31 , - 2 , dtype = "int32" ) + 136 , True )
445420
446421 @T .prim_func
447- def tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast (placeholder : T .handle , placeholder_1 : T .handle , T_cast : T .handle , global_workspace_1_var : T .handle ) -> None :
448- placeholder_2 = T .match_buffer (placeholder , [1 , 75 , 75 , 64 ], dtype = "uint8" )
449- placeholder_3 = T .match_buffer (placeholder_1 , [64 ], dtype = "int32" )
450- T_cast_1 = T .match_buffer (T_cast , [1 , 75 , 75 , 64 ], dtype = "int16" )
451- global_workspace_1_buffer_var = T .match_buffer (global_workspace_1_var , [7200000 ], dtype = "uint8" , strides = [1 ], elem_offset = 1 , align = 16 )
422+ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast (placeholder_4 : T .handle , placeholder_5 : T .handle , placeholder_6 : T .handle , T_cast_2 : T .handle , global_workspace_2_var : T .handle ) -> None :
423+ placeholder_7 = T .match_buffer (placeholder_4 , [1 , 75 , 75 , 64 ], dtype = "int16" )
424+ placeholder_8 = T .match_buffer (placeholder_5 , [1 , 1 , 64 , 64 ], dtype = "int16" )
425+ placeholder_9 = T .match_buffer (placeholder_6 , [1 , 1 , 1 , 64 ], dtype = "int32" )
426+ T_cast_3 = T .match_buffer (T_cast_2 , [1 , 75 , 75 , 64 ], dtype = "int16" )
427+ global_workspace_2_buffer_var = T .match_buffer (global_workspace_2_var , [7920256 ], dtype = "uint8" , strides = [1 ], elem_offset = 1 , align = 16 )
452428 # body
453- for ax0_ax1_fused , ax2 , ax3_outer , ax3_inner in T .grid (75 , 75 , 4 , 16 ):
454- T .store (T_cast_1 .data , ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner , T .cast (T .cast (T .max (T .min (T .q_multiply_shift (T .cast (T .load ("uint8" , placeholder_2 .data , ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner ), "int32" ) - 94 , 1843157232 , 31 , 1 , dtype = "int32" ) + T .load ("int32" , placeholder_3 .data , ax3_outer * 16 + ax3_inner ), 255 ), 0 ), "uint8" ), "int16" ), True )
429+ PaddedInput_let : T .handle = T .address_of (T .load ("uint8" , global_workspace_2_buffer_var .data , 7200000 ), dtype = "handle" )
430+ for i0_i1_fused , i2 , i3 in T .grid (75 , 75 , 64 ):
431+ T .store (PaddedInput_let , i0_i1_fused * 4800 + i2 * 64 + i3 , T .load ("int16" , placeholder_7 .data , i0_i1_fused * 4800 + i2 * 64 + i3 ), True )
432+ for ax0_ax1_fused_ax2_fused in T .serial (0 , 5625 ):
433+ Conv2dOutput_let : T .handle = T .address_of (T .load ("uint8" , global_workspace_2_buffer_var .data , 7920000 ), dtype = "handle" )
434+ for ff in T .serial (0 , 64 ):
435+ T .store (Conv2dOutput_let , ff , 0 , True )
436+ for rc in T .serial (0 , 64 ):
437+ T .store (Conv2dOutput_let , ff , T .load ("int32" , Conv2dOutput_let , ff ) + T .cast (T .load ("int16" , PaddedInput_let , ax0_ax1_fused_ax2_fused * 64 + rc ), "int32" ) * T .cast (T .load ("int16" , placeholder_8 .data , rc * 64 + ff ), "int32" ), True )
438+ for ax3_inner_1 in T .serial (0 , 64 ):
439+ T .store (T_cast_3 .data , ax0_ax1_fused_ax2_fused * 64 + ax3_inner_1 , T .cast (T .cast (T .max (T .min (T .q_multiply_shift (T .load ("int32" , Conv2dOutput_let , ax3_inner_1 ) + T .load ("int32" , placeholder_9 .data , ax3_inner_1 ), 1843106743 , 31 , - 6 , dtype = "int32" ), 255 ), 0 ), "uint8" ), "int16" ), True )
455440
456441 @T .prim_func
457442 def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1 (placeholder_10 : T .handle , placeholder_11 : T .handle , placeholder_12 : T .handle , T_cast_4 : T .handle , global_workspace_3_var : T .handle ) -> None :
458443 placeholder_13 = T .match_buffer (placeholder_10 , [1 , 75 , 75 , 64 ], dtype = "int16" )
459444 placeholder_14 = T .match_buffer (placeholder_11 , [3 , 3 , 64 , 64 ], dtype = "int16" )
460445 placeholder_15 = T .match_buffer (placeholder_12 , [1 , 1 , 1 , 64 ], dtype = "int32" )
461446 T_cast_5 = T .match_buffer (T_cast_4 , [1 , 75 , 75 , 64 ], dtype = "int16" )
462- global_workspace_3_buffer_var = T .match_buffer (global_workspace_3_var , [7200000 ], dtype = "uint8" , strides = [1 ], elem_offset = 1 , align = 16 )
447+ global_workspace_3_buffer_var = T .match_buffer (global_workspace_3_var , [7920256 ], dtype = "uint8" , strides = [1 ], elem_offset = 1 , align = 16 )
463448 # body
464449 PaddedInput_1_let : T .handle = T .address_of (T .load ("uint8" , global_workspace_3_buffer_var .data , 0 ), dtype = "handle" )
465450 for i0_i1_fused_1 , i2_1 , i3_1 in T .grid (77 , 77 , 64 ):
466451 T .store (PaddedInput_1_let , i0_i1_fused_1 * 4928 + i2_1 * 64 + i3_1 , T .if_then_else (1 <= i0_i1_fused_1 and i0_i1_fused_1 < 76 and 1 <= i2_1 and i2_1 < 76 , T .load ("int16" , placeholder_13 .data , i0_i1_fused_1 * 4800 + i2_1 * 64 + i3_1 - 4864 ), T .int16 (0 ), dtype = "int16" ), True )
467452 for ax0_ax1_fused_ax2_fused_1 in T .serial (0 , 5625 ):
468- Conv2dOutput_1_let : T .handle = T .address_of (T .load ("uint8" , global_workspace_3_buffer_var .data , 1478912 ), dtype = "handle" )
453+ Conv2dOutput_1_let : T .handle = T .address_of (T .load ("uint8" , global_workspace_3_buffer_var .data , 7200000 ), dtype = "handle" )
469454 for ff_1 in T .serial (0 , 64 ):
470455 T .store (Conv2dOutput_1_let , ff_1 , 0 , True )
471456 for ry , rx , rc_1 in T .grid (3 , 3 , 64 ):
472457 T .store (Conv2dOutput_1_let , ff_1 , T .load ("int32" , Conv2dOutput_1_let , ff_1 ) + T .cast (T .load ("int16" , PaddedInput_1_let , ax0_ax1_fused_ax2_fused_1 // 75 * 4928 + ry * 4928 + rx * 64 + ax0_ax1_fused_ax2_fused_1 % 75 * 64 + rc_1 ), "int32" ) * T .cast (T .load ("int16" , placeholder_14 .data , ry * 12288 + rx * 4096 + rc_1 * 64 + ff_1 ), "int32" ), True )
473458 for ax3_inner_2 in T .serial (0 , 64 ):
474459 T .store (T_cast_5 .data , ax0_ax1_fused_ax2_fused_1 * 64 + ax3_inner_2 , T .cast (T .cast (T .max (T .min (T .q_multiply_shift (T .load ("int32" , Conv2dOutput_1_let , ax3_inner_2 ) + T .load ("int32" , placeholder_15 .data , ax3_inner_2 ), 1608879842 , 31 , - 7 , dtype = "int32" ), 255 ), 0 ), "uint8" ), "int16" ), True )
460+
461+ @T .prim_func
462+ def run_model (input : T .handle , output : T .handle , global_workspace_0_var : T .handle ) -> None :
463+ global_workspace_0_buffer_var = T .match_buffer (global_workspace_0_var , [7920256 ], dtype = "uint8" , strides = [1 ], elem_offset = 1 , align = 16 )
464+ # body
465+ T .attr ("default" , "device_id" , 0 )
466+ T .attr ("default" , "device_type" , 1 )
467+ sid_2_let : T .handle = T .address_of (T .load ("uint8" , global_workspace_0_buffer_var .data , 5760000 ), dtype = "handle" )
468+ sid_6_let : T .handle = T .address_of (T .load ("uint8" , global_workspace_0_buffer_var .data , 0 ), dtype = "handle" )
469+ sid_7_let : T .handle = T .address_of (T .load ("uint8" , global_workspace_0_buffer_var .data , 6480000 ), dtype = "handle" )
470+ sid_8_let : T .handle = T .address_of (T .load ("uint8" , global_workspace_0_buffer_var .data , 6480000 ), dtype = "handle" )
471+ T .evaluate (T .call_extern ("tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast" , input , T .lookup_param ("p0" , dtype = "handle" ), sid_2_let , global_workspace_0_buffer_var .data , dtype = "int32" ))
472+ T .evaluate (T .call_extern ("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast" , sid_2_let , T .lookup_param ("p3" , dtype = "handle" ), T .lookup_param ("p4" , dtype = "handle" ), sid_8_let , global_workspace_0_buffer_var .data , dtype = "int32" ))
473+ T .evaluate (T .call_extern ("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1" , sid_8_let , T .lookup_param ("p5" , dtype = "handle" ), T .lookup_param ("p6" , dtype = "handle" ), sid_7_let , global_workspace_0_buffer_var .data , dtype = "int32" ))
474+ T .evaluate (T .call_extern ("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_" , sid_7_let , T .lookup_param ("p7" , dtype = "handle" ), T .lookup_param ("p8" , dtype = "handle" ), sid_6_let , global_workspace_0_buffer_var .data , dtype = "int32" ))
475+ T .evaluate (T .call_extern ("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_" , sid_2_let , T .lookup_param ("p1" , dtype = "handle" ), T .lookup_param ("p2" , dtype = "handle" ), sid_6_let , output , global_workspace_0_buffer_var .data , dtype = "int32" ))
476+ __tvm_meta__ = None
475477# fmt: on
476478
477479
478- def test_fanout ():
480+ def test_resnet_subgraph ():
479481 target = Target ("c" )
480482 global_workspace_pool = usmp_utils .PoolInfo (
481483 pool_name = "global_workspace" ,
@@ -498,6 +500,7 @@ def test_fanout():
498500 )(tir_mod )
499501
500502 tir_mod_with_offsets_ref = ResnetStructurePlanned
503+
501504 # The TIR produced fails on roundtrip TVMScript testing.
502505 # Therefore, indicates the TVMScript produced here and/or the parser
503506 # is lacking functionality. Thus for these tests, uses a string
0 commit comments