Skip to content

Commit 966d018

Browse files
authored
[PTX] ldmatrix builtin to accelerate copying data from shared memory to warp memory (#10855)
We already have PTX mma and mma.sp builtin support in #9909 and #10339 . However, we have not supported corresponding data movement builtins for these mma instructions, so the data movement would not be as fast as wmma. This PR brings the `ldmatrix` builtin, which is a native PTX warp-level instruction (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-ldmatrix), and we can use it to load several (1/2/4) 8x8 matrices from shared memory to warp memory.
1 parent afe6793 commit 966d018

File tree

6 files changed

+263
-40
lines changed

6 files changed

+263
-40
lines changed

include/tvm/tir/builtin.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -623,6 +623,15 @@ TVM_DLL const Op& ptx_mma();
623623
*/
624624
TVM_DLL const Op& ptx_mma_sp();
625625

626+
/*!
627+
* \brief tvm intrinsic for ptx load matrix from shared memory.
628+
*
629+
* void ptx_ldmatrix(Bool trans, IntImm num, StringImm type,
630+
* Var local_ptr, Expr local_offset,
631+
* Var smem_ptr, Expr smem_offset);
632+
*/
633+
TVM_DLL const Op& ptx_ldmatrix();
634+
626635
// TODO(tvm-team) replace the usage of the vector operations by Shuffle.
627636
/*!
628637
* \brief Get the high level half of the vector

src/target/source/codegen_cuda.cc

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
#include <vector>
3434

3535
#include "literal/cuda_half_t.h"
36-
#include "ptx_mma.h"
36+
#include "ptx.h"
3737

3838
namespace tvm {
3939
namespace codegen {
@@ -772,11 +772,11 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
772772
// arg 3: A precision: fp16, fp32, ...
773773
// arg 4: B precision: fp16, fp32, ...
774774
// arg 5: C precision: fp16, fp32, ...
775-
// arg 6: A multiplicand
775+
// arg 6: A multiplicand pointer
776776
// arg 7: A multiplicand index
777-
// arg 8: B multiplicand
777+
// arg 8: B multiplicand pointer
778778
// arg 9: B multiplicand index
779-
// arg 10: C accumulator
779+
// arg 10: C accumulator pointer
780780
// arg 11: C accumulator index
781781
// arg 12: metadata
782782
// arg 13: metadata index
@@ -803,6 +803,24 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
803803
shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_offset, b_ref, b_offset,
804804
c_ref, c_offset, metadata, metadata_offset, sparse_selector, "", true, saturate);
805805
this->stream << asm_code;
806+
} else if (op->op.same_as(builtin::ptx_ldmatrix())) {
807+
// arg 0: whether the matrix is loaded in column major format or not.
808+
// arg 1: number of matrices to load.
809+
// arg 2: The data type in the matrix, .b16 is the only accepted data type.
810+
// arg 3: pointer to local buffer.
811+
// arg 4: The offset of the element to store in the local buffer.
812+
// arg 5: pointer to the shared memory buffer to load.
813+
// arg 6: The offset of the start element of the row to load in shared memory.
814+
ICHECK_EQ(op->args.size(), 7U);
815+
bool trans = Downcast<Bool>(op->args[0])->value;
816+
int num = Downcast<Integer>(op->args[1])->value;
817+
std::string type = Downcast<StringImm>(op->args[2])->value;
818+
std::string local_ptr = this->PrintExpr(op->args[3]);
819+
std::string local_elem_offset = this->PrintExpr(op->args[4]);
820+
std::string smem_ptr = this->PrintExpr(op->args[5]);
821+
std::string smem_elem_offset = this->PrintExpr(op->args[6]);
822+
this->stream << PrintLoadMatrixAssembly(trans, num, type, local_ptr, local_elem_offset,
823+
smem_ptr, smem_elem_offset);
806824
} else {
807825
CodeGenC::VisitExpr_(op, os);
808826
}
Lines changed: 101 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@
1818
*/
1919

2020
/*!
21-
* \file ptx_mma.cc
21+
* \file ptx.cc
2222
*/
2323

24-
#include "ptx_mma.h"
24+
#include "ptx.h"
2525

2626
#include <algorithm>
2727
#include <string>
@@ -60,13 +60,18 @@ enum class DataType : int {
6060
kFloat32 = 13,
6161
kTensorFloat32 = 14,
6262
kFloat64 = 15,
63-
kBit1 = 16
63+
kBit1 = 16,
64+
kBit8 = 17,
65+
kBit16 = 18,
66+
kBit32 = 19,
67+
kBit64 = 20,
6468
};
6569

66-
static const char* dtype_str[] = {".s4", ".u4", ".s8", ".u8", ".s16", ".u16",
67-
".s32", ".u32", ".s64", ".u64", ".f16", ".bf16",
68-
".f16x2", ".f32", ".tf32", ".f64", ".b1"};
69-
static const uint32_t num_bits[] = {4, 4, 8, 8, 16, 16, 32, 32, 64, 64, 16, 16, 32, 32, 32, 64, 1};
70+
static const char* dtype_str[] = {".s4", ".u4", ".s8", ".u8", ".s16", ".u16", ".s32",
71+
".u32", ".s64", ".u64", ".f16", ".bf16", ".f16x2", ".f32",
72+
".tf32", ".f64", ".b1", ".b8", ".b16", ".b32", ".b64"};
73+
static const uint32_t num_bits[] = {4, 4, 8, 8, 16, 16, 32, 32, 64, 64, 16,
74+
16, 32, 32, 32, 64, 1, 8, 16, 32, 64};
7075

7176
/*!
7277
* \brief Create PTX data type from string.
@@ -106,6 +111,14 @@ inline DataType DTypeFromString(const std::string str) {
106111
return DataType::kFloat64;
107112
} else if (str == "int1" || str == ".b1") {
108113
return DataType::kBit1;
114+
} else if (str == ".b8") {
115+
return DataType::kBit8;
116+
} else if (str == ".b16") {
117+
return DataType::kBit16;
118+
} else if (str == ".b32") {
119+
return DataType::kBit32;
120+
} else if (str == ".b64") {
121+
return DataType::kBit64;
109122
} else {
110123
LOG(FATAL) << "Unrecognized PTX data type " << str;
111124
return DataType(0);
@@ -360,6 +373,7 @@ inline FragAttrs GetFragAttrs(DataType dtype) {
360373
case DataType::kUInt4:
361374
case DataType::kInt8:
362375
case DataType::kUInt8:
376+
case DataType::kBit16:
363377
case DataType::kFloat16: // .f16x2 register
364378
case DataType::kBFloat16:
365379
case DataType::kTensorFloat32:
@@ -508,9 +522,9 @@ inline std::tuple<std::string, std::string, std::string> GetMMAOperands(int m, i
508522
std::string PrintMMAAssembly(const std::string& shape, const std::string& A_layout,
509523
const std::string& B_layout, const std::string& A_dtype,
510524
const std::string& B_dtype, const std::string& C_dtype,
511-
const std::string& a_ref, const std::string& a_offset,
512-
const std::string& b_ref, const std::string& b_offset,
513-
const std::string& c_ref, const std::string& c_offset,
525+
const std::string& a_ptr, const std::string& a_elem_offset,
526+
const std::string& b_ptr, const std::string& b_elem_offset,
527+
const std::string& c_ptr, const std::string& c_elem_offset,
514528
const std::string& metadata, const std::string& metadata_offset,
515529
const std::string& sparsity_selector, const std::string& bit_op,
516530
bool sparse, bool saturate) {
@@ -525,7 +539,7 @@ std::string PrintMMAAssembly(const std::string& shape, const std::string& A_layo
525539
std::string asm_code = R"(
526540
{
527541
__asm__ __volatile__(
528-
"mma{sparse}.sync.aligned.{shape}.{alayout}.{blayout}{saturate}{dtype}{atype}{btype}{ctype}{bitop}"
542+
"mma{.sparse}.sync.aligned{.shape}{.alayout}{.blayout}{.saturate}{.dtype}{.atype}{.btype}{.ctype}{.bitop}"
529543
"{templates};\n"
530544
: {outputs}
531545
: {inputs});
@@ -537,30 +551,92 @@ std::string PrintMMAAssembly(const std::string& shape, const std::string& A_layo
537551

538552
// replace patterns
539553
Replacer replacer;
540-
replacer.register_rule("{sparse}", sparse ? ".sp" : "");
541-
replacer.register_rule("{shape}", shape);
542-
replacer.register_rule("{saturate}", saturate ? ".satfinite" : "");
543-
replacer.register_rule("{alayout}", A_layout);
544-
replacer.register_rule("{blayout}", B_layout);
545-
replacer.register_rule("{atype}", ptx::DTypeToString(dtype_a));
546-
replacer.register_rule("{btype}", ptx::DTypeToString(dtype_b));
547-
replacer.register_rule("{ctype}", ptx::DTypeToString(dtype_c));
548-
replacer.register_rule("{dtype}", ptx::DTypeToString(dtype_c));
549-
replacer.register_rule("{bitop}", bit_op.empty() ? "" : "." + bit_op + ".popc");
554+
replacer.register_rule("{.sparse}", sparse ? ".sp" : "");
555+
replacer.register_rule("{.shape}", "." + shape);
556+
replacer.register_rule("{.saturate}", saturate ? ".satfinite" : "");
557+
replacer.register_rule("{.alayout}", "." + A_layout);
558+
replacer.register_rule("{.blayout}", "." + B_layout);
559+
replacer.register_rule("{.atype}", ptx::DTypeToString(dtype_a));
560+
replacer.register_rule("{.btype}", ptx::DTypeToString(dtype_b));
561+
replacer.register_rule("{.ctype}", ptx::DTypeToString(dtype_c));
562+
replacer.register_rule("{.dtype}", ptx::DTypeToString(dtype_c));
563+
replacer.register_rule("{.bitop}", bit_op.empty() ? "" : "." + bit_op + ".popc");
550564
replacer.register_rule("{templates}", templates_str);
551565
replacer.register_rule("{outputs}", outputs_str);
552566
replacer.register_rule("{inputs}", inputs_str);
553567
asm_code = replacer.rewrite(asm_code);
554568
replacer.empty_rules();
555-
replacer.register_rule("A", a_ref + " + " + a_offset);
556-
replacer.register_rule("B", b_ref + " + " + b_offset);
557-
replacer.register_rule("C", c_ref + " + " + c_offset);
558-
replacer.register_rule("D", c_ref + " + " + c_offset);
569+
replacer.register_rule("A", a_ptr + " + " + a_elem_offset);
570+
replacer.register_rule("B", b_ptr + " + " + b_elem_offset);
571+
replacer.register_rule("C", c_ptr + " + " + c_elem_offset);
572+
replacer.register_rule("D", c_ptr + " + " + c_elem_offset);
559573
replacer.register_rule("E", metadata + " + " + metadata_offset);
560574
replacer.register_rule("F", sparsity_selector);
561575
asm_code = replacer.rewrite(asm_code);
562576
return asm_code;
563577
}
564578

579+
inline std::tuple<std::string, std::string> GetLoadMatrixOperands(
580+
int num, const std::string& local_ptr, const std::string& local_elem_offset) {
581+
std::stringstream templates, outputs;
582+
int arg_counter = 0;
583+
// generate templates
584+
templates << "{%" << arg_counter++;
585+
for (int i = 1; i < num; ++i) {
586+
templates << ", %" << arg_counter++;
587+
}
588+
templates << "}, [%" << arg_counter++ << "]";
589+
// generate outputs
590+
std::string ptr_type = "(unsigned *)";
591+
for (int i = 0; i < num; ++i) {
592+
if (i != 0) {
593+
outputs << ", ";
594+
}
595+
outputs << "\"=r\"((" << ptr_type << "(" << local_ptr << " + " << local_elem_offset << "))["
596+
<< i << "])";
597+
}
598+
return std::make_tuple(templates.str(), outputs.str());
599+
}
600+
601+
std::string PrintLoadMatrixAssembly(bool trans, int num, const std::string& type,
602+
const std::string& local_ptr,
603+
const std::string& local_elem_offset,
604+
const std::string& smem_ptr,
605+
const std::string& smem_elem_offset) {
606+
CHECK(num == 1 || num == 2 || num == 4) << "ldmatrix only accept loading 1/2/4 matrices.";
607+
ptx::DataType data_type = ptx::DTypeFromString(type);
608+
CHECK(data_type == ptx::DataType::kBit16) << "ldmatrix only accept matrix with type .b16.";
609+
std::string asm_code = R"(
610+
{
611+
unsigned int addr;
612+
__asm__ __volatile__(
613+
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
614+
: "=r"(addr)
615+
: "l"((void *)({smem_addr}))
616+
);
617+
__asm__ __volatile__(
618+
"ldmatrix.sync.aligned{.shape}{.num}{.trans}{.ss}{.type}"
619+
"{templates};\n"
620+
: {outputs}
621+
: "r"(addr)
622+
);
623+
}
624+
)";
625+
std::string templates_str, outputs_str;
626+
std::tie(templates_str, outputs_str) = GetLoadMatrixOperands(num, local_ptr, local_elem_offset);
627+
628+
Replacer replacer;
629+
replacer.register_rule("{.shape}", ".m8n8");
630+
replacer.register_rule("{.num}", ".x" + std::to_string(num));
631+
replacer.register_rule("{.trans}", trans ? ".trans" : "");
632+
replacer.register_rule("{.ss}", ".shared");
633+
replacer.register_rule("{.type}", ptx::DTypeToString(data_type));
634+
replacer.register_rule("{smem_addr}", smem_ptr + " + " + smem_elem_offset);
635+
replacer.register_rule("{templates}", templates_str);
636+
replacer.register_rule("{outputs}", outputs_str);
637+
asm_code = replacer.rewrite(asm_code);
638+
return asm_code;
639+
}
640+
565641
} // namespace codegen
566642
} // namespace tvm
Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@
1818
*/
1919

2020
/*!
21-
* \file ptx_mma.h
22-
* \brief MMA code generation with inlined PTX code.
21+
* \file ptx.h
22+
* \brief Code generation with inlined PTX code.
2323
*/
24-
#ifndef TVM_TARGET_SOURCE_PTX_MMA_H_
25-
#define TVM_TARGET_SOURCE_PTX_MMA_H_
24+
#ifndef TVM_TARGET_SOURCE_PTX_H_
25+
#define TVM_TARGET_SOURCE_PTX_H_
2626

2727
#include <tvm/runtime/logging.h>
2828

@@ -40,11 +40,11 @@ namespace codegen {
4040
* \param A_dtype The data type of multiplicand A.
4141
* \param B_dtype The data type of multiplicand B.
4242
* \param C_dtype The data type of multiplicand C.
43-
* \param a_ref Pointer to buffer A.
43+
* \param a_ptr Pointer to buffer A.
4444
* \param a_offset The offset of element in A.
45-
* \param b_ref Pointer to buffer B.
45+
* \param b_ptr Pointer to buffer B.
4646
* \param b_offset The offset of element in B.
47-
* \param c_ref Pointer to buffer C.
47+
* \param c_ptr Pointer to buffer C.
4848
* \param c_offset The offset of element in C.
4949
* \param metadata Pointer to metadata buffer (only used for sparse mma).
5050
* \param metadata_offset The offset of element in metadata.
@@ -56,14 +56,30 @@ namespace codegen {
5656
std::string PrintMMAAssembly(const std::string& shape, const std::string& A_layout,
5757
const std::string& B_layout, const std::string& A_dtype,
5858
const std::string& B_dtype, const std::string& C_dtype,
59-
const std::string& a_ref, const std::string& a_offset,
60-
const std::string& b_ref, const std::string& b_offset,
61-
const std::string& c_ref, const std::string& c_offset,
59+
const std::string& a_ptr, const std::string& a_offset,
60+
const std::string& b_ptr, const std::string& b_offset,
61+
const std::string& c_ptr, const std::string& c_offset,
6262
const std::string& metadata, const std::string& metadata_offset,
6363
const std::string& sparsity_selector, const std::string& bit_op,
6464
bool sparse, bool saturate);
6565

66+
/*!
67+
* \brief Print ldmatrix assembly string given parameters.
68+
* \param trans: whether the matrix is loaded in column major format or not.
69+
* \param num: number of matrices to load.
70+
* \param type: The data type in the matrix, .b16 is the only accepted data type.
71+
* \param local_ptr: pointer to local buffer.
72+
* \param local_elem_offset: The offset of the element to store in the local buffer.
73+
* \param smem_ptr: pointer to the shared memory buffer to load.
74+
* \param smem_elem_offset: The offset of the start element of the row to load in shared memory.
75+
*/
76+
std::string PrintLoadMatrixAssembly(bool trans, int num, const std::string& type,
77+
const std::string& local_ptr,
78+
const std::string& local_elem_offset,
79+
const std::string& smem_ptr,
80+
const std::string& smem_elem_offset);
81+
6682
} // namespace codegen
6783
} // namespace tvm
6884

69-
#endif // TVM_TARGET_SOURCE_PTX_MMA_H_
85+
#endif // TVM_TARGET_SOURCE_PTX_H_

src/tir/op/builtin.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,9 @@ TIR_DEFINE_BUILTIN_FUNC(ptx_mma).set_attr<TCallEffectKind>("TCallEffectKind",
244244
TIR_DEFINE_BUILTIN_FUNC(ptx_mma_sp)
245245
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
246246

247+
TIR_DEFINE_BUILTIN_FUNC(ptx_ldmatrix)
248+
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
249+
247250
TIR_DEFINE_BUILTIN_FUNC(vectorhigh)
248251
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure));
249252

0 commit comments

Comments
 (0)