1818 */
1919
2020/* !
21- * \file ptx_mma .cc
21+ * \file ptx .cc
2222 */
2323
24- #include " ptx_mma .h"
24+ #include " ptx .h"
2525
2626#include < algorithm>
2727#include < string>
@@ -60,13 +60,18 @@ enum class DataType : int {
6060 kFloat32 = 13 ,
6161 kTensorFloat32 = 14 ,
6262 kFloat64 = 15 ,
63- kBit1 = 16
63+ kBit1 = 16 ,
64+ kBit8 = 17 ,
65+ kBit16 = 18 ,
66+ kBit32 = 19 ,
67+ kBit64 = 20 ,
6468};
6569
66- static const char * dtype_str[] = {" .s4" , " .u4" , " .s8" , " .u8" , " .s16" , " .u16" ,
67- " .s32" , " .u32" , " .s64" , " .u64" , " .f16" , " .bf16" ,
68- " .f16x2" , " .f32" , " .tf32" , " .f64" , " .b1" };
69- static const uint32_t num_bits[] = {4 , 4 , 8 , 8 , 16 , 16 , 32 , 32 , 64 , 64 , 16 , 16 , 32 , 32 , 32 , 64 , 1 };
70+ static const char * dtype_str[] = {" .s4" , " .u4" , " .s8" , " .u8" , " .s16" , " .u16" , " .s32" ,
71+ " .u32" , " .s64" , " .u64" , " .f16" , " .bf16" , " .f16x2" , " .f32" ,
72+ " .tf32" , " .f64" , " .b1" , " .b8" , " .b16" , " .b32" , " .b64" };
73+ static const uint32_t num_bits[] = {4 , 4 , 8 , 8 , 16 , 16 , 32 , 32 , 64 , 64 , 16 ,
74+ 16 , 32 , 32 , 32 , 64 , 1 , 8 , 16 , 32 , 64 };
7075
7176/* !
7277 * \brief Create PTX data type from string.
@@ -106,6 +111,14 @@ inline DataType DTypeFromString(const std::string str) {
106111 return DataType::kFloat64 ;
107112 } else if (str == " int1" || str == " .b1" ) {
108113 return DataType::kBit1 ;
114+ } else if (str == " .b8" ) {
115+ return DataType::kBit8 ;
116+ } else if (str == " .b16" ) {
117+ return DataType::kBit16 ;
118+ } else if (str == " .b32" ) {
119+ return DataType::kBit32 ;
120+ } else if (str == " .b64" ) {
121+ return DataType::kBit64 ;
109122 } else {
110123 LOG (FATAL) << " Unrecognized PTX data type " << str;
111124 return DataType (0 );
@@ -360,6 +373,7 @@ inline FragAttrs GetFragAttrs(DataType dtype) {
360373 case DataType::kUInt4 :
361374 case DataType::kInt8 :
362375 case DataType::kUInt8 :
376+ case DataType::kBit16 :
363377 case DataType::kFloat16 : // .f16x2 register
364378 case DataType::kBFloat16 :
365379 case DataType::kTensorFloat32 :
@@ -508,9 +522,9 @@ inline std::tuple<std::string, std::string, std::string> GetMMAOperands(int m, i
508522std::string PrintMMAAssembly (const std::string& shape, const std::string& A_layout,
509523 const std::string& B_layout, const std::string& A_dtype,
510524 const std::string& B_dtype, const std::string& C_dtype,
511- const std::string& a_ref , const std::string& a_offset ,
512- const std::string& b_ref , const std::string& b_offset ,
513- const std::string& c_ref , const std::string& c_offset ,
525+ const std::string& a_ptr , const std::string& a_elem_offset ,
526+ const std::string& b_ptr , const std::string& b_elem_offset ,
527+ const std::string& c_ptr , const std::string& c_elem_offset ,
514528 const std::string& metadata, const std::string& metadata_offset,
515529 const std::string& sparsity_selector, const std::string& bit_op,
516530 bool sparse, bool saturate) {
@@ -525,7 +539,7 @@ std::string PrintMMAAssembly(const std::string& shape, const std::string& A_layo
525539 std::string asm_code = R"(
526540 {
527541 __asm__ __volatile__(
528- "mma{sparse}.sync.aligned.{ shape}.{ alayout}.{ blayout}{saturate}{dtype}{atype}{btype}{ctype}{bitop}"
542+ "mma{. sparse}.sync.aligned{. shape}{. alayout}{. blayout}{. saturate}{. dtype}{. atype}{. btype}{. ctype}{. bitop}"
529543 "{templates};\n"
530544 : {outputs}
531545 : {inputs});
@@ -537,30 +551,92 @@ std::string PrintMMAAssembly(const std::string& shape, const std::string& A_layo
537551
538552 // replace patterns
539553 Replacer replacer;
540- replacer.register_rule (" {sparse}" , sparse ? " .sp" : " " );
541- replacer.register_rule (" {shape}" , shape);
542- replacer.register_rule (" {saturate}" , saturate ? " .satfinite" : " " );
543- replacer.register_rule (" {alayout}" , A_layout);
544- replacer.register_rule (" {blayout}" , B_layout);
545- replacer.register_rule (" {atype}" , ptx::DTypeToString (dtype_a));
546- replacer.register_rule (" {btype}" , ptx::DTypeToString (dtype_b));
547- replacer.register_rule (" {ctype}" , ptx::DTypeToString (dtype_c));
548- replacer.register_rule (" {dtype}" , ptx::DTypeToString (dtype_c));
549- replacer.register_rule (" {bitop}" , bit_op.empty () ? " " : " ." + bit_op + " .popc" );
554+ replacer.register_rule (" {. sparse}" , sparse ? " .sp" : " " );
555+ replacer.register_rule (" {. shape}" , " . " + shape);
556+ replacer.register_rule (" {. saturate}" , saturate ? " .satfinite" : " " );
557+ replacer.register_rule (" {. alayout}" , " . " + A_layout);
558+ replacer.register_rule (" {. blayout}" , " . " + B_layout);
559+ replacer.register_rule (" {. atype}" , ptx::DTypeToString (dtype_a));
560+ replacer.register_rule (" {. btype}" , ptx::DTypeToString (dtype_b));
561+ replacer.register_rule (" {. ctype}" , ptx::DTypeToString (dtype_c));
562+ replacer.register_rule (" {. dtype}" , ptx::DTypeToString (dtype_c));
563+ replacer.register_rule (" {. bitop}" , bit_op.empty () ? " " : " ." + bit_op + " .popc" );
550564 replacer.register_rule (" {templates}" , templates_str);
551565 replacer.register_rule (" {outputs}" , outputs_str);
552566 replacer.register_rule (" {inputs}" , inputs_str);
553567 asm_code = replacer.rewrite (asm_code);
554568 replacer.empty_rules ();
555- replacer.register_rule (" A" , a_ref + " + " + a_offset );
556- replacer.register_rule (" B" , b_ref + " + " + b_offset );
557- replacer.register_rule (" C" , c_ref + " + " + c_offset );
558- replacer.register_rule (" D" , c_ref + " + " + c_offset );
569+ replacer.register_rule (" A" , a_ptr + " + " + a_elem_offset );
570+ replacer.register_rule (" B" , b_ptr + " + " + b_elem_offset );
571+ replacer.register_rule (" C" , c_ptr + " + " + c_elem_offset );
572+ replacer.register_rule (" D" , c_ptr + " + " + c_elem_offset );
559573 replacer.register_rule (" E" , metadata + " + " + metadata_offset);
560574 replacer.register_rule (" F" , sparsity_selector);
561575 asm_code = replacer.rewrite (asm_code);
562576 return asm_code;
563577}
564578
579+ inline std::tuple<std::string, std::string> GetLoadMatrixOperands (
580+ int num, const std::string& local_ptr, const std::string& local_elem_offset) {
581+ std::stringstream templates, outputs;
582+ int arg_counter = 0 ;
583+ // generate templates
584+ templates << " {%" << arg_counter++;
585+ for (int i = 1 ; i < num; ++i) {
586+ templates << " , %" << arg_counter++;
587+ }
588+ templates << " }, [%" << arg_counter++ << " ]" ;
589+ // generate outputs
590+ std::string ptr_type = " (unsigned *)" ;
591+ for (int i = 0 ; i < num; ++i) {
592+ if (i != 0 ) {
593+ outputs << " , " ;
594+ }
595+ outputs << " \" =r\" ((" << ptr_type << " (" << local_ptr << " + " << local_elem_offset << " ))["
596+ << i << " ])" ;
597+ }
598+ return std::make_tuple (templates.str (), outputs.str ());
599+ }
600+
601+ std::string PrintLoadMatrixAssembly (bool trans, int num, const std::string& type,
602+ const std::string& local_ptr,
603+ const std::string& local_elem_offset,
604+ const std::string& smem_ptr,
605+ const std::string& smem_elem_offset) {
606+ CHECK (num == 1 || num == 2 || num == 4 ) << " ldmatrix only accept loading 1/2/4 matrices." ;
607+ ptx::DataType data_type = ptx::DTypeFromString (type);
608+ CHECK (data_type == ptx::DataType::kBit16 ) << " ldmatrix only accept matrix with type .b16." ;
609+ std::string asm_code = R"(
610+ {
611+ unsigned int addr;
612+ __asm__ __volatile__(
613+ "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
614+ : "=r"(addr)
615+ : "l"((void *)({smem_addr}))
616+ );
617+ __asm__ __volatile__(
618+ "ldmatrix.sync.aligned{.shape}{.num}{.trans}{.ss}{.type}"
619+ "{templates};\n"
620+ : {outputs}
621+ : "r"(addr)
622+ );
623+ }
624+ )" ;
625+ std::string templates_str, outputs_str;
626+ std::tie (templates_str, outputs_str) = GetLoadMatrixOperands (num, local_ptr, local_elem_offset);
627+
628+ Replacer replacer;
629+ replacer.register_rule (" {.shape}" , " .m8n8" );
630+ replacer.register_rule (" {.num}" , " .x" + std::to_string (num));
631+ replacer.register_rule (" {.trans}" , trans ? " .trans" : " " );
632+ replacer.register_rule (" {.ss}" , " .shared" );
633+ replacer.register_rule (" {.type}" , ptx::DTypeToString (data_type));
634+ replacer.register_rule (" {smem_addr}" , smem_ptr + " + " + smem_elem_offset);
635+ replacer.register_rule (" {templates}" , templates_str);
636+ replacer.register_rule (" {outputs}" , outputs_str);
637+ asm_code = replacer.rewrite (asm_code);
638+ return asm_code;
639+ }
640+
565641} // namespace codegen
566642} // namespace tvm
0 commit comments