@@ -31,6 +31,7 @@ def csrmm(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices: T.ha
3131 B = T .match_sparse_buffer (b , (T .to_dense (J ), K ), n * k , "float32" )
3232 C = T .match_sparse_buffer (c , (I , K ), m * k , "float32" )
3333 with T .iter ([T .cord (I ), T .cord (J ), T .cord (K )], "SRS" , "csrmm" ) as [vi , vj , vk ]:
34+ T .block_attr ({"sparse" : True })
3435 with T .init ():
3536 C [vi , vk ] = 0.0
3637 C [vi , vk ] = C [vi , vk ] + A [vi , vj ] * B [vj , vk ]
@@ -51,6 +52,7 @@ def csrmm_tir(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices:
5152 C [vi * K + vk ] = 0.
5253 for j in T .serial (0 , A_indptr [vi + 1 ] - A_indptr [vi ]):
5354 with T .block ("spmm_inner" ):
55+ T .block_attr ({"sparse" : True })
5456 vj = T .axis .R (NNZ , j + A_indptr [vi ])
5557 C [vi * K + vk ] = C [vi * K + vk ] + \
5658 A_data [vj ] * B [A_indices [vj ] * K + vk ]
@@ -71,6 +73,7 @@ def bsrmm_tir(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices:
7173 C [(vio * BLOCK_SIZE + vii ) * K + vk ] = 0.
7274 for jo in T .serial (0 , A_indptr [vio + 1 ] - A_indptr [vio ]):
7375 with T .block ("spmm_inner" ):
76+ T .block_attr ({"sparse" : True })
7477 vjo = T .axis .R (NNZB , jo + A_indptr [vio ])
7578 C [(vio * BLOCK_SIZE + vii ) * K + vk ] = C [(vio * BLOCK_SIZE + vii ) * K + vk ] + A_data [(
7679 vjo * BLOCK_SIZE + vii ) * BLOCK_SIZE + vji ] * B [(A_indices [vjo ] * BLOCK_SIZE + vji ) * K + vk ]
@@ -85,6 +88,7 @@ def ellmm_tir(a: T.handle, b: T.handle, c: T.handle, indices: T.handle, M: T.int
8588 A_indices = T .match_buffer (indices , (M * NNZ_COLS ,), "int32" )
8689 for i , j , k in T .grid (M , NNZ_COLS , K ):
8790 with T .block ("spmm" ):
91+ T .block_attr ({"sparse" : True })
8892 vi , vj , vk = T .axis .remap ("SRS" , [i , j , k ])
8993 with T .init ():
9094 C [vi * K + vk ] = 0.
@@ -102,6 +106,7 @@ def sddmm_tir(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices:
102106 C_indices = T .match_buffer (indices , (NNZ ,), "int32" )
103107 for ij , k in T .grid (NNZ , K ):
104108 with T .block ("sddmm" ):
109+ T .block_attr ({"sparse" : True })
105110 vij , vk = T .axis .remap ("SR" , [ij , k ])
106111 T .reads ([A [0 : M * K ], B [0 : N * K ], C_data [vij ], C_indices [vij ], C_indptr [0 : M + 1 ]])
107112 T .writes ([C_data [vij ]])
@@ -262,10 +267,10 @@ def test_sddmm():
262267 )
263268 blk = sch .get_block ("sddmm" )
264269 ij , k = sch .get_loops (blk )
265- #sch.decompose_reduction(blk, ij)
270+ # TODO(zihao): fix the behavior in the future.
271+ # sch.decompose_reduction(blk, ij)
266272 sch .bind (ij , "blockIdx.x" )
267- ko , ki = sch .split (k , [None , 1 ])
268- sch .bind (ki , "threadIdx.x" )
273+ sch .bind (k , "threadIdx.x" )
269274
270275 # convert numpy tensor to tvm ndarray
271276 C_indices = tvm .nd .array (indices .astype ("int32" ), device = tvm .cuda (0 ))
@@ -276,6 +281,7 @@ def test_sddmm():
276281
277282 # build function
278283 f = tvm .build (sch .mod ['main' ], target = "cuda" )
284+ # print(f.imported_modules[0].get_source())
279285 f (X_nd , Y_nd , C_data , C_indptr , C_indices )
280286
281287 # assertion
0 commit comments