Skip to content

Commit deb4d66

Browse files
committed
update lower warp memory
1 parent 71fe5fe commit deb4d66

File tree

2 files changed

+18
-6
lines changed

2 files changed

+18
-6
lines changed

src/tir/transforms/lower_warp_memory.cc

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,17 @@ class WarpAccessRewriter : protected StmtExprMutator {
281281
return GetRef<PrimExpr>(op);
282282
}
283283

284+
if (op->op.same_as(builtin::mma_store())) {
285+
// Array<PrimExpr> new_args = op->args;
286+
// PrimExpr local_index, group;
287+
// if (op->args[3].get() == buffer_) {
288+
// std::tie(local_index, group) = SplitIndexByGroup(op->args[4]);
289+
// new_args.Set(4, local_index);
290+
// return Call(op->dtype, op->op, new_args);
291+
// }
292+
return GetRef<PrimExpr>(op);
293+
}
294+
284295
return StmtExprMutator::VisitExpr_(op);
285296
}
286297

@@ -466,11 +477,13 @@ namespace transform {
466477
Pass LowerWarpMemory() {
467478
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
468479
auto* n = f.CopyOnWrite();
480+
// LOG(INFO) << f;
469481
auto target = f->GetAttr<Target>(tvm::attr::kTarget);
470-
int warp_size = target.value()->GetAttr<Integer>("thread_warp_size", 1).value();
482+
int warp_size = 32;
471483
WarpMemoryRewriter warp_memory_rewriter(warp_size);
472484
auto stmt = warp_memory_rewriter.Rewrite(std::move(n->body));
473485
n->body = UpdatePointerStorageScope(warp_memory_rewriter.new_storage_scopes_)(stmt);
486+
LOG(INFO) << f;
474487
return f;
475488
};
476489
return CreatePrimFuncPass(pass_func, 0, "tir.LowerWarpMemory", {});

tests/python/unittest/test_mma_16x8x8_4k_tune.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -184,8 +184,8 @@ def mma_store_desc(a: T.handle, c: T.handle) -> None:
184184

185185
@T.prim_func
186186
def mma_store_impl(a: T.handle, c: T.handle) -> None:
187-
C_warp = T.match_buffer(a, [32, 4], dtype="float32", scope="warp")
188-
C = T.match_buffer(c, [16, 8], dtype="float32", scope="global")
187+
C_warp = T.match_buffer(a, [32, 4], dtype="float32", scope="warp", offset_factor=1)
188+
C = T.match_buffer(c, [16, 8], dtype="float32", scope="global",offset_factor=1)
189189

190190
with T.block("root"):
191191
T.reads(C_warp[0:32, 0:4])
@@ -351,9 +351,8 @@ def lambda_b(i, j):
351351
)
352352

353353
if use_ldmatrix:
354-
# sch.tensorize(loop_a, "mma.ldmatrix_a")
355-
# sch.tensorize(loop_b, "mma.ldmatrix_b")
356-
pass
354+
sch.tensorize(loop_a, "mma.ldmatrix_a")
355+
sch.tensorize(loop_b, "mma.ldmatrix_b")
357356
else:
358357
warp_loop1, warp_loop2 = sch.get_loops(A_warp)[-2:]
359358
f_0, f_1 = sch.split(warp_loop1, factors=[None, 8])

0 commit comments

Comments
 (0)