Skip to content

Commit b0205c7

Browse files
yzh119MasterJH5574
authored andcommitted
Add "sparse" block attribute. (apache#26)
1 parent c714529 commit b0205c7

4 files changed

Lines changed: 26 additions & 8 deletions

File tree

src/tir/schedule/analysis/analysis.cc

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -163,12 +163,17 @@ Definition of a scope that is a stage pipeline:
163163
!IsReductionBlock(self, block_sref, scope_root_sref)) {
164164
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
165165
// NOTE(Zihao): check if the block has atomic attribute.
166-
auto&& it = block->annotations.find("atomic");
166+
auto&& it_atomic = block->annotations.find("atomic");
167167
bool is_atomic = false;
168-
if (it != block->annotations.end()) {
169-
is_atomic = ((*it).second).as<IntImmNode>()->value;
168+
if (it_atomic != block->annotations.end()) {
169+
is_atomic = ((*it_atomic).second).as<IntImmNode>()->value;
170170
}
171-
if (!is_atomic) {
171+
auto&& it_sparse = block->annotations.find("sparse");
172+
bool is_sparse = false;
173+
if (it_sparse != block->annotations.end()) {
174+
is_sparse = ((*it_sparse).second).as<IntImmNode>()->value;
175+
}
176+
if (!is_sparse && !is_atomic) {
172177
throw NotCompactDataFlowError(self->mod, GetRef<Stmt>(scope_root_subtree->stmt),
173178
GetRef<Block>(block));
174179
}

src/tir/transforms/lower_sparse_tir.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,8 +308,10 @@ class IndexTransformer : public StmtExprMutator {
308308
GenerateReadWriteRegions(sp_block, &reads, &writes);
309309

310310
// Step 5. Create the block and block-realize
311+
Map<String, ObjectRef> mapping;
312+
mapping.Set("sparse", Bool(true));
311313
Block block(block_iters, std::move(reads), std::move(writes), sp_block->name, std::move(body),
312-
std::move(init));
314+
std::move(init), {}, {}, std::move(mapping));
313315
BlockRealize block_realize(std::move(iter_bindings), const_true(), std::move(block));
314316

315317
// Step 6. Create outer loops and the block binding.

tests/python/sparsetir/test_tir_sparse_correctness.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def csrmm(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices: T.ha
3131
B = T.match_sparse_buffer(b, (T.to_dense(J), K), n * k, "float32")
3232
C = T.match_sparse_buffer(c, (I, K), m * k, "float32")
3333
with T.iter([T.cord(I), T.cord(J), T.cord(K)], "SRS", "csrmm") as [vi, vj, vk]:
34+
T.block_attr({"sparse": True})
3435
with T.init():
3536
C[vi, vk] = 0.0
3637
C[vi, vk] = C[vi, vk] + A[vi, vj] * B[vj, vk]
@@ -51,6 +52,7 @@ def csrmm_tir(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices:
5152
C[vi * K + vk] = 0.
5253
for j in T.serial(0, A_indptr[vi + 1] - A_indptr[vi]):
5354
with T.block("spmm_inner"):
55+
T.block_attr({"sparse": True})
5456
vj = T.axis.R(NNZ, j + A_indptr[vi])
5557
C[vi * K + vk] = C[vi * K + vk] + \
5658
A_data[vj] * B[A_indices[vj] * K + vk]
@@ -71,6 +73,7 @@ def bsrmm_tir(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices:
7173
C[(vio * BLOCK_SIZE + vii) * K + vk] = 0.
7274
for jo in T.serial(0, A_indptr[vio + 1] - A_indptr[vio]):
7375
with T.block("spmm_inner"):
76+
T.block_attr({"sparse": True})
7477
vjo = T.axis.R(NNZB, jo + A_indptr[vio])
7578
C[(vio * BLOCK_SIZE + vii) * K + vk] = C[(vio * BLOCK_SIZE + vii) * K + vk] + A_data[(
7679
vjo * BLOCK_SIZE + vii) * BLOCK_SIZE + vji] * B[(A_indices[vjo] * BLOCK_SIZE + vji) * K + vk]
@@ -85,6 +88,7 @@ def ellmm_tir(a: T.handle, b: T.handle, c: T.handle, indices: T.handle, M: T.int
8588
A_indices = T.match_buffer(indices, (M * NNZ_COLS,), "int32")
8689
for i, j, k in T.grid(M, NNZ_COLS, K):
8790
with T.block("spmm"):
91+
T.block_attr({"sparse": True})
8892
vi, vj, vk = T.axis.remap("SRS", [i, j, k])
8993
with T.init():
9094
C[vi * K + vk] = 0.
@@ -102,6 +106,7 @@ def sddmm_tir(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices:
102106
C_indices = T.match_buffer(indices, (NNZ,), "int32")
103107
for ij, k in T.grid(NNZ, K):
104108
with T.block("sddmm"):
109+
T.block_attr({"sparse": True})
105110
vij, vk = T.axis.remap("SR", [ij, k])
106111
T.reads([A[0: M * K], B[0: N * K], C_data[vij], C_indices[vij], C_indptr[0: M + 1]])
107112
T.writes([C_data[vij]])
@@ -262,10 +267,10 @@ def test_sddmm():
262267
)
263268
blk = sch.get_block("sddmm")
264269
ij, k = sch.get_loops(blk)
265-
#sch.decompose_reduction(blk, ij)
270+
# TODO(zihao): fix the behavior in the future.
271+
# sch.decompose_reduction(blk, ij)
266272
sch.bind(ij, "blockIdx.x")
267-
ko, ki = sch.split(k, [None, 1])
268-
sch.bind(ki, "threadIdx.x")
273+
sch.bind(k, "threadIdx.x")
269274

270275
# convert numpy tensor to tvm ndarray
271276
C_indices = tvm.nd.array(indices.astype("int32"), device=tvm.cuda(0))
@@ -276,6 +281,7 @@ def test_sddmm():
276281

277282
# build function
278283
f = tvm.build(sch.mod['main'], target="cuda")
284+
# print(f.imported_modules[0].get_source())
279285
f(X_nd, Y_nd, C_data, C_indptr, C_indices)
280286

281287
# assertion

tests/python/sparsetir/test_tir_sparse_lower.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def lowered_csrmm(
6969
for v_vi in T.serial(0, n):
7070
for v_vj, v_vk in T.grid(J_indptr[v_vi + 1] - J_indptr[v_vi], k):
7171
with T.block("csrmm"):
72+
T.block_attr({"sparse": True})
7273
vi, vj, vk = T.axis.remap("SRS", [v_vi, v_vj, v_vk])
7374
T.reads(
7475
[
@@ -125,6 +126,7 @@ def lowered_csr_reduce(
125126
for v_vi in T.serial(0, n):
126127
for v_vj in T.serial(0, J_indptr[v_vi + 1] - J_indptr[v_vi]):
127128
with T.block("csr_reduce"):
129+
T.block_attr({"sparse": True})
128130
vi, vj = T.axis.remap("SR", [v_vi, v_vj])
129131
T.reads([J_indptr[0 : n + 1], J_indices[0:nnz], A_data[0:nnz], B_data[0:n]])
130132
T.writes([B_data[0:n]])
@@ -190,6 +192,7 @@ def lowered_bsrmm(
190192
J_indptr[v_vi + 1] - J_indptr[v_vi], blk, blk, feat_size
191193
):
192194
with T.block("bsrmm"):
195+
T.block_attr({"sparse": True})
193196
vi, vj, vbi, vbj, vf = T.axis.remap("SRSRS", [v_vi, v_vj, v_vbi, v_vbj, v_vf])
194197
T.reads(
195198
[
@@ -263,6 +266,7 @@ def lowered_ellpack_mm(
263266
J_indices = T.match_buffer(indices, [nnz], dtype="int32")
264267
for v_vi, v_vj, v_vbi, v_vbj, v_vf in T.grid(nb, col, blk, blk, feat_size):
265268
with T.block("bsrmm"):
269+
T.block_attr({"sparse": True})
266270
vi, vj, vbi, vbj, vf = T.axis.remap("SRSRS", [v_vi, v_vj, v_vbi, v_vbj, v_vf])
267271
T.reads(
268272
[
@@ -359,6 +363,7 @@ def lowered_csr_element_wise(
359363
for v_vi in T.serial(0, m):
360364
for v_vj in T.serial(0, J_indptr[v_vi + 1] - J_indptr[v_vi]):
361365
with T.block("csr_element_wise"):
366+
T.block_attr({"sparse": True})
362367
vi, vj = T.axis.remap("SS", [v_vi, v_vj])
363368
T.reads([J_indptr[0 : m + 1], J_indices[0:nnz], A_data[0:nnz]])
364369
T.writes([B_data[0:nnz]])

0 commit comments

Comments
 (0)