Skip to content

Conversation

@qianfengz
Copy link
Contributor

@qianfengz qianfengz commented Dec 24, 2025

About qr_ks_vs_whole_k_prefetch pipeline

The pipeline qr_ks_vs_whole_k_prefetch is mainly used for the situations where total number of work-groups is not enough to occupy the CUs. When the total number of work-groups is low, use MTile size (kM0) 64 rather than 128 can improve the CU occupancy. And with kM0=64, less registers are consumed to save P and O, thus enough vgprs are left for prefetch the whole k_tile from next iteration in the main-loop, and thus performance can be improved compared to the usual method of using kM0=128,
Except for prefetching whole k tile when kM0=64, the pipeline also has the path to use kM0=128, in which case, 1/2 of n0_loops slices of the k tile are prefetched for next iteration. Path of kM0=128 can be used as a replacement of using pipeline qr_ks_vs_async

What this PR does

  1. Update in the pipeline policy to ensure best mfma instructions are used on MI350
  2. Add the qr_ks_vs_whole_k_prefetch_trload pipeline instance so that V can be loaded using transposed loading on MI350 (avoid the need of lots of shuffling instructions)
  3. Using n0_loop to implement Gemm0 instead of the commonly used k0_loop. n0_loop brings the benefits of less move_tile_window() call, and removing the need of clear_tile(s_acc) in the main loop.
  4. Complete support of naive tile loading for hdim96 and hdim160, which means loading tile of hdim96/hdim160 without having to pad them to hdim128/hdim256
  5. Other fine-grained improvement (eg. use explict partition_index to guarantee warp_id is allocated on vgpr for store_tile/load_tile to/from LDS tile_window)

Performance results

  1. For attention shapes which leads to kM0=64, qr_ks_vs_async_whole_k_prefetch_trload shows much better performance than qr_ks_vs_async_trload on the same case (execution time 41.02ms by whole_k_prefetch_trload & 58.50ms by async_load)
  2. For attention shapes which leads to kM0=128, qr_ks_vs_async_whole_k_prefetch_trload show a little bit better performance than qr_ks_vs_async on mi350 (execution time 104.50ms by whole_k_prefetch_trload & 106.50ms by qr_ks_vs_async). And they shows completely on-par performance on MI300

Test/Verify

  1. Use the ROCM xformers branch test_whole_k_prefetch_n0loop to test/verify qr_ks_vs_whole_k_prefetch pipeline since this pipeline can not be used by ck_tile fmha example so far
  2. Use the following command-line for building/testing xformers
#> git clone -b test_whole_k_prefetch_n0loop https://github.com/ROCm/xformers
#> git submodule update --init --recursive   
#> pip  install --no-build-isolation -e ./
#> pytest tests/test_mem_eff_attention.py::test_forward
  1. Any scripts which can run on xformers can be used to evaluate qr_ks_vs_whole_k_prefetch pipeline. Using the two environ variable to switch from using different pipelines

#> export FMHA_DISABLE_SPECIAL_TREATMENT=1 #> to disable using FAV3 and qr_ks_vs_async_trload pipeline
#> export FMHA_DISABLE_ASYNC_PIPELINE=1 #> to disable using qr_ks_vs_async pipeline

Discussion

… next iteration in the non-whole-k-perfetch path
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants