Skip to content

[CUDA ] Write an optimized flash_attn_stream_k_fixup kernel#21159

Open
gaugarg-nv wants to merge 2 commits intoggml-org:masterfrom
gaugarg-nv:fa_opt
Open

[CUDA ] Write an optimized flash_attn_stream_k_fixup kernel#21159
gaugarg-nv wants to merge 2 commits intoggml-org:masterfrom
gaugarg-nv:fa_opt

Conversation

@gaugarg-nv
Copy link
Copy Markdown
Contributor

@gaugarg-nv gaugarg-nv commented Mar 29, 2026

This is a follow-up to PR: #21086

The observation was that flash_attn_stream_k_fixup takes significant time if nblocks_stream_k is significantly larger than ntiles_dst.

The reason for this was that flash_attn_stream_k_fixup launches too many blocks with either redundant or no work for many of the blocks.

Based on the idea from @JohannesGaessler at #21086 (comment), I have written a specialized and more optimized kernel for cases where nblocks_stream_k is a multiple of ntiles_dst. This PR also makes nblocks_stream_k to be multiple of ntiles_dst if nblocks_stream_k > 4 * ntiles_dst.

I'm seeing significant perf improvement with BS=1,2,4,8,16 and no change for BS=512.

Performance on RTX Pro 6000 Blackwell
model_type n_ubatch n_prompt n_depth Master-avg_ts PR-avg_ts  Speed-up
qwen3moe 30B.A3B Q4_0 1 512 8192 194.2906 244.9483 1.26
qwen3moe 30B.A3B Q4_0 2 512 8192 248.2377 357.0286 1.44
qwen3moe 30B.A3B Q4_0 4 512 8192 409.2129 548.2016 1.34
qwen3moe 30B.A3B Q4_0 8 512 8192 553.7515 661.8706 1.20
qwen3moe 30B.A3B Q4_0 16 512 8192 964.083 1041.194 1.08
gpt-oss 20B MXFP4 MoE 1 512 8192 349.3694 386.0867 1.11
gpt-oss 20B MXFP4 MoE 2 512 8192 485.0375 527.4232 1.09
gpt-oss 20B MXFP4 MoE 4 512 8192 738.8452 787.8727 1.07
gpt-oss 20B MXFP4 MoE 8 512 8192 1002.743 1044.264 1.04
gpt-oss 20B MXFP4 MoE 16 512 8192 1622.006 1665.784 1.03
gpt-oss 120B MXFP4 MoE 1 512 8192 235.1368 256.2041 1.09
gpt-oss 120B MXFP4 MoE 2 512 8192 315.9809 342.3444 1.08
gpt-oss 120B MXFP4 MoE 4 512 8192 476.9255 507.3815 1.06
gpt-oss 120B MXFP4 MoE 8 512 8192 562.2907 582.2067 1.04
gpt-oss 120B MXFP4 MoE 16 512 8192 882.1752 904.8635 1.03
qwen3 4B Q4_K - Medium 1 512 8192 273.4305 273.9293 1.00
qwen3 4B Q4_K - Medium 2 512 8192 452.7331 496.2888 1.10
qwen3 4B Q4_K - Medium 4 512 8192 702.9103 849.9596 1.21
qwen3 4B Q4_K - Medium 8 512 8192 1137.738 1318.802 1.16
qwen3 4B Q4_K - Medium 16 512 8192 2215.583 2553.496 1.15
llama 8B Q4_0 1 512 8192 234.5946 234.7103 1.00
llama 8B Q4_0 2 512 8192 408.5204 439.4525 1.08
llama 8B Q4_0 4 512 8192 701.5437 827.5419 1.18
llama 8B Q4_0 8 512 8192 1183.018 1354.725 1.15
llama 8B Q4_0 16 512 8192 2035.527 2281.44 1.12
qwen3 14B Q8_0 1 512 8192 82.25566 84.28179 1.02
qwen3 14B Q8_0 2 512 8192 145.7982 158.2092 1.09
qwen3 14B Q8_0 4 512 8192 286.2245 310.4999 1.08
qwen3 14B Q8_0 8 512 8192 529.7227 570.7132 1.08
qwen3 14B Q8_0 16 512 8192 981.9385 1053.469 1.07
qwen2 32B Q4_K - Medium 1 512 8192 61.83612 63.77031 1.03
qwen2 32B Q4_K - Medium 2 512 8192 106.7554 117.7693 1.10
qwen2 32B Q4_K - Medium 4 512 8192 178.1692 193.2558 1.08
qwen2 32B Q4_K - Medium 8 512 8192 235.8052 250.7318 1.06
qwen2 32B Q4_K - Medium 16 512 8192 675.7395 729.7434 1.08
llama 70B Q4_0 1 512 8192 33.82455 34.87174 1.03
llama 70B Q4_0 2 512 8192 62.55735 67.12654 1.07
llama 70B Q4_0 4 512 8192 120.0981 128.6658 1.07
llama 70B Q4_0 8 512 8192 180.365 190.3354 1.06
llama 70B Q4_0 16 512 8192 389.64 416.2356 1.07
qwen3moe 30B.A3B Q4_0 1 512 16384 184.7588 220.5698 1.19
qwen3moe 30B.A3B Q4_0 2 512 16384 232.5748 327.1713 1.41
qwen3moe 30B.A3B Q4_0 4 512 16384 386.5917 511.5847 1.32
qwen3moe 30B.A3B Q4_0 8 512 16384 528.2092 631.2642 1.20
qwen3moe 30B.A3B Q4_0 16 512 16384 921.2011 993.2606 1.08
gpt-oss 20B MXFP4 MoE 1 512 16384 337.7098 366.6163 1.09
gpt-oss 20B MXFP4 MoE 2 512 16384 468.9916 509.2715 1.09
gpt-oss 20B MXFP4 MoE 4 512 16384 717.6101 764.9691 1.07
gpt-oss 20B MXFP4 MoE 8 512 16384 982.2338 1021.7 1.04
gpt-oss 20B MXFP4 MoE 16 512 16384 1586.696 1641.923 1.03
gpt-oss 120B MXFP4 MoE 1 512 16384 224.2185 243.4467 1.09
gpt-oss 120B MXFP4 MoE 2 512 16384 305.4747 330.9628 1.08
gpt-oss 120B MXFP4 MoE 4 512 16384 464.3444 493.7532 1.06
gpt-oss 120B MXFP4 MoE 8 512 16384 553.3267 572.2047 1.03
gpt-oss 120B MXFP4 MoE 16 512 16384 862.2424 884.1049 1.03
qwen3 4B Q4_K - Medium 1 512 16384 225.3375 225.3917 1.00
qwen3 4B Q4_K - Medium 2 512 16384 387.1468 416.7638 1.08
qwen3 4B Q4_K - Medium 4 512 16384 615.1846 723.6386 1.18
qwen3 4B Q4_K - Medium 8 512 16384 1015.441 1160.138 1.14
qwen3 4B Q4_K - Medium 16 512 16384 1972.563 2238.497 1.13
llama 8B Q4_0 1 512 16384 202.8418 202.7286 1.00
llama 8B Q4_0 2 512 16384 361.155 382.9665 1.06
llama 8B Q4_0 4 512 16384 626.8773 725.8326 1.16
llama 8B Q4_0 8 512 16384 1069.192 1207.6 1.13
llama 8B Q4_0 16 512 16384 1859.911 2068.387 1.11
qwen3 14B Q8_0 1 512 16384 77.18114 78.75773 1.02
qwen3 14B Q8_0 2 512 16384 137.3252 148.1467 1.08
qwen3 14B Q8_0 4 512 16384 269.5644 290.6831 1.08
qwen3 14B Q8_0 8 512 16384 498.5193 533.9284 1.07
qwen3 14B Q8_0 16 512 16384 903.9419 990.7183 1.10
qwen2 32B Q4_K - Medium 1 512 16384 57.27158 58.74953 1.03
qwen2 32B Q4_K - Medium 2 512 16384 99.65511 108.9368 1.09
qwen2 32B Q4_K - Medium 4 512 16384 167.5539 178.8494 1.07
qwen2 32B Q4_K - Medium 8 512 16384 225.2119 234.5024 1.04
qwen2 32B Q4_K - Medium 16 512 16384 616.3309 679.8824 1.10
llama 70B Q4_0 1 512 16384 32.09406 32.95899 1.03
llama 70B Q4_0 2 512 16384 59.43024 63.47142 1.07
llama 70B Q4_0 4 512 16384 114.3537 121.5395 1.06
llama 70B Q4_0 8 512 16384 174.2445 179.8662 1.03
llama 70B Q4_0 16 512 16384 365.266 395.44 1.08
qwen3moe 30B.A3B Q4_0 1 512 32768 156.6724 182.3001 1.16
qwen3moe 30B.A3B Q4_0 2 512 32768 208.4212 280.1442 1.34
qwen3moe 30B.A3B Q4_0 4 512 32768 351.1808 449.6801 1.28
qwen3moe 30B.A3B Q4_0 8 512 32768 491.3763 579.1859 1.18
qwen3moe 30B.A3B Q4_0 16 512 32768 859.4219 926.6355 1.08
gpt-oss 20B MXFP4 MoE 1 512 32768 310.9347 334.4992 1.08
gpt-oss 20B MXFP4 MoE 2 512 32768 441.3317 476.0534 1.08
gpt-oss 20B MXFP4 MoE 4 512 32768 684.7815 725.451 1.06
gpt-oss 20B MXFP4 MoE 8 512 32768 950.1508 991.7097 1.04
gpt-oss 20B MXFP4 MoE 16 512 32768 1522.841 1589.66 1.04
gpt-oss 120B MXFP4 MoE 1 512 32768 206.7893 222.5515 1.08
gpt-oss 120B MXFP4 MoE 2 512 32768 286.6845 308.383 1.08
gpt-oss 120B MXFP4 MoE 4 512 32768 439.5258 464.2738 1.06
gpt-oss 120B MXFP4 MoE 8 512 32768 510.1555 529.1474 1.04
gpt-oss 120B MXFP4 MoE 16 512 32768 789.1286 814.6637 1.03
qwen3 4B Q4_K - Medium 1 512 32768 167.6451 167.4462 1.00
qwen3 4B Q4_K - Medium 2 512 32768 296.955 314.7754 1.06
qwen3 4B Q4_K - Medium 4 512 32768 495.3276 562.759 1.14
qwen3 4B Q4_K - Medium 8 512 32768 844.0766 939.2867 1.11
qwen3 4B Q4_K - Medium 16 512 32768 1640.37 1810.961 1.10
llama 8B Q4_0 1 512 32768 159.2656 159.1219 1.00
llama 8B Q4_0 2 512 32768 289.212 303.5063 1.05
llama 8B Q4_0 4 512 32768 515.4942 578.7951 1.12
llama 8B Q4_0 8 512 32768 898.2365 993.4768 1.11
llama 8B Q4_0 16 512 32768 1593.617 1736.464 1.09
qwen3 14B Q8_0 1 512 32768 68.21444 69.57448 1.02
qwen3 14B Q8_0 2 512 32768 123.1306 131.6828 1.07
qwen3 14B Q8_0 4 512 32768 242.1461 258.6565 1.07
qwen3 14B Q8_0 8 512 32768 449.4426 477.5033 1.06
qwen3 14B Q8_0 16 512 32768 790.2553 887.626 1.12
qwen2 32B Q4_K - Medium 1 512 32768 49.58599 50.76794 1.02
qwen2 32B Q4_K - Medium 2 512 32768 87.80284 94.93992 1.08
qwen2 32B Q4_K - Medium 4 512 32768 151.1674 160.6833 1.06
qwen2 32B Q4_K - Medium 8 512 32768 209.9186 217.9663 1.04
qwen2 32B Q4_K - Medium 16 512 32768 536.1302 600.895 1.12
llama 70B Q4_0 1 512 32768 28.94624 29.68453 1.03
llama 70B Q4_0 2 512 32768 54.05033 57.32374 1.06
llama 70B Q4_0 4 512 32768 104.4469 110.3038 1.06
llama 70B Q4_0 8 512 32768 162.8966 167.9878 1.03
llama 70B Q4_0 16 512 32768 327.3872 360.6217 1.10
qwen3moe 30B.A3B Q4_0 512 512 8192 7188.651 7187.102 1.00
qwen3moe 30B.A3B Q4_0 512 512 16384 6023.053 6013.735 1.00
qwen3moe 30B.A3B Q4_0 512 512 32768 4591.681 4587.174 1.00
gpt-oss 20B MXFP4 MoE 512 512 8192 13157.46 13118.31 1.00
gpt-oss 20B MXFP4 MoE 512 512 16384 11713.61 11712.45 1.00
gpt-oss 20B MXFP4 MoE 512 512 32768 9623.017 9651.042 1.00
gpt-oss 120B MXFP4 MoE 512 512 8192 6355.704 6356.626 1.00
gpt-oss 120B MXFP4 MoE 512 512 16384 5891.147 5889.586 1.00
gpt-oss 120B MXFP4 MoE 512 512 32768 5115.74 5113.34 1.00
qwen3 4B Q4_K - Medium 512 512 8192 14818.83 14858.16 1.00
qwen3 4B Q4_K - Medium 512 512 16384 11262.12 11278.12 1.00
qwen3 4B Q4_K - Medium 512 512 32768 6467.29 6467.709 1.00
llama 8B Q4_0 512 512 8192 12608.22 12727.42 1.01
llama 8B Q4_0 512 512 16384 10072.89 10103.35 1.00
llama 8B Q4_0 512 512 32768 6328.296 6347.458 1.00
qwen3 14B Q8_0 512 512 8192 6340.334 6317.186 1.00
qwen3 14B Q8_0 512 512 16384 4962.728 4922.053 0.99
qwen3 14B Q8_0 512 512 32768 3182.909 3182.51 1.00
qwen2 32B Q4_K - Medium 512 512 8192 3071.667 3061.586 1.00
qwen2 32B Q4_K - Medium 512 512 16384 2476.393 2472.166 1.00
qwen2 32B Q4_K - Medium 512 512 32768 1678.167 1676.694 1.00
llama 70B Q4_0 512 512 8192 1673.669 1672.007 1.00
llama 70B Q4_0 512 512 16384 1435.91 1434.845 1.00
llama 70B Q4_0 512 512 32768 1058.678 1062.875 1.00

Requirements

  • I have read and agree with the contributing guidelines
  • AI usage disclosure: Yes, to understand the implementation details of flash_attn_stream_k_fixup kernel and for code review

Write a specialized and more optimized kernel for cases where nblocks_stream_k is multiple of ntiles_dst.
Make nblocks_stream_k to multiple of ntiles_dst if nblocks_stream_k > 2 * ntiles_dst
@gaugarg-nv gaugarg-nv requested a review from a team as a code owner March 29, 2026 19:16
@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Mar 29, 2026
@gaugarg-nv gaugarg-nv changed the title Write an optimized flash_attn_stream_k_fixup kernel [CUDA ] Write an optimized flash_attn_stream_k_fixup kernel Mar 29, 2026
…make sure we have enough concurrency on GPUs
@gaugarg-nv
Copy link
Copy Markdown
Contributor Author

I changed the threshold for the new kernel from nblocks_stream_k > 2 * ntiles_dst to nblocks_stream_k > 4 * ntiles_dst to make sure GPU occupancy remains good. Without this, I was seeing a regression of 3% for llama-8b with Tensor parallelism on 2x RTX Pro Blackwell for BS=512. This change just means that large batch sizes will continue to use the older kernel.

Updated the PR description also.

Copy link
Copy Markdown
Contributor

@JohannesGaessler JohannesGaessler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some historical context: I had initially written the mma FA kernel and the fixup with naive integer division/modulo. I later learned from @ORippler that you can significantly speed up these integer divisions via precomputations in host code, as is now done in fastdiv defined in common.cuh. However, I had put off adding this optimization to the FA code because it's currently being rewritten - I have already consolidated the tile and vector kernels, AMD support is being added to the mma kernel, after that the WMMA kernel will be removed.

const int tid = threadIdx.x;

// nblocks_stream_k is a multiple of ntiles_dst (== gridDim.x), so each tile gets the same number of blocks.
const int blocks_per_tile = nblocks_stream_k / gridDim.x;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This operation can make use of fastdiv as defined in common.cuh.


const float * dst_fixup_data = ((const float *) dst_fixup) + nblocks_stream_k*(2*2*ncols);

const int gqa_ratio = ne02 / ne12;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Preferably precompute this in host code.

Comment on lines +704 to +707
const int sequence = tile_idx /(iter_j*iter_z_gqa*ne12);
const int z_KV = (tile_idx - iter_j*iter_z_gqa*ne12 * sequence)/(iter_j*iter_z_gqa);
const int zt_gqa = (tile_idx - iter_j*iter_z_gqa*ne12 * sequence - iter_j*iter_z_gqa * z_KV)/iter_j;
const int jt = tile_idx - iter_j*iter_z_gqa*ne12 * sequence - iter_j*iter_z_gqa * z_KV - iter_j * zt_gqa;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Preferably use fastdiv.

// (blocks_num.x not a multiple of ntiles_dst)
template <int D, int ncols1, int ncols2>
__launch_bounds__(D, 1)
static __global__ void flash_attn_stream_k_fixup_fallback(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please change the suffix for this kernel to _general and add a suffix of _uniform for the special case.

const int nblocks_stream_k_raw = std::min(max_blocks, ntiles_KV*ntiles_dst);
// Round down to a multiple of ntiles_dst so that each output tile gets the same number of blocks.
// do this only if nblocks_stream_k_raw is at least 4x ntiles_dst to avoid excessive loss of occupancy
const int nblocks_stream_k = nblocks_stream_k_raw > 4 * ntiles_dst
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is the correct heuristic. Again, for an infinitely long context we would never want to reduce the number of CUDA blocks, especially for Deepseek where there are cases where we can fit only a single block / SM. I think a better heuristic will be to always round down for parallel_blocks > 2 since there the impact should be negligible. For parallel_blocks == 2 and parallel_blocks == 1, check the minimum slice of K->ne[1] that a CUDA block would need to make rounding down worthwhile.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Many of the models I have listed above in the perf section actually have parallel_blocks of 2 or 1. I will try to collect more data and see at which seq length we start seeing negative speed-up. I can change the heuristic accordingly.

The reason for putting nblocks_stream_k_raw > 4 * ntiles_dst heuristic was to reduce the tail effect and keep high occupancy. I can try to increase the threshold to a larger value like 16 * ntiles_dst, making sure low batch sizes continue to benefit with this change. This will reduce the number of idle SMs due to the tail effect.

Which approach is more preferred by you?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example, with nblocks_stream_k_raw > 16 * ntiles_dst, the maximum number of blocks reduced will be 5.8% (0.99/16.99).

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer that you try to find the thresholds for parallel_blocks == 1, parallel_blocks == 2, and parallel_blocks > 2 within context windows sizes that people would frequently use. Alternatively, try to define the maximum acceptable efficiency loss from rounding down and try to fuse this with the logic that already exists for tiling on Ampere.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants