Skip to content

Commit 324a960

Browse files
HzfengsyLaurawly
authored andcommitted
TensorCore Support using Intrinsic (#4136)
* add tensor core support * avoid memory bank conflict * fix thread sync & better performance * better performance * add schedule test for conv2d * extend into BatchMatMul * support config fragment shape and layout using intrinsic * add TensorCore tutorial * add int support and fix lint * address comment * add 32*16*8 TensorCore test * fix wmma include logic
1 parent 4ab7363 commit 324a960

15 files changed

Lines changed: 1274 additions & 7 deletions

File tree

include/tvm/ir.h

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1310,6 +1310,16 @@ constexpr const char* opengl_stage_scope = "opengl_stage_scope";
13101310
*/
13111311
constexpr const char* device_scope = "device_scope";
13121312

1313+
/*!
1314+
* \brief Mark that the shape of TensorCore fragment
1315+
*/
1316+
constexpr const char* fragment_shape = "fragment_shape";
1317+
1318+
/*!
1319+
* \brief Mark that the layout of TensorCore fragment
1320+
*/
1321+
constexpr const char* fragment_layout = "fragment_layout";
1322+
13131323
/*!
13141324
* \brief Check if attr_key is a pragma key extension
13151325
* \param attr_key The attr key to be compared
@@ -1552,6 +1562,54 @@ constexpr const char* tvm_global_barrier_kinit = "tvm_global_barrier_kinit";
15521562
* }
15531563
*/
15541564
constexpr const char* tvm_thread_allreduce = "tvm_thread_allreduce";
1565+
/*!
1566+
* \brief tvm intrinsic for tensor core load operators.
1567+
*
1568+
* void tvm_load_matrix_sync(Var fragment, UIntImm m, UIntImm, n, UIntImm k,
1569+
* Expr index, Expr buffer_ptr, Expr stride,
1570+
* StringImm layout) {
1571+
* // m, n, k are the shape of wmma fragment.
1572+
* // Determine fragment layout(column-major or row major) by layout.
1573+
* // fragments must be in 'wmma.matrix_a' or 'wmma.matrix_b' scope.
1574+
* nvcuda::wmma::load_matrix_sync(fragment[index], buffer_ptr, stride);
1575+
* }
1576+
*/
1577+
constexpr const char* tvm_load_matrix_sync = "tvm_load_matrix_sync";
1578+
/*!
1579+
* \brief tvm intrinsic for tensor core mma_sync operators.
1580+
*
1581+
* void tvm_mma_sync(Var fragment_d, Expr index_d,
1582+
* Var fragment_a, Expr index_a,
1583+
* Var fragment_b, Expr index_b,
1584+
* Var fragment_c, Expr index_c) {
1585+
* nvcuda::wmma::mma_sync(fragment_d[index_d], fragment_a[index_a],
1586+
* fragment_b[index_b], fragment_c[index_c]);
1587+
* }
1588+
*/
1589+
constexpr const char* tvm_mma_sync = "tvm_mma_sync";
1590+
/*!
1591+
* \brief tvm intrinsic for tensor core fill_fragment operators.
1592+
*
1593+
* void tvm_fill_fragment(Var fragment, UIntImm m, UIntImm, n, UIntImm k,
1594+
* Expr index, Expr value) {
1595+
* // m, n, k are the shape of wmma fragment
1596+
* // fragments must be in 'wmma.accumulator' scope.
1597+
* nvcuda::wmma::fill_fragment(fragment[index], value);
1598+
* }
1599+
*/
1600+
constexpr const char* tvm_fill_fragment = "tvm_fill_fragment";
1601+
/*!
1602+
* \brief tvm intrinsic for tensor core store operators.
1603+
*
1604+
* void tvm_store_matrix_sync(Var fragment, UIntImm m, UIntImm, n, UIntImm k,
1605+
* Expr index, Expr buffer_ptr, Expr stride,
1606+
* StringImm layout) {
1607+
* // m, n, k are the shape of wmma fragment
1608+
* // fragments must be in 'wmma.accumulator' scope.
1609+
* nvcuda::wmma::store_matrix_sync(fragment[index], buffer_ptr, stride, layout);
1610+
* }
1611+
*/
1612+
constexpr const char* tvm_store_matrix_sync = "tvm_store_matrix_sync";
15551613

15561614
} // namespace intrinsic
15571615

include/tvm/ir_pass.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -513,6 +513,15 @@ LoweredFunc CombineContextCall(LoweredFunc f);
513513
*/
514514
LoweredFunc PointerValueTypeRewrite(LoweredFunc f);
515515

516+
/*!
517+
* \brief Lower attached storage access information on device.
518+
* Do this pass after all storage access analysis finish.
519+
*
520+
* \param func The device function to be lowered.
521+
* \return Transformed function.
522+
*/
523+
LoweredFunc LowerDeviceStorageAccessInfo(LoweredFunc func);
524+
516525
/*!
517526
* \brief Lower intrinsic function calls.
518527
* \param f The device function to be lowered.
@@ -532,6 +541,14 @@ LoweredFunc LowerIntrin(LoweredFunc f, const std::string& target);
532541
*/
533542
LoweredFunc LowerCustomDatatypes(LoweredFunc f, const std::string& target);
534543

544+
/*!
545+
* \brief Infer the TensorCore fragment infomation using tensor intrinsics
546+
*
547+
* \param f The device function to be lowered.
548+
* \return Transformed function.
549+
*/
550+
LoweredFunc InferFragment(LoweredFunc f);
551+
535552
/*!
536553
* \brief Verify if memory accesses are legal for a specific target device type.
537554
*

python/tvm/build_module.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -413,7 +413,6 @@ def lower(sch,
413413

414414
# Phase 3
415415
stmt = ir_pass.Simplify(stmt)
416-
stmt = ir_pass.LowerStorageAccessInfo(stmt)
417416
stmt = ir_pass.RemoveNoOp(stmt)
418417
if not cfg.disable_select_rewriting:
419418
stmt = ir_pass.RewriteUnsafeSelect(stmt)
@@ -465,6 +464,7 @@ def _build_for_device(flist, target, target_host):
465464
func = ir_pass.ThreadSync(func, "global")
466465
func = ir_pass.ThreadSync(func, "shared")
467466
func = ir_pass.ThreadSync(func, "warp")
467+
func = ir_pass.InferFragment(func)
468468
warp_size = target.thread_warp_size
469469
func = ir_pass.LowerThreadAllreduce(func, warp_size)
470470
fsplits = [s for s in ir_pass.SplitHostDevice(func)]
@@ -494,6 +494,8 @@ def _build_for_device(flist, target, target_host):
494494
assert not fdevice
495495

496496
target_host = _target.create(target_host)
497+
fdevice = [ir_pass.LowerDeviceStorageAccessInfo(x) for x in fdevice]
498+
fhost = [ir_pass.LowerDeviceStorageAccessInfo(x) for x in fhost]
497499
fdevice = [ir_pass.LowerIntrin(x, target.target_name) for x in fdevice]
498500
fhost = [ir_pass.LowerIntrin(x, target_host.target_name) for x in fhost]
499501
fhost = [ir_pass.CombineContextCall(x) for x in fhost]

src/api/api_pass.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,14 @@ TVM_REGISTER_API("ir_pass.PostOrderVisit")
118118
});
119119
});
120120

121+
TVM_REGISTER_API("ir_pass.LowerStorageAccess")
122+
.set_body([](TVMArgs args, TVMRetValue *ret) {
123+
LoweredFunc f = args[0];
124+
auto n = make_node<LoweredFuncNode>(*f.operator->());
125+
n->body = LowerStorageAccessInfo(f->body);
126+
*ret = LoweredFunc(n);
127+
});
128+
121129
// make from two arguments
122130
#define REGISTER_PASS(PassName) \
123131
TVM_REGISTER_API("ir_pass."#PassName) \
@@ -140,6 +148,7 @@ REGISTER_PASS(SplitHostDevice);
140148
REGISTER_PASS(StorageRewrite);
141149
REGISTER_PASS(CoProcSync);
142150
REGISTER_PASS(LowerStorageAccessInfo);
151+
REGISTER_PASS(LowerDeviceStorageAccessInfo)
143152
REGISTER_PASS(InjectVirtualThread);
144153
REGISTER_PASS(InjectPrefetch);
145154
REGISTER_PASS(InjectDoubleBuffer);
@@ -161,5 +170,6 @@ REGISTER_PASS(DecorateDeviceScope);
161170
REGISTER_PASS(InstrumentBoundCheckers);
162171
REGISTER_PASS(VerifyCompactBuffer);
163172
REGISTER_PASS(HoistIfThenElse);
173+
REGISTER_PASS(InferFragment)
164174
} // namespace ir
165175
} // namespace tvm

src/codegen/build_module.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -422,7 +422,6 @@ Stmt BuildStmt(Schedule sch,
422422

423423
// Phase 2
424424
stmt = ir::Simplify(stmt);
425-
stmt = ir::LowerStorageAccessInfo(stmt);
426425
stmt = ir::RemoveNoOp(stmt);
427426

428427
if (!(config->disable_select_rewriting))
@@ -517,13 +516,15 @@ Array<Array<LoweredFunc> > split_dev_host_funcs(const Array<LoweredFunc>& funcs,
517516
for (size_t i = 0; i < fhost.size(); ++i) {
518517
auto func = fhost[i];
519518
func = ir::BindDeviceType(func, target->device_type);
519+
func = ir::LowerDeviceStorageAccessInfo(func);
520520
func = ir::LowerTVMBuiltin(func);
521521
fhost.Set(i, func);
522522
}
523523

524524
for (size_t i = 0; i < fhost.size(); ++i) {
525525
auto func = fhost[i];
526526
func = ir::LowerIntrin(func, target_host->target_name);
527+
func = ir::LowerDeviceStorageAccessInfo(func);
527528
func = ir::CombineContextCall(func);
528529
fhost.Set(i, func);
529530
}

src/codegen/codegen_cuda.cc

Lines changed: 165 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,10 @@ std::string CodeGenCUDA::Finish() {
7474
decl_stream << "#include <math_constants.h>\n";
7575
}
7676

77+
if (need_mma_h_) {
78+
decl_stream << "#include <mma.h>\n";
79+
}
80+
7781
return CodeGenC::Finish();
7882
}
7983

@@ -102,14 +106,22 @@ void CodeGenCUDA::PrintType(Type t, std::ostream& os) { // NOLINT(*)
102106
bool fail = false;
103107
if (t.is_float()) {
104108
switch (t.bits()) {
105-
case 16: os << "half";
109+
case 16:
106110
enable_fp16_ = true;
111+
if (lanes == 1) {
112+
os << "half";
113+
} else if (lanes <= 8) {
114+
CHECK_EQ(lanes % 2, 0) << "only support even lane for half type";
115+
os << "float" << lanes / 2;
116+
} else {
117+
fail = true;
118+
}
107119
break;
108120
case 32: os << "float"; break;
109121
case 64: os << "double"; break;
110122
default: fail = true; break;
111123
}
112-
if (!fail && lanes == 1) return;
124+
if (!fail && (lanes == 1 || t.bits() == 16)) return;
113125
if (!fail && (lanes >= 2 && lanes <= 4)) {
114126
os << lanes; return;
115127
}
@@ -290,6 +302,113 @@ void CodeGenCUDA::PrintStorageScope(
290302
}
291303
}
292304

305+
void CodeGenCUDA::VisitExpr_(const Call *op, std::ostream& os) {
306+
if (op->is_intrinsic(intrinsic::tvm_fill_fragment)) {
307+
need_mma_h_ = true;
308+
CHECK_EQ(op->args.size(), 6U);
309+
os << "nvcuda::wmma::fill_fragment(";
310+
this->PrintExpr(op->args[0], os);
311+
os << "[";
312+
this->PrintExpr(op->args[4], os);
313+
os << "], ";
314+
this->PrintExpr(op->args[5], os);
315+
os << ")";
316+
} else if (op->is_intrinsic(intrinsic::tvm_load_matrix_sync)) {
317+
need_mma_h_ = true;
318+
CHECK_EQ(op->args.size(), 8U);
319+
os << "nvcuda::wmma::load_matrix_sync(";
320+
this->PrintExpr(op->args[0], os);
321+
os << "[";
322+
this->PrintExpr(op->args[4], os);
323+
os << "], ";
324+
this->PrintExpr(op->args[5], os);
325+
os << ", ";
326+
this->PrintExpr(op->args[6], os);
327+
os << ")";
328+
} else if (op->is_intrinsic(intrinsic::tvm_store_matrix_sync)) {
329+
need_mma_h_ = true;
330+
CHECK_EQ(op->args.size(), 8U);
331+
os << "nvcuda::wmma::store_matrix_sync(";
332+
this->PrintExpr(op->args[5], os);
333+
os << ", ";
334+
this->PrintExpr(op->args[0], os);
335+
os << "[";
336+
this->PrintExpr(op->args[4], os);
337+
os << "], ";
338+
this->PrintExpr(op->args[6], os);
339+
if (const StringImm *str = op->args[7].as<StringImm>()) {
340+
os << ", nvcuda::wmma::mem_" << str->value;
341+
} else {
342+
LOG(FATAL) << "Invalid parameters";
343+
}
344+
os << ")";
345+
} else if (op->is_intrinsic(intrinsic::tvm_mma_sync)) {
346+
need_mma_h_ = true;
347+
CHECK_EQ(op->args.size(), 8U);
348+
os << "nvcuda::wmma::mma_sync(";
349+
for (int i = 0; i < 4; ++i) {
350+
this->PrintExpr(op->args[i * 2], os);
351+
os << "[";
352+
this->PrintExpr(op->args[i * 2 + 1], os);
353+
os << "]" << ((i < 3) ? ", ": ")");
354+
}
355+
} else {
356+
CodeGenC::VisitExpr_(op, os);
357+
}
358+
}
359+
360+
void CodeGenCUDA::VisitStmt_(const AttrStmt* op) {
361+
if (op->attr_key == attr::fragment_shape) {
362+
const Variable* buffer = op->node.as<Variable>();
363+
const StringImm* shape_str = op->value.as<StringImm>();
364+
fragment_shapes[buffer] = shape_str->value;
365+
} else if (op->attr_key == attr::fragment_layout) {
366+
const Variable* buffer = op->node.as<Variable>();
367+
const StringImm* layout_str = op->value.as<StringImm>();
368+
fragment_layouts[buffer] = layout_str->value;
369+
}
370+
CodeGenC::VisitStmt_(op);
371+
}
372+
373+
void CodeGenCUDA::VisitStmt_(const Allocate* op) {
374+
CHECK(!is_zero(op->condition));
375+
std::string vid = AllocVarID(op->buffer_var.get());
376+
if (op->new_expr.defined()) {
377+
// Prefer global static allocation for the program
378+
CHECK_EQ(op->free_function, "nop");
379+
std::string new_data = PrintExpr(op->new_expr);
380+
this->PrintIndent();
381+
PrintType(op->type, stream);
382+
stream << "* "<< vid << '=' << new_data << ";\n";
383+
} else {
384+
this->PrintIndent();
385+
int32_t constant_size = op->constant_allocation_size();
386+
CHECK_GT(constant_size, 0)
387+
<< "Can only handle constant size stack allocation for now";
388+
const Variable* buffer = op->buffer_var.as<Variable>();
389+
std::string scope = alloc_storage_scope_.at(buffer);
390+
if (scope.find("wmma.") == 0) {
391+
if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") {
392+
CHECK(op->type == Float(16) || op->type == Int(8) || op->type == UInt(8))
393+
<< "Matrix_a and matrix_b only support half or char or unsigned char type for now";
394+
} else {
395+
CHECK(op->type == Float(16) || op->type == Float(32) || op->type == Int(32))
396+
<< "Accumulator only support half, float and int type for now";
397+
}
398+
constant_size = GetWmmaFragmentSize(scope, buffer, constant_size);
399+
PrintWmmaScope(scope, op->type, buffer, stream);
400+
} else {
401+
PrintStorageScope(scope, stream);
402+
stream << ' ';
403+
PrintType(op->type, stream);
404+
}
405+
stream << ' '<< vid << '['
406+
<< constant_size << "];\n";
407+
}
408+
RegisterHandleType(op->buffer_var.get(), op->type);
409+
this->PrintStmt(op->body);
410+
}
411+
293412
void CodeGenCUDA::VisitStmt_(const Evaluate *op) {
294413
if (is_const(op->value)) return;
295414
const Call* call = op->value.as<Call>();
@@ -392,5 +511,49 @@ void CodeGenCUDA::VisitExpr_(const FloatImm *op, std::ostream& os) { // NOLINT(*
392511
PrintConst(op, os, this);
393512
}
394513

514+
void CodeGenCUDA::PrintWmmaScope(const std::string &scope, Type t,
515+
const Variable* variable, std::ostream &os) {
516+
std::stringstream type;
517+
PrintType(t, type);
518+
std::string shape_str = fragment_shapes[variable];
519+
if (scope == "wmma.matrix_a") {
520+
need_mma_h_ = true;
521+
std::string layout_str = fragment_layouts[variable];
522+
os << "nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, "
523+
<< shape_str << ", " << type.str() << ", nvcuda::wmma::" << layout_str <<">";
524+
} else if (scope == "wmma.matrix_b") {
525+
need_mma_h_ = true;
526+
std::string layout_str = fragment_layouts[variable];
527+
os << "nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, "
528+
<< shape_str << ", " << type.str() << ", nvcuda::wmma::" << layout_str <<">";
529+
} else if (scope == "wmma.accumulator") {
530+
need_mma_h_ = true;
531+
os << "nvcuda::wmma::fragment<nvcuda::wmma::accumulator, "
532+
<< shape_str << ", "<< type.str() << ">";
533+
}
534+
}
535+
536+
int32_t CodeGenCUDA::GetWmmaFragmentSize(const std::string &scope,
537+
const Variable* variable, int32_t size) {
538+
std::string shape_str = fragment_shapes[variable];
539+
size_t m, n, k;
540+
size_t last_pos = 0, pos = 0;
541+
pos = shape_str.find(", ", last_pos);
542+
m = std::stoi(shape_str.substr(last_pos, pos - last_pos));
543+
last_pos = pos + 2;
544+
pos = shape_str.find(", ", last_pos);
545+
n = std::stoi(shape_str.substr(last_pos, pos - last_pos));
546+
last_pos = pos + 2;
547+
k = std::stoi(shape_str.substr(last_pos, shape_str.length() - last_pos));
548+
if (scope == "wmma.matrix_a") {
549+
return size / m / k;
550+
} else if (scope == "wmma.matrix_b") {
551+
return size / n / k;
552+
} else if (scope == "wmma.accumulator") {
553+
return size / m / n;
554+
}
555+
return 0;
556+
}
557+
395558
} // namespace codegen
396559
} // namespace tvm

0 commit comments

Comments
 (0)