@@ -74,6 +74,10 @@ std::string CodeGenCUDA::Finish() {
7474 decl_stream << " #include <math_constants.h>\n " ;
7575 }
7676
77+ if (need_mma_h_) {
78+ decl_stream << " #include <mma.h>\n " ;
79+ }
80+
7781 return CodeGenC::Finish ();
7882}
7983
@@ -102,14 +106,22 @@ void CodeGenCUDA::PrintType(Type t, std::ostream& os) { // NOLINT(*)
102106 bool fail = false ;
103107 if (t.is_float ()) {
104108 switch (t.bits ()) {
105- case 16 : os << " half " ;
109+ case 16 :
106110 enable_fp16_ = true ;
111+ if (lanes == 1 ) {
112+ os << " half" ;
113+ } else if (lanes <= 8 ) {
114+ CHECK_EQ (lanes % 2 , 0 ) << " only support even lane for half type" ;
115+ os << " float" << lanes / 2 ;
116+ } else {
117+ fail = true ;
118+ }
107119 break ;
108120 case 32 : os << " float" ; break ;
109121 case 64 : os << " double" ; break ;
110122 default : fail = true ; break ;
111123 }
112- if (!fail && lanes == 1 ) return ;
124+ if (!fail && ( lanes == 1 || t. bits () == 16 ) ) return ;
113125 if (!fail && (lanes >= 2 && lanes <= 4 )) {
114126 os << lanes; return ;
115127 }
@@ -290,6 +302,113 @@ void CodeGenCUDA::PrintStorageScope(
290302 }
291303}
292304
305+ void CodeGenCUDA::VisitExpr_ (const Call *op, std::ostream& os) {
306+ if (op->is_intrinsic (intrinsic::tvm_fill_fragment)) {
307+ need_mma_h_ = true ;
308+ CHECK_EQ (op->args .size (), 6U );
309+ os << " nvcuda::wmma::fill_fragment(" ;
310+ this ->PrintExpr (op->args [0 ], os);
311+ os << " [" ;
312+ this ->PrintExpr (op->args [4 ], os);
313+ os << " ], " ;
314+ this ->PrintExpr (op->args [5 ], os);
315+ os << " )" ;
316+ } else if (op->is_intrinsic (intrinsic::tvm_load_matrix_sync)) {
317+ need_mma_h_ = true ;
318+ CHECK_EQ (op->args .size (), 8U );
319+ os << " nvcuda::wmma::load_matrix_sync(" ;
320+ this ->PrintExpr (op->args [0 ], os);
321+ os << " [" ;
322+ this ->PrintExpr (op->args [4 ], os);
323+ os << " ], " ;
324+ this ->PrintExpr (op->args [5 ], os);
325+ os << " , " ;
326+ this ->PrintExpr (op->args [6 ], os);
327+ os << " )" ;
328+ } else if (op->is_intrinsic (intrinsic::tvm_store_matrix_sync)) {
329+ need_mma_h_ = true ;
330+ CHECK_EQ (op->args .size (), 8U );
331+ os << " nvcuda::wmma::store_matrix_sync(" ;
332+ this ->PrintExpr (op->args [5 ], os);
333+ os << " , " ;
334+ this ->PrintExpr (op->args [0 ], os);
335+ os << " [" ;
336+ this ->PrintExpr (op->args [4 ], os);
337+ os << " ], " ;
338+ this ->PrintExpr (op->args [6 ], os);
339+ if (const StringImm *str = op->args [7 ].as <StringImm>()) {
340+ os << " , nvcuda::wmma::mem_" << str->value ;
341+ } else {
342+ LOG (FATAL) << " Invalid parameters" ;
343+ }
344+ os << " )" ;
345+ } else if (op->is_intrinsic (intrinsic::tvm_mma_sync)) {
346+ need_mma_h_ = true ;
347+ CHECK_EQ (op->args .size (), 8U );
348+ os << " nvcuda::wmma::mma_sync(" ;
349+ for (int i = 0 ; i < 4 ; ++i) {
350+ this ->PrintExpr (op->args [i * 2 ], os);
351+ os << " [" ;
352+ this ->PrintExpr (op->args [i * 2 + 1 ], os);
353+ os << " ]" << ((i < 3 ) ? " , " : " )" );
354+ }
355+ } else {
356+ CodeGenC::VisitExpr_ (op, os);
357+ }
358+ }
359+
360+ void CodeGenCUDA::VisitStmt_ (const AttrStmt* op) {
361+ if (op->attr_key == attr::fragment_shape) {
362+ const Variable* buffer = op->node .as <Variable>();
363+ const StringImm* shape_str = op->value .as <StringImm>();
364+ fragment_shapes[buffer] = shape_str->value ;
365+ } else if (op->attr_key == attr::fragment_layout) {
366+ const Variable* buffer = op->node .as <Variable>();
367+ const StringImm* layout_str = op->value .as <StringImm>();
368+ fragment_layouts[buffer] = layout_str->value ;
369+ }
370+ CodeGenC::VisitStmt_ (op);
371+ }
372+
373+ void CodeGenCUDA::VisitStmt_ (const Allocate* op) {
374+ CHECK (!is_zero (op->condition ));
375+ std::string vid = AllocVarID (op->buffer_var .get ());
376+ if (op->new_expr .defined ()) {
377+ // Prefer global static allocation for the program
378+ CHECK_EQ (op->free_function , " nop" );
379+ std::string new_data = PrintExpr (op->new_expr );
380+ this ->PrintIndent ();
381+ PrintType (op->type , stream);
382+ stream << " * " << vid << ' =' << new_data << " ;\n " ;
383+ } else {
384+ this ->PrintIndent ();
385+ int32_t constant_size = op->constant_allocation_size ();
386+ CHECK_GT (constant_size, 0 )
387+ << " Can only handle constant size stack allocation for now" ;
388+ const Variable* buffer = op->buffer_var .as <Variable>();
389+ std::string scope = alloc_storage_scope_.at (buffer);
390+ if (scope.find (" wmma." ) == 0 ) {
391+ if (scope == " wmma.matrix_a" || scope == " wmma.matrix_b" ) {
392+ CHECK (op->type == Float (16 ) || op->type == Int (8 ) || op->type == UInt (8 ))
393+ << " Matrix_a and matrix_b only support half or char or unsigned char type for now" ;
394+ } else {
395+ CHECK (op->type == Float (16 ) || op->type == Float (32 ) || op->type == Int (32 ))
396+ << " Accumulator only support half, float and int type for now" ;
397+ }
398+ constant_size = GetWmmaFragmentSize (scope, buffer, constant_size);
399+ PrintWmmaScope (scope, op->type , buffer, stream);
400+ } else {
401+ PrintStorageScope (scope, stream);
402+ stream << ' ' ;
403+ PrintType (op->type , stream);
404+ }
405+ stream << ' ' << vid << ' ['
406+ << constant_size << " ];\n " ;
407+ }
408+ RegisterHandleType (op->buffer_var .get (), op->type );
409+ this ->PrintStmt (op->body );
410+ }
411+
293412void CodeGenCUDA::VisitStmt_ (const Evaluate *op) {
294413 if (is_const (op->value )) return ;
295414 const Call* call = op->value .as <Call>();
@@ -392,5 +511,49 @@ void CodeGenCUDA::VisitExpr_(const FloatImm *op, std::ostream& os) { // NOLINT(*
392511 PrintConst (op, os, this );
393512}
394513
514+ void CodeGenCUDA::PrintWmmaScope (const std::string &scope, Type t,
515+ const Variable* variable, std::ostream &os) {
516+ std::stringstream type;
517+ PrintType (t, type);
518+ std::string shape_str = fragment_shapes[variable];
519+ if (scope == " wmma.matrix_a" ) {
520+ need_mma_h_ = true ;
521+ std::string layout_str = fragment_layouts[variable];
522+ os << " nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, "
523+ << shape_str << " , " << type.str () << " , nvcuda::wmma::" << layout_str <<" >" ;
524+ } else if (scope == " wmma.matrix_b" ) {
525+ need_mma_h_ = true ;
526+ std::string layout_str = fragment_layouts[variable];
527+ os << " nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, "
528+ << shape_str << " , " << type.str () << " , nvcuda::wmma::" << layout_str <<" >" ;
529+ } else if (scope == " wmma.accumulator" ) {
530+ need_mma_h_ = true ;
531+ os << " nvcuda::wmma::fragment<nvcuda::wmma::accumulator, "
532+ << shape_str << " , " << type.str () << " >" ;
533+ }
534+ }
535+
536+ int32_t CodeGenCUDA::GetWmmaFragmentSize (const std::string &scope,
537+ const Variable* variable, int32_t size) {
538+ std::string shape_str = fragment_shapes[variable];
539+ size_t m, n, k;
540+ size_t last_pos = 0 , pos = 0 ;
541+ pos = shape_str.find (" , " , last_pos);
542+ m = std::stoi (shape_str.substr (last_pos, pos - last_pos));
543+ last_pos = pos + 2 ;
544+ pos = shape_str.find (" , " , last_pos);
545+ n = std::stoi (shape_str.substr (last_pos, pos - last_pos));
546+ last_pos = pos + 2 ;
547+ k = std::stoi (shape_str.substr (last_pos, shape_str.length () - last_pos));
548+ if (scope == " wmma.matrix_a" ) {
549+ return size / m / k;
550+ } else if (scope == " wmma.matrix_b" ) {
551+ return size / n / k;
552+ } else if (scope == " wmma.accumulator" ) {
553+ return size / m / n;
554+ }
555+ return 0 ;
556+ }
557+
395558} // namespace codegen
396559} // namespace tvm
0 commit comments