@@ -250,6 +250,50 @@ def transformed_strided_buffer_func(
250250 C [i0 * 4 + i1 , j ] = B [i1 , j ] * T .float32 (2 )
251251
252252
253+ @T .prim_func
254+ def compacted_symbolic_strided_buffer_func (a : T .handle ) -> None :
255+ n = T .int32 ()
256+ A = T .match_buffer (a , (1 , n , 10240 ))
257+ padded_size = T .meta_var (T .min ((n + 63 ) // 64 * 64 , 96 ))
258+ # with T.block("root"):
259+ for i , j , k in T .grid (((n + 63 ) // 64 * 4 + 7 ) // 8 , 2 , 160 ):
260+ with T .block ("" ):
261+ A_pad_shared_dyn = T .alloc_buffer (
262+ (1 , padded_size , 64 ), strides = (72 * padded_size , 72 , 1 ), scope = "shared.dyn"
263+ )
264+ for ax0 , ax1 in T .grid (96 , 64 ):
265+ with T .block ("A_pad_shared.dyn" ):
266+ T .where (i * 128 + j * 32 + ax0 < (n + 63 ) // 64 * 64 )
267+ A_pad_shared_dyn [0 , ax0 , ax1 ] = T .if_then_else (
268+ i * 128 + j * 32 + ax0 < n ,
269+ A [0 , i * 128 + j * 32 + ax0 , k * 64 + ax1 ],
270+ T .float32 (0 ),
271+ )
272+
273+
274+ @T .prim_func
275+ def transformed_symbolic_strided_buffer_func (a : T .handle ):
276+ n = T .int32 ()
277+ A = T .match_buffer (a , (1 , n , 10240 ))
278+ for i , j , k in T .grid (((n + 63 ) // 64 * 4 + 7 ) // 8 , 2 , 160 ):
279+ A_pad_shared_dyn = T .allocate (
280+ [1 , T .min ((n + 63 ) // 64 * 64 , 96 ), 72 ], "float32" , "shared.dyn"
281+ )
282+ A_pad_shared_dyn_1 = T .decl_buffer (
283+ (1 , T .min ((n + 63 ) // 64 * 64 , 96 ), 64 ),
284+ data = A_pad_shared_dyn ,
285+ strides = (72 * T .min ((n + 63 ) // 64 * 64 , 96 ), 72 , 1 ),
286+ scope = "shared.dyn" ,
287+ )
288+ for ax0 , ax1 in T .grid (96 , 64 ):
289+ if i * 128 + j * 32 + ax0 < (n + 63 ) // 64 * 64 :
290+ A_pad_shared_dyn_1 [0 , ax0 , ax1 ] = T .if_then_else (
291+ i * 128 + j * 32 + ax0 < n ,
292+ A [0 , i * 128 + j * 32 + ax0 , k * 64 + ax1 ],
293+ T .float32 (0 ),
294+ )
295+
296+
253297@T .prim_func
254298def annotated_loops (a : T .handle ) -> None :
255299 A = T .match_buffer (a , (16 ,), "float32" )
@@ -301,6 +345,10 @@ def test_strided_buffer():
301345 _check (compacted_strided_buffer_func , transformed_strided_buffer_func )
302346
303347
348+ def test_symbolic_strided_buffer ():
349+ _check (compacted_symbolic_strided_buffer_func , transformed_symbolic_strided_buffer_func )
350+
351+
304352def test_lower_te ():
305353 x = te .placeholder ((1 ,))
306354 y = te .compute ((1 ,), lambda i : x [i ] + 2 )
0 commit comments