diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index c59a4db3999..ea71cbad8b2 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -976,10 +976,20 @@ void launch_fattn( const int tiles_nwaves = (ntiles_dst + max_blocks - 1) / max_blocks; const int tiles_efficiency_percent = 100 * ntiles_dst / (max_blocks*tiles_nwaves); - const int nblocks_stream_k = std::min(max_blocks, ntiles_KV*ntiles_dst); + int nblocks_stream_k = std::min(max_blocks, ntiles_KV*ntiles_dst); const bool use_stream_k = cc >= GGML_CUDA_CC_ADA_LOVELACE || amd_wmma_available(cc) || tiles_efficiency_percent < 75; + //Todo: need to find a thresold based on tuning + constexpr int thr_blocks_stream_k = 16; + + // try reducing the number of stream-k blocks as + // flash_attn_stream_k_fixup has a non-negligible overhead for large number of stream-k blocks + // make sure to reduce only when more than 1 block per SM is used + if(use_stream_k && nblocks_stream_k / ntiles_dst > thr_blocks_stream_k && max_blocks_per_sm > 1) { + nblocks_stream_k = nblocks_stream_k / 2; + } + blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_dst; blocks_num.y = 1; blocks_num.z = 1;