@@ -228,19 +228,6 @@ inline KernelCode createMatmul3(const char *shaderTemplate, const size_t M,
228228 }
229229}
230230
231- /* *
232- * @brief No-Op shader with matmul bindings for performance testing
233- */
234- static const char *kShaderNoOp = R"(
235- @group(0) @binding(0) var<storage, read_write> A: array<{{precision}}>;
236- @group(0) @binding(1) var<storage, read_write> B: array<{{precision}}>;
237- @group(0) @binding(2) var<storage, read_write> C: array<{{precision}}>;
238- @compute @workgroup_size({{workgroupSize}})
239- fn main(
240- @builtin(global_invocation_id) globalID : vec3<u32>) {
241- }
242- )" ;
243-
244231/* 2D block-tiling
245232 *
246233 */
@@ -357,6 +344,141 @@ inline KernelCode createMatmul4(const char *shaderTemplate, const size_t M,
357344 }
358345}
359346
347+ /* 2D block-tiling with vectorization
348+ *
349+ */
350+ static const char *kShaderMatmulWithVectorization = R"(
351+ @group(0) @binding(0) var<storage, read_write> a: array<{{precision}}>;
352+ @group(0) @binding(1) var<storage, read_write> b: array<{{precision}}>;
353+ @group(0) @binding(2) var<storage, read_write> c: array<vec4<{{precision}}>>;
354+ var<workgroup> tileA: array<{{precision}}, {{BM}} * {{BK}}>;
355+ var<workgroup> tileB: array<{{precision}}, {{BN}} * {{BK}}>;
356+
357+ @compute @workgroup_size({{workgroupSize}})
358+ fn main(
359+ @builtin(global_invocation_id) globalID : vec3<u32>,
360+ @builtin(local_invocation_id) localID : vec3<u32>,
361+ @builtin(workgroup_id) groupid : vec3<u32>) {
362+
363+ var threadResults: array<vec4<{{precision}}>, {{TM}} * {{TN4}}>;
364+ var localM: array<{{precision}}, {{TM}}>;
365+ var localN: array<vec4<{{precision}}>, {{TN4}}>;
366+
367+ let cRow: u32 = groupid.x;
368+ let cCol: u32 = groupid.y;
369+ let numThread: u32 = ({{BM}} * {{BN}}) / ({{TM}} * {{TN}});
370+
371+ // position of the first c element computed by the thread
372+ let threadRow: u32 = (localID.x / ({{BN}} / {{TN}})) * {{TM}};
373+ let threadCol: u32 = (localID.x % ({{BN}} / {{TN}})) * {{TN}};
374+
375+ // aPtr and bPtr are the starting positions of the tiles in a and b,
376+ // incremented in the bkidx loop.
377+ // cPtr is the starting position of the tile in c which is fixed.
378+
379+ var aPtr = cRow * {{BM}} * {{K}};
380+ var bPtr = cCol * {{BN}} * {{K}};
381+ let cPtr = cRow * {{BM}} * {{N4}} + cCol * {{BN4}};
382+
383+ for (var bkidx = 0; bkidx < {{K}}; bkidx += {{BK}}) {
384+
385+ // Load tile
386+ // Load BM x BK by numThread(BM * BN / (TM * TN))
387+ // The number of iteration == BM * BK / (BM * BN / (TM * TN))
388+ for (var idx: u32 = 0; idx < {{NUM_TILEA}}; idx++) {
389+ tileA[localID.x + idx * numThread] = a[aPtr + ((localID.x + idx * numThread) / {{BK}}) * {{K}} + (localID.x + idx * numThread) % {{BK}}];
390+ }
391+ // Load BK x BN by numThread(BM * BN / (TM * TN))
392+ // The number of iteration == BK * BN / (BM * BN / (TM * TN))
393+ for (var idx: u32 = 0; idx < {{NUM_TILEB}}; idx++) {
394+ tileB[localID.x + idx * numThread] = b[bPtr + ((localID.x + idx * numThread) / {{BK}}) * {{K}} + ((localID.x + idx * numThread) % {{BK}})];
395+ }
396+
397+ aPtr += {{BK}};
398+ bPtr += {{BK}};
399+
400+ workgroupBarrier();
401+ // Compute tile
402+ for (var dotIdx: u32 = 0; dotIdx < {{BK}}; dotIdx = dotIdx + 1) {
403+ for (var idx: u32 = 0; idx < {{TM}}; idx++) {
404+ localM[idx] = tileA[(threadRow + idx) * {{BK}} + dotIdx];
405+ }
406+ for (var idx: u32 = 0; idx < {{TN4}}; idx++) {
407+ localN[idx] = vec4<{{precision}}>(tileB[(threadCol + idx*4 ) * {{BK}} + dotIdx],
408+ tileB[(threadCol + idx*4 + 1) * {{BK}} + dotIdx],
409+ tileB[(threadCol + idx*4 + 2) * {{BK}} + dotIdx],
410+ tileB[(threadCol + idx*4 + 3) * {{BK}} + dotIdx]);
411+ }
412+ for (var resIdxM: u32 = 0; resIdxM < {{TM}}; resIdxM++) {
413+ for (var resIdxN: u32 = 0; resIdxN < {{TN4}}; resIdxN++) {
414+ threadResults[resIdxM * {{TN4}} + resIdxN] += localM[resIdxM] * localN[resIdxN];
415+ }
416+ }
417+ }
418+ workgroupBarrier();
419+ }
420+
421+ for (var resIdxM: u32 = 0; resIdxM < {{TM}}; resIdxM++) {
422+ for (var resIdxN: u32 = 0; resIdxN < {{TN4}}; resIdxN++) {
423+ c[cPtr + (threadRow + resIdxM) * {{N4}} + (threadCol/4) + resIdxN] = threadResults[resIdxM * {{TN4}} + resIdxN];
424+ }
425+ }
426+ }
427+ )" ;
428+
429+ inline KernelCode createMatmulWithVectorization (const char *shaderTemplate, const size_t M,
430+ const size_t K, const size_t N, const size_t BM,
431+ const size_t BK, const size_t BN,
432+ const size_t TM, const size_t TN,
433+ const Shape &workgroupSize = {256 , 1 , 1 },
434+ NumType precision = kf32,
435+ bool unrolling = false ) {
436+ assert (BM % TM == 0 );
437+ assert (BN % TN == 0 );
438+ assert (K % BK == 0 );
439+ assert (M % BM == 0 );
440+ assert (N % BN == 0 );
441+ // # threads = tile A size == tile B size == # threads for computing C
442+ int num_threads = BM * BN / (TM * TN);
443+ std::string codeString (shaderTemplate);
444+ replaceAll (codeString, {{" {{workgroupSize}}" , toString (workgroupSize)},
445+ {" {{precision}}" , toString (precision)},
446+ {" {{M}}" , toString (M)},
447+ {" {{K}}" , toString (K)},
448+ {" {{N}}" , toString (N)},
449+ {" {{BM}}" , toString (BM)},
450+ {" {{BK}}" , toString (BK)},
451+ {" {{BN}}" , toString (BN)},
452+ {" {{TM}}" , toString (TM)},
453+ {" {{TN}}" , toString (TN)},
454+ {" {{NUM_TILEA}}" , toString (BM * BK / num_threads)},
455+ {" {{NUM_TILEB}}" , toString (BN * BK / num_threads)},
456+ {" {{TN4}}" , toString (TN / 4 )},
457+ {" {{N4}}" , toString (N / 4 )},
458+ {" {{BN4}}" , toString (BN / 4 )},
459+ });
460+ if (unrolling) {
461+ std::string unrolledCode = loopUnrolling (codeString);
462+ LOG (kDefLog , kInfo , " Unrolled code:\n %s" , unrolledCode.c_str ());
463+ return {unrolledCode, workgroupSize};
464+ } else {
465+ return {codeString, workgroupSize};
466+ }
467+ }
468+
469+ /* *
470+ * @brief No-Op shader with matmul bindings for performance testing
471+ */
472+ static const char *kShaderNoOp = R"(
473+ @group(0) @binding(0) var<storage, read_write> A: array<{{precision}}>;
474+ @group(0) @binding(1) var<storage, read_write> B: array<{{precision}}>;
475+ @group(0) @binding(2) var<storage, read_write> C: array<{{precision}}>;
476+ @compute @workgroup_size({{workgroupSize}})
477+ fn main(
478+ @builtin(global_invocation_id) globalID : vec3<u32>) {
479+ }
480+ )" ;
481+
360482inline KernelCode createNoOp (const char *shaderTemplate,
361483 const Shape &workgroupSize = {256 , 1 , 1 },
362484 NumType precision = kf32) {
@@ -448,6 +570,24 @@ Kernel selectMatmul(Context &ctx, int version,
448570 kernel = createKernel (ctx, matmul, bindings,
449571 /* nWorkgroups*/ nWorkgroups);
450572 } else if (version == 7 ) {
573+ static constexpr size_t BM = 64 ;
574+ static constexpr size_t BK = 16 ;
575+ static constexpr size_t BN = 64 ;
576+ static constexpr size_t TM = BM / BK;
577+ static constexpr size_t TN = BN / BK;
578+ Shape wgSize = {(BM / TM) * (BN / TN), 1 , 1 }; // This is the same as BK * BK.
579+ Shape nWorkgroups = {cdiv (M, BM), cdiv (N, BN), 1 };
580+ LOG (kDefLog , kInfo , " M: %d, K: %d, N: %d" , M, K, N);
581+ LOG (kDefLog , kInfo , " BM: %d, BK: %d, BN: %d, TM: %d, TN: %d" , BM, BK, BN, TM, TN);
582+ LOG (kDefLog , kInfo , " wgSize: ( %s )" , toString (wgSize).c_str ());
583+ LOG (kDefLog , kInfo , " nWorkgroups: ( %s )" , toString (nWorkgroups).c_str ());
584+ KernelCode matmul = createMatmulWithVectorization (kShaderMatmulWithVectorization , M, K, N, BM, BK, BN, TM, TN,
585+ /* wgSize*/ wgSize,
586+ kf32,
587+ /* Loop unrolling*/ true );
588+ kernel = createKernel (ctx, matmul, bindings,
589+ /* nWorkgroups*/ nWorkgroups);
590+ } else if (version == 8 ) {
451591 Shape wgSize = {256 , 1 , 1 };
452592 Shape nWorkgroups = cdiv ({M, N, 1 }, {16 , 16 , 1 });
453593 KernelCode matmul = createNoOp (kShaderNoOp , /* wgsize*/ wgSize);
@@ -528,7 +668,8 @@ int main() {
528668 // 4 == 2D blocktiling
529669 // 5 == 1D blocktiling with loop unrolling
530670 // 6 == 2D blocktiling with loop unrolling
531- // 7 == No-Op
671+ // 7 == 2D blocktiling with loop unrolling and vectorization
672+ // 8 == No-Op
532673
533674 size_t M, K, N; // Matrix dimensions
534675 static constexpr int kTestSize = 2 ;
0 commit comments