Skip to content

Commit b835e63

Browse files
Add the vectorized matmul
1 parent 1c665b2 commit b835e63

File tree

1 file changed

+155
-14
lines changed

1 file changed

+155
-14
lines changed

examples/matmul/run.cpp

Lines changed: 155 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -228,19 +228,6 @@ inline KernelCode createMatmul3(const char *shaderTemplate, const size_t M,
228228
}
229229
}
230230

231-
/**
232-
* @brief No-Op shader with matmul bindings for performance testing
233-
*/
234-
static const char *kShaderNoOp = R"(
235-
@group(0) @binding(0) var<storage, read_write> A: array<{{precision}}>;
236-
@group(0) @binding(1) var<storage, read_write> B: array<{{precision}}>;
237-
@group(0) @binding(2) var<storage, read_write> C: array<{{precision}}>;
238-
@compute @workgroup_size({{workgroupSize}})
239-
fn main(
240-
@builtin(global_invocation_id) globalID : vec3<u32>) {
241-
}
242-
)";
243-
244231
/* 2D block-tiling
245232
*
246233
*/
@@ -357,6 +344,141 @@ inline KernelCode createMatmul4(const char *shaderTemplate, const size_t M,
357344
}
358345
}
359346

347+
/* 2D block-tiling with vectorization
348+
*
349+
*/
350+
static const char *kShaderMatmulWithVectorization = R"(
351+
@group(0) @binding(0) var<storage, read_write> a: array<{{precision}}>;
352+
@group(0) @binding(1) var<storage, read_write> b: array<{{precision}}>;
353+
@group(0) @binding(2) var<storage, read_write> c: array<vec4<{{precision}}>>;
354+
var<workgroup> tileA: array<{{precision}}, {{BM}} * {{BK}}>;
355+
var<workgroup> tileB: array<{{precision}}, {{BN}} * {{BK}}>;
356+
357+
@compute @workgroup_size({{workgroupSize}})
358+
fn main(
359+
@builtin(global_invocation_id) globalID : vec3<u32>,
360+
@builtin(local_invocation_id) localID : vec3<u32>,
361+
@builtin(workgroup_id) groupid : vec3<u32>) {
362+
363+
var threadResults: array<vec4<{{precision}}>, {{TM}} * {{TN4}}>;
364+
var localM: array<{{precision}}, {{TM}}>;
365+
var localN: array<vec4<{{precision}}>, {{TN4}}>;
366+
367+
let cRow: u32 = groupid.x;
368+
let cCol: u32 = groupid.y;
369+
let numThread: u32 = ({{BM}} * {{BN}}) / ({{TM}} * {{TN}});
370+
371+
// position of the first c element computed by the thread
372+
let threadRow: u32 = (localID.x / ({{BN}} / {{TN}})) * {{TM}};
373+
let threadCol: u32 = (localID.x % ({{BN}} / {{TN}})) * {{TN}};
374+
375+
// aPtr and bPtr are the starting positions of the tiles in a and b,
376+
// incremented in the bkidx loop.
377+
// cPtr is the starting position of the tile in c which is fixed.
378+
379+
var aPtr = cRow * {{BM}} * {{K}};
380+
var bPtr = cCol * {{BN}} * {{K}};
381+
let cPtr = cRow * {{BM}} * {{N4}} + cCol * {{BN4}};
382+
383+
for (var bkidx = 0; bkidx < {{K}}; bkidx += {{BK}}) {
384+
385+
// Load tile
386+
// Load BM x BK by numThread(BM * BN / (TM * TN))
387+
// The number of iteration == BM * BK / (BM * BN / (TM * TN))
388+
for (var idx: u32 = 0; idx < {{NUM_TILEA}}; idx++) {
389+
tileA[localID.x + idx * numThread] = a[aPtr + ((localID.x + idx * numThread) / {{BK}}) * {{K}} + (localID.x + idx * numThread) % {{BK}}];
390+
}
391+
// Load BK x BN by numThread(BM * BN / (TM * TN))
392+
// The number of iteration == BK * BN / (BM * BN / (TM * TN))
393+
for (var idx: u32 = 0; idx < {{NUM_TILEB}}; idx++) {
394+
tileB[localID.x + idx * numThread] = b[bPtr + ((localID.x + idx * numThread) / {{BK}}) * {{K}} + ((localID.x + idx * numThread) % {{BK}})];
395+
}
396+
397+
aPtr += {{BK}};
398+
bPtr += {{BK}};
399+
400+
workgroupBarrier();
401+
// Compute tile
402+
for (var dotIdx: u32 = 0; dotIdx < {{BK}}; dotIdx = dotIdx + 1) {
403+
for (var idx: u32 = 0; idx < {{TM}}; idx++) {
404+
localM[idx] = tileA[(threadRow + idx) * {{BK}} + dotIdx];
405+
}
406+
for (var idx: u32 = 0; idx < {{TN4}}; idx++) {
407+
localN[idx] = vec4<{{precision}}>(tileB[(threadCol + idx*4 ) * {{BK}} + dotIdx],
408+
tileB[(threadCol + idx*4 + 1) * {{BK}} + dotIdx],
409+
tileB[(threadCol + idx*4 + 2) * {{BK}} + dotIdx],
410+
tileB[(threadCol + idx*4 + 3) * {{BK}} + dotIdx]);
411+
}
412+
for (var resIdxM: u32 = 0; resIdxM < {{TM}}; resIdxM++) {
413+
for (var resIdxN: u32 = 0; resIdxN < {{TN4}}; resIdxN++) {
414+
threadResults[resIdxM * {{TN4}} + resIdxN] += localM[resIdxM] * localN[resIdxN];
415+
}
416+
}
417+
}
418+
workgroupBarrier();
419+
}
420+
421+
for (var resIdxM: u32 = 0; resIdxM < {{TM}}; resIdxM++) {
422+
for (var resIdxN: u32 = 0; resIdxN < {{TN4}}; resIdxN++) {
423+
c[cPtr + (threadRow + resIdxM) * {{N4}} + (threadCol/4) + resIdxN] = threadResults[resIdxM * {{TN4}} + resIdxN];
424+
}
425+
}
426+
}
427+
)";
428+
429+
inline KernelCode createMatmulWithVectorization(const char *shaderTemplate, const size_t M,
430+
const size_t K, const size_t N, const size_t BM,
431+
const size_t BK, const size_t BN,
432+
const size_t TM, const size_t TN,
433+
const Shape &workgroupSize = {256, 1, 1},
434+
NumType precision = kf32,
435+
bool unrolling = false) {
436+
assert(BM % TM == 0);
437+
assert(BN % TN == 0);
438+
assert(K % BK == 0);
439+
assert(M % BM == 0);
440+
assert(N % BN == 0);
441+
// # threads = tile A size == tile B size == # threads for computing C
442+
int num_threads = BM * BN / (TM * TN);
443+
std::string codeString(shaderTemplate);
444+
replaceAll(codeString, {{"{{workgroupSize}}", toString(workgroupSize)},
445+
{"{{precision}}", toString(precision)},
446+
{"{{M}}", toString(M)},
447+
{"{{K}}", toString(K)},
448+
{"{{N}}", toString(N)},
449+
{"{{BM}}", toString(BM)},
450+
{"{{BK}}", toString(BK)},
451+
{"{{BN}}", toString(BN)},
452+
{"{{TM}}", toString(TM)},
453+
{"{{TN}}", toString(TN)},
454+
{"{{NUM_TILEA}}", toString(BM * BK / num_threads)},
455+
{"{{NUM_TILEB}}", toString(BN * BK / num_threads)},
456+
{"{{TN4}}", toString(TN / 4)},
457+
{"{{N4}}", toString(N / 4)},
458+
{"{{BN4}}", toString(BN / 4)},
459+
});
460+
if (unrolling) {
461+
std::string unrolledCode = loopUnrolling(codeString);
462+
LOG(kDefLog, kInfo, "Unrolled code:\n%s", unrolledCode.c_str());
463+
return {unrolledCode, workgroupSize};
464+
} else {
465+
return {codeString, workgroupSize};
466+
}
467+
}
468+
469+
/**
470+
* @brief No-Op shader with matmul bindings for performance testing
471+
*/
472+
static const char *kShaderNoOp = R"(
473+
@group(0) @binding(0) var<storage, read_write> A: array<{{precision}}>;
474+
@group(0) @binding(1) var<storage, read_write> B: array<{{precision}}>;
475+
@group(0) @binding(2) var<storage, read_write> C: array<{{precision}}>;
476+
@compute @workgroup_size({{workgroupSize}})
477+
fn main(
478+
@builtin(global_invocation_id) globalID : vec3<u32>) {
479+
}
480+
)";
481+
360482
inline KernelCode createNoOp(const char *shaderTemplate,
361483
const Shape &workgroupSize = {256, 1, 1},
362484
NumType precision = kf32) {
@@ -448,6 +570,24 @@ Kernel selectMatmul(Context &ctx, int version,
448570
kernel = createKernel(ctx, matmul, bindings,
449571
/*nWorkgroups*/ nWorkgroups);
450572
} else if (version == 7) {
573+
static constexpr size_t BM = 64;
574+
static constexpr size_t BK = 16;
575+
static constexpr size_t BN = 64;
576+
static constexpr size_t TM = BM / BK;
577+
static constexpr size_t TN = BN / BK;
578+
Shape wgSize = {(BM / TM) * (BN / TN), 1, 1}; // This is the same as BK * BK.
579+
Shape nWorkgroups = {cdiv(M, BM), cdiv(N, BN), 1};
580+
LOG(kDefLog, kInfo, "M: %d, K: %d, N: %d", M, K, N);
581+
LOG(kDefLog, kInfo, "BM: %d, BK: %d, BN: %d, TM: %d, TN: %d", BM, BK, BN, TM, TN);
582+
LOG(kDefLog, kInfo, "wgSize: ( %s )", toString(wgSize).c_str());
583+
LOG(kDefLog, kInfo, "nWorkgroups: ( %s )", toString(nWorkgroups).c_str());
584+
KernelCode matmul = createMatmulWithVectorization(kShaderMatmulWithVectorization, M, K, N, BM, BK, BN, TM, TN,
585+
/*wgSize*/ wgSize,
586+
kf32,
587+
/*Loop unrolling*/ true);
588+
kernel = createKernel(ctx, matmul, bindings,
589+
/*nWorkgroups*/ nWorkgroups);
590+
} else if (version == 8) {
451591
Shape wgSize = {256, 1, 1};
452592
Shape nWorkgroups = cdiv({M, N, 1}, {16, 16, 1});
453593
KernelCode matmul = createNoOp(kShaderNoOp, /*wgsize*/ wgSize);
@@ -528,7 +668,8 @@ int main() {
528668
// 4 == 2D blocktiling
529669
// 5 == 1D blocktiling with loop unrolling
530670
// 6 == 2D blocktiling with loop unrolling
531-
// 7 == No-Op
671+
// 7 == 2D blocktiling with loop unrolling and vectorization
672+
// 8 == No-Op
532673

533674
size_t M, K, N; // Matrix dimensions
534675
static constexpr int kTestSize = 2;

0 commit comments

Comments
 (0)