[CUDA ] Write an optimized flash_attn_stream_k_fixup kernel#21159
[CUDA ] Write an optimized flash_attn_stream_k_fixup kernel#21159gaugarg-nv wants to merge 2 commits intoggml-org:masterfrom
Conversation
Write a specialized and more optimized kernel for cases where nblocks_stream_k is multiple of ntiles_dst. Make nblocks_stream_k to multiple of ntiles_dst if nblocks_stream_k > 2 * ntiles_dst
…make sure we have enough concurrency on GPUs
|
I changed the threshold for the new kernel from Updated the PR description also. |
JohannesGaessler
left a comment
There was a problem hiding this comment.
Some historical context: I had initially written the mma FA kernel and the fixup with naive integer division/modulo. I later learned from @ORippler that you can significantly speed up these integer divisions via precomputations in host code, as is now done in fastdiv defined in common.cuh. However, I had put off adding this optimization to the FA code because it's currently being rewritten - I have already consolidated the tile and vector kernels, AMD support is being added to the mma kernel, after that the WMMA kernel will be removed.
| const int tid = threadIdx.x; | ||
|
|
||
| // nblocks_stream_k is a multiple of ntiles_dst (== gridDim.x), so each tile gets the same number of blocks. | ||
| const int blocks_per_tile = nblocks_stream_k / gridDim.x; |
There was a problem hiding this comment.
This operation can make use of fastdiv as defined in common.cuh.
|
|
||
| const float * dst_fixup_data = ((const float *) dst_fixup) + nblocks_stream_k*(2*2*ncols); | ||
|
|
||
| const int gqa_ratio = ne02 / ne12; |
There was a problem hiding this comment.
Preferably precompute this in host code.
| const int sequence = tile_idx /(iter_j*iter_z_gqa*ne12); | ||
| const int z_KV = (tile_idx - iter_j*iter_z_gqa*ne12 * sequence)/(iter_j*iter_z_gqa); | ||
| const int zt_gqa = (tile_idx - iter_j*iter_z_gqa*ne12 * sequence - iter_j*iter_z_gqa * z_KV)/iter_j; | ||
| const int jt = tile_idx - iter_j*iter_z_gqa*ne12 * sequence - iter_j*iter_z_gqa * z_KV - iter_j * zt_gqa; |
There was a problem hiding this comment.
Preferably use fastdiv.
| // (blocks_num.x not a multiple of ntiles_dst) | ||
| template <int D, int ncols1, int ncols2> | ||
| __launch_bounds__(D, 1) | ||
| static __global__ void flash_attn_stream_k_fixup_fallback( |
There was a problem hiding this comment.
Please change the suffix for this kernel to _general and add a suffix of _uniform for the special case.
| const int nblocks_stream_k_raw = std::min(max_blocks, ntiles_KV*ntiles_dst); | ||
| // Round down to a multiple of ntiles_dst so that each output tile gets the same number of blocks. | ||
| // do this only if nblocks_stream_k_raw is at least 4x ntiles_dst to avoid excessive loss of occupancy | ||
| const int nblocks_stream_k = nblocks_stream_k_raw > 4 * ntiles_dst |
There was a problem hiding this comment.
I don't think this is the correct heuristic. Again, for an infinitely long context we would never want to reduce the number of CUDA blocks, especially for Deepseek where there are cases where we can fit only a single block / SM. I think a better heuristic will be to always round down for parallel_blocks > 2 since there the impact should be negligible. For parallel_blocks == 2 and parallel_blocks == 1, check the minimum slice of K->ne[1] that a CUDA block would need to make rounding down worthwhile.
There was a problem hiding this comment.
Many of the models I have listed above in the perf section actually have parallel_blocks of 2 or 1. I will try to collect more data and see at which seq length we start seeing negative speed-up. I can change the heuristic accordingly.
The reason for putting nblocks_stream_k_raw > 4 * ntiles_dst heuristic was to reduce the tail effect and keep high occupancy. I can try to increase the threshold to a larger value like 16 * ntiles_dst, making sure low batch sizes continue to benefit with this change. This will reduce the number of idle SMs due to the tail effect.
Which approach is more preferred by you?
There was a problem hiding this comment.
For example, with nblocks_stream_k_raw > 16 * ntiles_dst, the maximum number of blocks reduced will be 5.8% (0.99/16.99).
There was a problem hiding this comment.
I would prefer that you try to find the thresholds for parallel_blocks == 1, parallel_blocks == 2, and parallel_blocks > 2 within context windows sizes that people would frequently use. Alternatively, try to define the maximum acceptable efficiency loss from rounding down and try to fuse this with the logic that already exists for tiling on Ampere.
This is a follow-up to PR: #21086
The observation was that
flash_attn_stream_k_fixuptakes significant time ifnblocks_stream_kis significantly larger thanntiles_dst.The reason for this was that
flash_attn_stream_k_fixuplaunches too many blocks with either redundant or no work for many of the blocks.Based on the idea from @JohannesGaessler at #21086 (comment), I have written a specialized and more optimized kernel for cases where
nblocks_stream_kis a multiple ofntiles_dst. This PR also makesnblocks_stream_kto be multiple ofntiles_dstifnblocks_stream_k > 4 * ntiles_dst.I'm seeing significant perf improvement with BS=1,2,4,8,16 and no change for BS=512.
Performance on RTX Pro 6000 Blackwell
Requirements