@@ -297,6 +297,45 @@ def test_parallel_alloc():
297297
298298 assert isinstance (body .body .body .body .body , tvm .tir .Allocate )
299299
300+ ib = tvm .tir .ir_builder .create ()
301+ n = te .var ("n" )
302+ with ib .for_range (0 , n , name = "i" , kind = "parallel" ) as i :
303+ j = ib .allocate ("int32" , 1 , name = "j" , scope = "global" )
304+ j [0 ] = 0
305+ with ib .while_loop (j [0 ] < 10 ):
306+ A = ib .allocate ("float32" , n , name = "A" , scope = "global" )
307+ A [j [0 ]] = A [j [0 ]] + 2
308+ j [0 ] += j [0 ] + 1
309+
310+ body = ib .get ()
311+ # parallel (i, 0, n) {
312+ # // attr [j] storage_scope = "global"
313+ # allocate j[int32 * 1]
314+ # j[0] = 0
315+ # while((j[0] < 10)){
316+ # // attr [A] storage_scope = "global"
317+ # allocate A[float32 * n]
318+ # A[j[0]] = (A[j[0]] + 2f)
319+ # j[0] = (j[0] + (j[0] + 1))
320+ # }
321+ # }
322+
323+ mod = tvm .IRModule .from_expr (tvm .tir .PrimFunc ([n ], body ))
324+ body = tvm .tir .transform .StorageRewrite ()(mod )["main" ].body
325+
326+ # parallel (i, 0, n) {
327+ # // attr [j] storage_scope = "global"
328+ # allocate j[int32 * 1]
329+ # // attr [A] storage_scope = "global"
330+ # allocate A[float32 * n]
331+ # j[0] = 0
332+ # while((j[0] < 10)){
333+ # A[j[0]] = (A[j[0]] + 2f)
334+ # j[0] = (j[0] + (j[0] + 1))
335+ # }
336+ # }
337+ assert isinstance (body .body .body , tvm .tir .Allocate )
338+
300339
301340def test_inplace_rule2 (scope_tb = "local_TB2" , max_bits = 1024 * 1024 * 1024 ):
302341 # Test Buffer
0 commit comments