Skip to content

[gfx1201] Add tuned kernel configs and FP8 attention support for AMD RDNA4 (Radeon AI PRO R9700)#2242

Draft
vllmellm wants to merge 9 commits intoROCm:mainfrom
EmbeddedLLM:rdna4-ck
Draft

[gfx1201] Add tuned kernel configs and FP8 attention support for AMD RDNA4 (Radeon AI PRO R9700)#2242
vllmellm wants to merge 9 commits intoROCm:mainfrom
EmbeddedLLM:rdna4-ck

Conversation

@vllmellm
Copy link
Contributor

Motivation

Following the correctness fixes in #1681 (ISA patches for v_pk_mul_f32, DPP broadcast, buffer_load_lds, and RMSNorm kernel operand syntax), this PR adds the tuning configs and FP8 attention enablement needed to run performant FP8 inference on gfx1201 (AMD Radeon AI PRO R9700) via vLLM.

Evaluated on Qwen3-0.6B-FP8 (dense) and Qwen3-30B-A3B-FP8 (MoE): up to +33% throughput, +25% TPOT improvement over vLLM default.

Technical Details

1. FP8 Triton Attention (aiter/ops/triton/_triton_kernels/flash_attn_triton_amd/utils.py)

Added "gfx1201" to the FP8_ARCHS frozenset. Without this, the flash attention Triton dispatcher silently falls back to BF16/FP16 arithmetic even when FP8 tensors are provided, missing the FP8 fast-path.

2. CK GEMM Tuned Configs (aiter/configs/a8w8_blockscale_tuned_gemm.csv, a8w8_blockscale_untuned_gemm.csv)

Added ~370 tuned kernel configurations for gfx1201 covering common M/N/K dimensions from Qwen3-0.6B-FP8 and Qwen3-30B-A3B-FP8. Each entry maps a tensor shape to an optimal CK (Composable Kernel) instance for block-scaled FP8 GEMM. Without these, the dispatcher has no valid mapping and falls back to suboptimal kernels or errors.

3. Triton GEMM Configs (aiter/ops/triton/configs/gemm/gfx1201-*.json)

Added 37 Triton GEMM config files for gfx1201 across the following variants:

  • GEMM-A8W8 / GEMM-A8W8_BLOCKSCALE / GEMM-A8W8_BLOCKSCALE_PRESHUFFLED / GEMM-A8W8_PER_TOKEN_SCALE
  • GEMM-A16W16 / GEMM-A16W16-ATOMIC
  • BATCHED_GEMM-A8W8

Includes both default configs and shape-specific overrides for common N/K dimensions (1024/2048/3072/4096/6144 × 1024).

Test Plan

  • Unit tests: op_tests/test_gemm_a8w8_blockscale.py, op_tests/test_gemm_a8w8.py, op_tests/triton_tests/gemm/basic/test_gemm_a8w8_blockscale.py, op_tests/triton_tests/gemm/basic/test_gemm_a8w8_per_token_scale.py on gfx1201
  • End-to-end: vLLM serving benchmark (benchmarks/benchmark_serving.py) on Qwen3-0.6B-FP8 and Qwen3-30B-A3B-FP8

Test Result

Benchmarked on AMD Radeon AI PRO R9700 (gfx1201) via vLLM serving benchmark.

Mean TTFT (s) — lower is better

ISL/OSL 1. Default 2. aiter CK 3. aiter CK + Attn 4. aiter CK + Attn + Norm
1024/1024 0.482 0.307 0.657 0.340
2048/2048 0.551 0.608 0.647 0.680
4096/4096 2.007 2.020 2.410 2.522
8192/1024 12.306 10.451 12.137 12.294
16384/2048 141.984 130.585 134.655 136.155

Mean TPOT (s) — lower is better

ISL/OSL 1. Default 2. aiter CK 3. aiter CK + Attn 4. aiter CK + Attn + Norm
1024/1024 0.0221 0.0168 0.0158 0.0167
2048/2048 0.0319 0.0266 0.0248 0.0257
4096/4096 0.0508 0.0456 0.0436 0.0445
8192/1024 0.0624 0.0590 0.0644 0.0654
16384/2048 0.0660 0.0610 0.0612 0.0623

Total Token Throughput (tok/s) — higher is better

ISL/OSL 1. Default 2. aiter CK 3. aiter CK + Attn 4. aiter CK + Attn + Norm
1024/1024 2830 3749 3900 3766
2048/2048 1989 2376 2543 2455
4096/4096 1214 1365 1424 1395
8192/1024 3434 3753 3454 3397
16384/2048 1793 1952 1924 1896

Qwen3-30B-A3B-FP8 (MoE)

Mean TTFT (s) — lower is better

ISL/OSL 1. Default 2. aiter CK 3. aiter CK + Attn 4. aiter CK + Attn + Norm
1024/1024 0.926 0.886 0.942 0.940
2048/2048 1.565 1.585 1.575 1.566
4096/4096 5.333 5.387 5.300 5.291
8192/1024 14.066 13.701 13.112 13.088
16384/2048 144.387 136.012 131.844 129.665

Mean TPOT (s) — lower is better

ISL/OSL 1. Default 2. aiter CK 3. aiter CK + Attn 4. aiter CK + Attn + Norm
1024/1024 0.0378 0.0334 0.0322 0.0313
2048/2048 0.0377 0.0367 0.0333 0.0335
4096/4096 0.0456 0.0427 0.0413 0.0410
8192/1024 0.0788 0.0765 0.0761 0.0742
16384/2048 0.0715 0.0675 0.0694 0.0701

Total Token Throughput (tok/s) — higher is better

ISL/OSL 1. Default 2. aiter CK 3. aiter CK + Attn 4. aiter CK + Attn + Norm
1024/1024 1652 1868 1930 1986
2048/2048 1662 1706 1879 1869
4096/4096 1360 1451 1498 1511
8192/1024 2856 2990 3067 3133
16384/2048 1715 1845 1858 1856

Notes

Submission Checklist

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants