diff --git a/examples/matmul/run.cpp b/examples/matmul/run.cpp index e76ce1f..0a5e3b3 100644 --- a/examples/matmul/run.cpp +++ b/examples/matmul/run.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include "gpu.h" // createContext, createTensor, createKernel, dispatchKernel, // wait, resetCommandBuffer, toCPU @@ -233,6 +234,120 @@ fn main( } )"; +/* 2D block-tiling + * + */ +static const char *kShaderMatmul4 = R"( +@group(0) @binding(0) var a: array<{{precision}}>; +@group(0) @binding(1) var b: array<{{precision}}>; +@group(0) @binding(2) var c: array<{{precision}}>; +var tileA: array<{{precision}}, {{BM}} * {{BK}}>; +var tileB: array<{{precision}}, {{BN}} * {{BK}}>; + +@compute @workgroup_size({{workgroupSize}}) +fn main( + @builtin(global_invocation_id) globalID : vec3, + @builtin(local_invocation_id) localID : vec3, + @builtin(workgroup_id) groupid : vec3) { + + var threadResults: array<{{precision}}, {{TM}} * {{TN}}>; + var localM: array<{{precision}}, {{TM}}>; + var localN: array<{{precision}}, {{TN}}>; + + let cRow: u32 = groupid.x; + let cCol: u32 = groupid.y; + let numThread: u32 = ({{BM}} * {{BN}}) / ({{TM}} * {{TN}}); + + // position of the first c element computed by the thread + let threadRow: u32 = (localID.x / ({{BN}} / {{TN}})) * {{TM}}; + let threadCol: u32 = (localID.x % ({{BN}} / {{TN}})) * {{TN}}; + + let numIterA: u32 = {{BM}} * {{BK}} / ({{BM}} * {{BN}} / ({{TM}} * {{TN}})); + let numIterB: u32 = {{BK}} * {{BN}} / ({{BM}} * {{BN}} / ({{TM}} * {{TN}})); + + // aPtr and bPtr are the starting positions of the tiles in a and b, + // incremented in the bkidx loop. + // cPtr is the starting position of the tile in c which is fixed. + + var aPtr = cRow * {{BM}} * {{K}}; + var bPtr = cCol * {{BN}} * {{K}}; + let cPtr = cRow * {{BM}} * {{N}} + cCol * {{BN}}; + + for (var bkidx = 0; bkidx < {{K}}; bkidx += {{BK}}) { + + // Load tile + // Load BM x BK by numThread(BM * BN / (TM * TN)) + // The number of iteration == BM * BK / (BM * BN / (TM * TN)) + for (var i: u32 = 0; i < numIterA; i++) { + let loadColA: u32 = (localID.x + i * numThread) % {{BK}}; + let loadRowA: u32 = (localID.x + i * numThread) / {{BK}}; + tileA[loadRowA * {{BK}} + loadColA] = a[aPtr + loadRowA * {{K}} + loadColA]; + } + // Load BK x BN by numThread(BM * BN / (TM * TN)) + // The number of iteration == BK * BN / (BM * BN / (TM * TN)) + for (var i: u32 = 0; i < numIterB; i++) { + let loadColB: u32 = (localID.x + i * numThread) % {{BK}}; + let loadRowB: u32 = (localID.x + i * numThread) / {{BK}}; + tileB[loadRowB * {{BK}} + loadColB] = b[bPtr + loadRowB * {{K}} + loadColB]; + } + + aPtr += {{BK}}; + bPtr += {{BK}}; + + workgroupBarrier(); + // Compute tile + for (var dotIdx: u32 = 0; dotIdx < {{BK}}; dotIdx = dotIdx + 1) { + for (var i: u32 = 0; i < {{TM}}; i++) { + localM[i] = tileA[(threadRow + i) * {{BK}} + dotIdx]; + } + for (var i: u32 = 0; i < {{TN}}; i++) { + localN[i] = tileB[(threadCol + i) * {{BK}} + dotIdx]; + } + for (var resIdxM: u32 = 0; resIdxM < {{TM}}; resIdxM++) { + for (var resIdxN: u32 = 0; resIdxN < {{TN}}; resIdxN++) { + threadResults[resIdxM * {{TN}} + resIdxN] += localM[resIdxM] * localN[resIdxN]; + } + } + } + workgroupBarrier(); + } + + for (var resIdxM: u32 = 0; resIdxM < {{TM}}; resIdxM++) { + for (var resIdxN: u32 = 0; resIdxN < {{TN}}; resIdxN++) { + c[cPtr + (threadRow + resIdxM) * {{N}} + threadCol + resIdxN] = threadResults[resIdxM * {{TN}} + resIdxN]; + } + } +} +)"; + +inline ShaderCode createMatmul4(const char *shaderTemplate, const size_t M, + const size_t K, const size_t N, const size_t BM, + const size_t BK, const size_t BN, + const size_t TM, const size_t TN, + const Shape &workgroupSize = {256, 1, 1}, + NumType precision = kf32) { + assert(BM % TM == 0); + assert(BN % TN == 0); + assert(K % BK == 0); + assert(M % BM == 0); + assert(N % BN == 0); + // # threads = tile A size == tile B size == # threads for computing C + //assert(/* tile A size */ BM * BK == /* tile B size */ BK * BN); + //assert(/* tile A size */ BM * BK == /* # of threads for C */ BM * BN / (TM * TN)); + std::string codeString(shaderTemplate); + replaceAll(codeString, {{"{{workgroupSize}}", toString(workgroupSize)}, + {"{{precision}}", toString(precision)}, + {"{{M}}", toString(M)}, + {"{{K}}", toString(K)}, + {"{{N}}", toString(N)}, + {"{{BM}}", toString(BM)}, + {"{{BK}}", toString(BK)}, + {"{{BN}}", toString(BN)}, + {"{{TM}}", toString(TM)}, + {"{{TN}}", toString(TN)}}); + return ShaderCode{codeString, workgroupSize}; +} + inline ShaderCode createNoOp(const char *shaderTemplate, const Shape &workgroupSize = {256, 1, 1}, NumType precision = kf32) { @@ -304,6 +419,22 @@ Kernel selectMatmul(Context &ctx, int version, kernel = createKernel(ctx, matmul, bindings, /*nWorkgroups*/ nWorkgroups); } else if (version == 4) { + static constexpr size_t BM = 64; + static constexpr size_t BK = 16; + static constexpr size_t BN = 64; + static constexpr size_t TM = BM / BK; + static constexpr size_t TN = BN / BK; + Shape wgSize = {(BM / TM) * (BN / TN), 1, 1}; // This is the same as BK * BK. + Shape nWorkgroups = {cdiv(M, BM), cdiv(N, BN), 1}; + LOG(kDefLog, kInfo, "M: %d, K: %d, N: %d", M, K, N); + LOG(kDefLog, kInfo, "BM: %d, BK: %d, BN: %d, TM: %d, TN: %d", BM, BK, BN, TM, TN); + LOG(kDefLog, kInfo, "wgSize: ( %s )", toString(wgSize).c_str()); + LOG(kDefLog, kInfo, "nWorkgroups: ( %s )", toString(nWorkgroups).c_str()); + ShaderCode matmul = createMatmul4(kShaderMatmul4, M, K, N, BM, BK, BN, TM, TN, + /*wgSize*/ wgSize); + kernel = createKernel(ctx, matmul, bindings, + /*nWorkgroups*/ nWorkgroups); + } else if (version == 5) { Shape wgSize = {256, 1, 1}; Shape nWorkgroups = cdiv({M, N, 1}, {16, 16, 1}); ShaderCode matmul = createNoOp(kShaderNoOp, /*wgsize*/ wgSize); @@ -371,10 +502,14 @@ void runTest(int version, size_t M, size_t K, size_t N, } int main() { - int version = 3; // 1 == naive matmul - // 2 == tiling - // 3 == 1D blocktiling - // 4 == No-Op + char* version_str = getenv("MATMUL_VERSION"); + int version = version_str == NULL ? 3 : atoi(version_str); + // 1 == naive matmul + // 2 == tiling + // 3 == 1D blocktiling + // 4 == 2D blocktiling + // 5 == No-Op + size_t M, K, N; // Matrix dimensions static constexpr int kTestSize = 2; if constexpr (kTestSize == 0) {