Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
498 commits
Select commit Hold shift + click to select a range
6387433
[FIX] Allow m_block_size == 192 and mma_pv_is_rs == False in Sm90 CuT…
reubenconducts Sep 2, 2025
afc97c6
make FA3 compatible with CUDA 13 Builds (#1860)
johnnynunez Sep 4, 2025
dfb6649
[BUILD] SBSA wheels + CUDA 13 Support (#1865)
johnnynunez Sep 5, 2025
e8c7344
benchmark: qualify all attention backends by methods list (#1881)
rajesh-s Sep 12, 2025
b3846b0
ABI stable fa3 (#1791)
mikaylagawarecki Sep 12, 2025
7bdb426
[NVIDIA] Enable Blackwell Family Specific (#1882)
johnnynunez Sep 12, 2025
e980f0f
fix typo in flops calculation for local attention (#1883)
henrylhtsang Sep 13, 2025
2cc6fd6
flash-attn-cute bwd sm90 (#1868)
tzadouri Sep 13, 2025
8ecf128
[Cute] Make testing utils standlone for cute (#1892)
drisspg Sep 17, 2025
589cc20
Bump pin for CuTeDSL (#1891)
drisspg Sep 17, 2025
5c1627a
Improve causal backward determinism perf with SPT schedule (#1893)
jayhshah Sep 17, 2025
1ceaa98
Upgrade to cutlass v4.2.1 (#1905)
johnnynunez Sep 23, 2025
3b24b08
switch to use cutlass.utils.get_smem_capacity_in_bytes instead of dep…
brandon-yujie-sun Sep 24, 2025
0165c96
Add Missing None Gradient in FA3 QKVPacked (#1908)
JackCharlesZhang Sep 24, 2025
add1756
C++11 fix warnings (#1904)
johnnynunez Sep 25, 2025
cc0a79b
[Cute] Write ex2 emulation in a more readable form
tridao Sep 27, 2025
5059fd5
[Cute] Simplify utils.py a bit
tridao Sep 27, 2025
c485eea
[Cute] Remove arith & vector import in utils.py
tridao Oct 1, 2025
cbd2490
[CuteDSL] Fix test (#1925)
drisspg Oct 7, 2025
5183de4
Refactors to enable FlexAttention (#1840)
drisspg Oct 8, 2025
a38d69d
[Cute] Fix softmax for cutlass-dsl==4.2.1
tridao Oct 11, 2025
437b35a
[Cute] Fix softmax for fwd_sm100
tridao Oct 12, 2025
ea03e06
[Cute,Bwd] Simplify bwd_preprocessing kernel
tridao Oct 12, 2025
fbdba01
[Cute,Fwd,Sm90] Simplify by passing around functions
tridao Oct 12, 2025
b528f4b
[Cute,Fwd,Sm90] Simplify score mode by passing around partial fn
tridao Oct 12, 2025
13f2077
[Cute] Optionally dump cubin and sass
tridao Oct 12, 2025
c172985
[Cute,Fwd,Sm90] Rename m_block_size->tile_m, n_block_size->tile_n
tridao Oct 12, 2025
9eee089
[Cute,Bwd,Sm90] Format file w ruff
tridao Oct 12, 2025
42e4e3e
[Cute,Bwd,Sm90] Fix bwd dK & dV, more async
tridao Oct 13, 2025
093b935
[Cute,Bwd,Sm90] Use cp.async.bulk instead of TMA for LSE & dPsum
tridao Oct 13, 2025
9be4a62
[Cute,Bwd,Sm90] Use 1 barrier for loading both K & V
tridao Oct 13, 2025
5576480
[Cute,Bwd,Sm90] Don't clear dK & dV, use zero_init mma flag instead
tridao Oct 13, 2025
5a5a65b
[Cute,Bwd,Sm90] Use TMA to store dK & dV
tridao Oct 13, 2025
66fd2a4
[Cute,Bwd,Sm90] Load K together w Q & LSE in the first iteration
tridao Oct 13, 2025
35384ec
[Cute,Sm90] Move gemm helper functions to hopper_helpers.py
tridao Oct 13, 2025
7c0e373
Swap masking to not use R2P
imbr92 Oct 13, 2025
60eb1ea
Pre-indent to make commit diffs readable
imbr92 Oct 13, 2025
25f5d09
Adding varlen support + tests
imbr92 Oct 13, 2025
b4e5896
Remove self refs in softmax for loop (#1924)
kevin-tong-augment Oct 13, 2025
13afe0d
[Cute,Bwd,Sm90] Make postprocessing kernel work
tridao Oct 13, 2025
d2c8a6c
[Cute] Run ruff format on bwd files
tridao Oct 13, 2025
ee3a533
[CI] Add pre-commit GH action
tridao Oct 13, 2025
93e433b
[Cute,Bwd,Sm90] Try dO_stage=1, PdS_stage=1
tridao Oct 14, 2025
57d0ce9
[Cute,Bwd,Sm90] Make causal work
tridao Oct 14, 2025
89b94f8
[Cute,Bwd,Sm90] Implement dQ_swapAB
tridao Oct 14, 2025
54d8aa6
[Cute,Bwd,Sm90] Implement SdP_swapAB
tridao Oct 14, 2025
72b793a
[AMD] Torch Compile Issues (#1756)
micmelesse Oct 14, 2025
5685ace
[Cute,Bwd,Sm90] Implement mma_dkv_is_rs
tridao Oct 14, 2025
a76e692
[Cute,Bwd,Sm90] Use block size 80x128
tridao Oct 14, 2025
6bc3d1f
[CUTE] Enable Pack GQA for score mods (#1937)
drisspg Oct 15, 2025
04adaf0
Add precommit list and then uncomment in chunks (#1941)
drisspg Oct 15, 2025
48ecd14
[ROCm] prepare CK sources for pytorch hipify v2 APIs (#1944)
jeffdaily Oct 18, 2025
cc843a2
[Cute] Add flake8 config file
tridao Oct 18, 2025
c712d43
[Cute,Fwd,Sm90] Load Q & K using the same mbarrier
tridao Oct 18, 2025
752c263
[Cute,Bwd,Sm90] Use the same producer states if Q_stage == dO_stage
tridao Oct 18, 2025
71ec343
[Cute,Bwd,Sm90] Split sdQaccum layout into 2 warp groups
tridao Oct 18, 2025
7a3a8fe
[Cute,Bwd,Sm90] Implement masking
tridao Oct 19, 2025
75fcbf2
[Cute,Fwd,Sm100] Parse swizzle from pointer, don't need to pass in
tridao Oct 19, 2025
b5e9a71
[Cute,Fwd,Sm100] Clean up
tridao Oct 19, 2025
b4fac7d
[Cute,Fwd,Sm100] Clean up mask
tridao Oct 19, 2025
9c14873
[Cute] Reformat blackwell_helpers.py, block_info.py
tridao Oct 19, 2025
aae355e
[Cute] Format mma_sm100_desc.py, seqlen_info.py
tridao Oct 19, 2025
83eb8d6
sm100 bwd add kernel and update postprocess mask and barriers (#1945)
tzadouri Oct 19, 2025
5fa6e8d
[Cute,Bwd,Sm100] Format flash_bwd_sm100.py and flash_bwd_postprocess
tridao Oct 19, 2025
498bfe6
[Cute,Bwd,Sm100] Rename var {m,n}_block_size->tile_{m,n}
tridao Oct 19, 2025
94f50b0
[Cute,Bwd,Sm100] Clean up a bit
tridao Oct 19, 2025
e925d10
add barrier module (#1946)
tzadouri Oct 19, 2025
d0d8adb
[Cute,Bwd,Sm100] Have a separate function to set up the mma
tridao Oct 19, 2025
796564d
[Cute,Bwd,Sm100] Load LSE with cpasync_bulk
tridao Oct 19, 2025
d0399b6
[Cute,Bwd,Sm100] Load dPsum with cpasync_bulk
tridao Oct 19, 2025
372f3e2
[Cute,Bwd,Sm100] Use copy_utils functions to load Q & dO
tridao Oct 19, 2025
c0c8c2d
[Cute,Bwd,Sm100] Load K & Q, V & dO in the first iteration
tridao Oct 19, 2025
7b17cd8
[Cute,Bwd,Sm100] Simplify mma by using functools.partial
tridao Oct 19, 2025
5c685ea
[Cute,Bwd,Sm100] Don't need q_dk_consumer_state
tridao Oct 19, 2025
8790c6e
[Cute,Bwd,Sm100] Simplify dQacc_reduce, don't need mbarrier
tridao Oct 20, 2025
7254904
[Cute,Bwd,Sm100] Iterate from m_block_min -> m_block_max
tridao Oct 20, 2025
2187695
[Cute,Bwd,Sm100] Try direct atomicadd rmem -> gmem
tridao Oct 20, 2025
12e1c04
[Cute,Bwd,Sm100] Combine pipeline_dK and pipeline_dV into one
tridao Oct 20, 2025
d101fa7
[Cute,Bwd,Sm100] All compute warps wait for lse_barrier
tridao Oct 20, 2025
82c9cbb
[Cute,Bwd,Sm100] sdQaccum doesn't need swizzle
tridao Oct 20, 2025
91f14ca
[Cute,Bwd,Sm100] Try gemm_ptx
tridao Oct 20, 2025
53c884b
[Cute,Bwd,Sm100] Clean up compute fn
tridao Oct 21, 2025
0f56550
[Cute,Bwd,Sm100] Combine pipeline_S and pipeline_P into 1
tridao Oct 21, 2025
22f7daa
[Cute,Bwd,Sm100] Don't shuffle LSE & dPsum, reduce state variables
tridao Oct 21, 2025
3cac07a
[Cute,Bwd,Sm100] Hardcode dS_stage = 1
tridao Oct 21, 2025
f29df7a
[Cute,Bwd,Sm100] Add option for delay tma store
tridao Oct 21, 2025
933b2c3
Fix hopper cuda 13 build (#1949)
kevmo314 Oct 21, 2025
a098f98
[CuteDSL] Fix hash function for cute.jit decorator (#1953)
drisspg Oct 21, 2025
143b0ba
Block Sparsity and Flex Attention mask mod support (#1942)
reubenconducts Oct 21, 2025
16c7f0f
cutlass v4.3.0 (#1952)
johnnynunez Oct 21, 2025
9dbed03
[Cute,Bwd,Sm100] Use CopyBulkG2SOp copy op instead of calling ptx
tridao Oct 21, 2025
1b8e1e6
[Cute,Bwd,Sm100] More cleanup
tridao Oct 22, 2025
e4d25a4
[CuTe DSL] Update "buffers" name to "aux_tensors"; fix flex bugs (#1961)
reubenconducts Oct 24, 2025
3effce8
Fix FA3 segfault with custom CUDA streams in ABI stable build (#1957)
kevmo314 Oct 24, 2025
9450df6
[Cute,Fwd,Sm100] Fix interface w score mod to get it to run
tridao Oct 24, 2025
7ef1a6f
[Cute,Sm100] In gemm ptx, add to base smem_address instead
tridao Oct 24, 2025
b3f437f
[Cute,Bwd,Sm100] Make postprocessing work, add interface
tridao Oct 25, 2025
6eb7c80
[Cute,Bwd,Sm100] Simplify layouts in compute_loop
tridao Oct 25, 2025
93a0afe
[Cute,Bwd,Sm100] Causal mask
tridao Oct 25, 2025
662cf9c
[Cute,Bwd,Sm100] Enable bwd tests
tridao Oct 25, 2025
79b9030
[Cute,Bwd] Enable bwd benchmarks
tridao Oct 25, 2025
510fe92
[Cute] Add store_shared_remote_fp32x4 util function
tridao Oct 26, 2025
b634499
[Cute,Bwd,Sm100] Tune registers
tridao Oct 26, 2025
e873ad0
[Cute,Sm100] acc_tmem_addr is Int32 instead of constexpr
tridao Oct 26, 2025
2c7177d
[Cute,Bwd,Sm100] Reduce sync
tridao Oct 26, 2025
6c56a0c
[Cute] Change utils.view_transpose back
tridao Oct 26, 2025
285bf12
[Cute,Bwd,Sm100] Remove delay_tma_store option
tridao Oct 26, 2025
c59ecd8
[Cute,Bwd,Sm100] Implement cluster
tridao Oct 26, 2025
25e6d94
[Cute] Copy benchmark util functions to cute directory
tridao Oct 27, 2025
53d3a99
[Cute,Bwd,Sm100] Use pipeline class for LSE and dPsum
tridao Oct 28, 2025
a5d545d
[Cute,Bwd,Sm100] Remove stage from sK, sV, tP, sdS
tridao Oct 28, 2025
b3f1b6a
[Cute,Bwd,Sm100] Fix wrong LSE and dPsum indexing in load
tridao Oct 28, 2025
67e8865
[Cute] Blocks tweaks (#1964)
drisspg Oct 28, 2025
7f7a497
[Cute,Bwd,Sm100] Use TS MMA for dK
tridao Oct 28, 2025
b613d9e
[Cute,Blocksparse] Group block sparse input torch tensors
tridao Oct 28, 2025
11336b7
[Cute,Bwd,Sm100] Separate mma_S and mma_dP
tridao Oct 29, 2025
419bdb7
[Cute,Bwd,Sm100] Try LPTBwdScheduler
tridao Oct 29, 2025
de1584b
[Cute,Bwd,Sm100] Try separating warps loading Q and dO
tridao Oct 29, 2025
0256114
BlockSparse Tweaks (#1970)
drisspg Oct 31, 2025
6c9eef9
[Cute] Fix main (#1982)
drisspg Nov 3, 2025
e724e25
[Cute,Fwd,Sm100] Implement SplitKV (#1940)
timmy-feng Nov 5, 2025
ad70a00
[Cute] Extract block-sparse utilities from SM80/90 (#1984)
drisspg Nov 5, 2025
c8abdd4
Enable python-3.10+ (#1998)
drisspg Nov 9, 2025
2ef346b
[Cute, Bwd, Sm100] Add GQA support (#2004)
jayhshah Nov 12, 2025
1338006
[Cute,Fwd,Sm100] fix major regression with split kv (#2006)
jayhshah Nov 12, 2025
16d78bb
[CuTe DSL] Block sparsity computation kernel (#1983)
reubenconducts Nov 12, 2025
fbf24f6
[NVIDIA] bump github actions (#1996)
johnnynunez Nov 13, 2025
5d2cd3b
[Cute,Fwd,Sm100] Support paged attention (#1999)
timmy-feng Nov 14, 2025
c7697bb
Add torch.compile support to flash attention 3
guilhermeleobas Jul 16, 2025
e1944ba
Don't return mutated variables in mha_bwd
guilhermeleobas Jul 24, 2025
a760ca3
Change fake_check flag to be opt-in; Remove build.sh and remove if-el…
guilhermeleobas Jul 25, 2025
24cc2b2
Remove print statements and update exception message
guilhermeleobas Jul 30, 2025
5e114d5
Fix flash_attn_backward_fake
guilhermeleobas Aug 6, 2025
734bc43
Add `safe_aot_autograd_check`
guilhermeleobas Aug 7, 2025
fde4bc0
Update namespace to flash_attn_3
guilhermeleobas Aug 19, 2025
ab79ae2
Add `flash_attn_forward.register_autograd`
guilhermeleobas Aug 22, 2025
6250fbe
Fix bug in `flash_attn_backward_fake`
guilhermeleobas Aug 22, 2025
1e3539e
Add support and tests for torch.export and aoti_compile_and_package
guilhermeleobas Sep 2, 2025
f174bd6
format code
guilhermeleobas Sep 3, 2025
6fe1c8c
update flash_api_stable.cpp
guilhermeleobas Sep 19, 2025
b555ac7
Fix flash_api_stable.cpp build
guilhermeleobas Oct 13, 2025
0aa4fa1
Only run schema_check if dtype is not float8_e4m3fn
guilhermeleobas Oct 13, 2025
47d7137
Correctly compute kBlockM for sm88/86/80
guilhermeleobas Oct 13, 2025
49fb775
Fix bug in boxed_mha_bwd
guilhermeleobas Oct 13, 2025
65dd580
don't run autograd_check when num_splits > 0
guilhermeleobas Nov 12, 2025
b4555bf
[Cute] Add block-sparsity support to SM100 (#1985)
drisspg Nov 18, 2025
43375aa
[Cute,Sm100,Fwd] use correction warps for epi when not using TMA (#2014)
jayhshah Nov 19, 2025
3fcde4b
Raise TypeError if out is specified when compiling _flash_attn_forward
guilhermeleobas Nov 21, 2025
052015a
add fastdivmod for oob reads in mask_mods (#2020)
drisspg Nov 21, 2025
d063b33
don't pass mask_fn to softmax_step generically (#2026)
jayhshah Nov 22, 2025
a986d01
swap order of decorators (#2029)
anakinxc Nov 24, 2025
20cda05
[Cute,Bwd,Sm100] enable deterministic mode for sm100 bwd and fix race…
jayhshah Nov 25, 2025
9194297
[NFC] Trivial fix to silence linter (#1928)
jduprat Nov 25, 2025
5cc6fa4
Add LICENSE and AUTHORS to flash_attn/cute (#2032)
jduprat Nov 25, 2025
63b66f2
[Cute] Add authors
tridao Nov 25, 2025
92ca9da
[Cute,Fwd] enable mask mod without blocksparsity (#2031)
reubenconducts Nov 25, 2025
672381f
Bump pin (#2025)
drisspg Nov 25, 2025
91ba87d
ruff all the smaller files (#2040)
drisspg Dec 2, 2025
de6a6ad
[Flash] Fix head dim 64 bwd (#2035)
drisspg Dec 2, 2025
26ba559
Add headdim64 tests (#2041)
drisspg Dec 2, 2025
59df2f9
Merge pull request #1769 from guilhermeleobas/guilhermeleobas/fa3-com…
v0i0 Dec 4, 2025
56fdf3e
[Cute,Bwd,Sm100] Add local for sm100 bwd (#2046)
jayhshah Dec 6, 2025
0d1ad61
Add hash attr to shortcut expensive check (#2048)
drisspg Dec 7, 2025
6328432
[AMD ROCm] Update to latest composable_kernel to improve performance …
rocking5566 Dec 7, 2025
c783ab2
fixing cute bwd func def (#2056)
liangel-02 Dec 9, 2025
bc0e4ac
Fix use-after-free in FA3 deterministic mode. The pytorch caching all…
skarupke Dec 12, 2025
e240e0f
[CUTE] Allow grads to be preallocated (#2065)
drisspg Dec 15, 2025
fd8d5eb
[Cute,Fwd] Extend score_mod to variable sequence length (#2043)
reubenconducts Dec 15, 2025
179f793
[CUTE] Seeing if tvvm reduces cpu overhead (#2042)
drisspg Dec 15, 2025
0a5339f
[FIRST] Fix softcap scoremod kwargs typo. (#2072)
LeoZDong Dec 16, 2025
ac9b5f1
basics working (#2070)
drisspg Dec 16, 2025
eacbc56
Blocksparse impl (#2085)
drisspg Dec 18, 2025
bba578d
Fix IMA in fwd on m boundary (#2091)
drisspg Dec 20, 2025
ceb4110
Update to dsl 3.4.3 (#2092)
drisspg Dec 22, 2025
5663adf
README for AMD ROCm (#2068)
seungrokj Dec 23, 2025
58fe37f
fix shuffle sync for pack gqa epilogue (#2097)
jayhshah Dec 24, 2025
11b32fd
improve paged cpasync
v0i0 Dec 24, 2025
d234051
Enable Thor (#2108)
johnnynunez Dec 29, 2025
4fd123e
[Cute] Add quack as dependency
tridao Dec 31, 2025
f3423a8
[Cute,Fwd,Sm90] Change PipelineTMAAsync sublass to signal per warp
tridao Jan 1, 2026
9b6dbac
Add pack-gqa support for blcoksparse impl w/ braodcasted H dim (#2098)
drisspg Jan 4, 2026
f98d345
[Cute,Fwd] improved block sparsity (#2100)
reubenconducts Jan 5, 2026
bb2efb3
[Cute] Fix minor lint issue in shuffle_sync
tridao Jan 5, 2026
f472175
Misc tests that should be xfailed for now (#2127)
drisspg Jan 5, 2026
3e87e42
Update cutlass to fix undefined symbol: cuDriverGetVersion. (#2142)
HydraQYH Jan 7, 2026
3c8ca4e
[Cute,Fwd,Sm100] Support `q_stage=1` for inference (#1993)
timmy-feng Jan 8, 2026
6dd7e74
[Cute] Fix two tests that were failing (#2149)
henrylhtsang Jan 8, 2026
c15ffe3
cleanup
v0i0 Jan 8, 2026
ed6a82f
[Cute, Bwd, Sm100] Add varlen for sm100 bwd (#2150)
jayhshah Jan 9, 2026
27a3b54
block-sparse backward SM90 (#2136)
drisspg Jan 10, 2026
844b10f
score-mod backward SM90 (#2137)
drisspg Jan 10, 2026
e317aa4
[Cute] Clarify and fix subtle cachekey bug (#2143)
drisspg Jan 10, 2026
26d4ee9
[CUTE][SM100] Fix backward gqa on sm100 post mask-mod semantic change…
drisspg Jan 10, 2026
8eff546
[CUTE][SM90]Enable pack-gqa with broadcasted maskmods (#2145)
drisspg Jan 10, 2026
5d4c953
[CUTE][SM90] GQA backward non deterministic (#2158)
drisspg Jan 10, 2026
ea8f735
[Cute,Bwd,Sm100] fix seqused in varlen bwd (#2167)
jayhshah Jan 10, 2026
ef7343b
[CUTE] Bump cutedsl to 4.3.5 (#2170)
drisspg Jan 12, 2026
dbf08eb
Merge pull request #2156 from v0i0/v0i0/improve-paged-ldgsts
v0i0 Jan 12, 2026
4cb272e
[Cute,Flex] Add option to create and cache __cute_hash__ (#2171)
reubenconducts Jan 12, 2026
4894657
[Cute][Flex] Remove no longer needed contig (#2172)
drisspg Jan 12, 2026
13696f2
[Cute] update row_max before safe overwrite for online_softmax (#2174)
jayhshah Jan 13, 2026
506441a
[Cute][Flex] add back in contig (#2177)
drisspg Jan 15, 2026
68649fb
[Cute][Flex]Add pack-gqa divmod (#2180)
drisspg Jan 15, 2026
88067b0
baseline local flops
henrylhtsang Jan 15, 2026
fffabc3
[Cute,Fwd,Sm100] distributed offset calculation for paged KV (#2104)
timmy-feng Jan 15, 2026
a512bd8
Add R2P dual bound masking for local attention
henrylhtsang Jan 15, 2026
2020964
remove benchmark result, undo changes to benchmark
henrylhtsang Jan 15, 2026
7108d1c
Add R2P dual bound masking for local attention
henrylhtsang Jan 15, 2026
e4ec1ad
switch from xor to mask_right & ~ mask_left
henrylhtsang Jan 16, 2026
ac88858
flip in_bound to out_bound
henrylhtsang Jan 16, 2026
e34d840
remove zero logic for right_s and left_s
henrylhtsang Jan 16, 2026
08e6518
remove 24 clamp
henrylhtsang Jan 16, 2026
94f0348
doc
henrylhtsang Jan 16, 2026
e94012a
lint
henrylhtsang Jan 16, 2026
2e6ae05
added back clamp to avoid "OverflowError: Python int too large to con…
henrylhtsang Jan 16, 2026
137ad8e
add comment
henrylhtsang Jan 16, 2026
2d6b146
Merge pull request #2185 from henrylhtsang/test_local_r2p
v0i0 Jan 17, 2026
a0f9f41
[Cute][Flex] Fix expanded tensor bug (#2189)
drisspg Jan 17, 2026
04e6ee1
[Cute, SM90] fix fwd varlen Cute implementation bug for H100 (#2194)
KareemMusleh Jan 20, 2026
f15ccf5
reduce chance of build oom (#2079)
Qubitium Jan 21, 2026
2580b5a
[Cute][Flex] Allow q_offset 1 and add block-sizes to disambiguate edg…
drisspg Jan 22, 2026
57cef6c
ci: Use 1 ninja job for cu13 (#2195)
ko3n1g Jan 24, 2026
438325c
Update README to include 'psutil' package as build requirement (#2210)
wanglc02 Jan 25, 2026
4f89246
[Flex][SM100] Replay expand fix on sm100 (#2209)
drisspg Jan 26, 2026
99589e5
[DSL] Optionally patch cute-dsl to use system's ptxas
tridao Jan 27, 2026
701ebe0
[AMD] Triton Backend for ROCm #3 (#2178)
micmelesse Jan 28, 2026
514e63c
fix compute_block_sparsity usage in benchmark_mask_mod (#2221)
zhuochenKIDD Feb 2, 2026
188643b
Fix shared-memory race (#2229)
drisspg Feb 4, 2026
ef9e6a6
Use TORCH_TARGET_VERSION over TORCH_STABLE_ONLY (#2155)
janeyx99 Feb 4, 2026
24445c0
short readme for flex flash (#2231)
v0i0 Feb 5, 2026
e2743ab
[FA3] Mark current main version as v3.0.0 stable (#2223)
lw Feb 5, 2026
f1284cf
hdim 192 smem fix (#2235)
jayhshah Feb 5, 2026
912c6c4
Add `FLASH_ATTENTION_TRITON_AMD_CONFIG_JSON` env var support (#2239)
alexheretic Feb 7, 2026
abaa878
[CUTE]Bump to Cutedsl (#2216)
drisspg Feb 8, 2026
48af662
pytest-dist round robin to gpus (#2241)
drisspg Feb 8, 2026
a804a5a
[DSL] Replace old fence with cute.arch.fence_view_async_shared()
tridao Feb 8, 2026
5a66f2c
[DSL]Replace utils.{fma,mul,add}_packed_f32x2 with cute.arch version
tridao Feb 8, 2026
d39b629
[DSL] Remove coord_offset_i64, domain_offset_i64, elem_pointer_i64
tridao Feb 8, 2026
81f2c2d
[Sm90] Use functions from quack.sm90_utils
tridao Feb 8, 2026
7edcf59
[DSL] Use cute.arch.warp_reduction_{max,sum}
tridao Feb 8, 2026
b735ef2
[Layout] Use reshape_acc_to_mn and reshape_acc_to_frgA from quack
tridao Feb 8, 2026
8dd8019
[Layout] Use quack.layout_utils.mma_partition_C_vec
tridao Feb 8, 2026
90f10fa
[DSL] Use cute.math.{exp2,log2,log}
tridao Feb 8, 2026
b9148ce
[Layout] Use layout_utils.transpose_view and select from quack
tridao Feb 8, 2026
c912a37
[Bwd,Sm90] Use quack.copy_utils
tridao Feb 8, 2026
deb1830
[Bwd,Sm100] Shorten PipelineTmaUmma create
tridao Feb 8, 2026
17d2943
[Bwd,Sm90] Have score_mod and score_mod_bwd as partial functions
tridao Feb 8, 2026
2a8d39c
[DSL] warpgroup_reg_alloc -> setmaxregister_increase
tridao Feb 8, 2026
72c7ba4
Fix Hopper tests (#2242)
drisspg Feb 8, 2026
fc9e426
Merge remote-tracking branch 'upstream/main' into merge_upstream
MatthewBonanni Feb 11, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/_build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ jobs:
# Limit MAX_JOBS otherwise the github runner goes OOM
# nvcc 11.8 can compile with 2 jobs, but nvcc 12.3 goes OOM

export MAX_JOBS=$([ "$MATRIX_CUDA_VERSION" == "129" ] && echo 1 || echo 2)
export MAX_JOBS=$([ "$MATRIX_CUDA_VERSION" == "129" ] || [ "$MATRIX_CUDA_VERSION" == "130" ] && echo 1 || echo 2)
export NVCC_THREADS=2
export FLASH_ATTENTION_FORCE_BUILD="TRUE"
export FLASH_ATTENTION_FORCE_CXX11_ABI=${{ inputs.cxx11_abi }}
Expand Down
71 changes: 25 additions & 46 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ flash_attn_interface.flash_attn_func()
- CUDA toolkit or ROCm toolkit
- PyTorch 2.2 and above.
- `packaging` Python package (`pip install packaging`)
- `psutil` Python package (`pip install psutil`)
- `ninja` Python package (`pip install ninja`) *
- Linux. Might work for Windows starting v2.3.2 (we've seen a few positive [reports](https://github.com/Dao-AILab/flash-attention/issues/595)) but Windows compilation still requires more testing. If you have ideas on how to set up prebuilt CUDA wheels for Windows, please reach out via Github issue.

Expand Down Expand Up @@ -128,74 +129,52 @@ FlashAttention-2 ROCm CK backend currently supports:
3. Both forward's and backward's head dimensions up to 256.

#### Triton Backend
The Triton implementation of the [Flash Attention v2](https://tridao.me/publications/flash2/flash2.pdf) is currently a work in progress.
The Triton implementation of [Flash Attention](https://tridao.me/publications/flash2/flash2.pdf) supports AMD's CDNA (MI200, MI300) and RDNA GPUs using fp16, bf16, and fp32 datatypes. It provides forward and backward passes with causal masking, variable sequence lengths, arbitrary Q/KV sequence lengths and head sizes, MQA/GQA, dropout, rotary embeddings, ALiBi, paged attention, and FP8 (via the Flash Attention v3 interface). Sliding window attention is currently a work in progress.

It supports AMD's CDNA (MI200, MI300) and RDNA GPU's using fp16, bf16 and fp32 datatypes.

These features are supported in Fwd and Bwd
1) Fwd and Bwd with causal masking
2) Variable sequence lengths
3) Arbitrary Q and KV sequence lengths
4) Arbitrary head sizes
5) Multi and grouped query attention
6) Dropout
7) Rotary embeddings
8) ALiBi

We are working on the following things
1) Paged Attention
2) Sliding Window
3) FP8
4) Performance Improvements

##### Getting Started
To get started with the triton backend for AMD, follow the steps below.

First install the torch for ROCm from https://pytorch.org/get-started/locally/ if it is not installed. The torch and triton will be installed.

Then install Flash Attention with the flag `FLASH_ATTENTION_TRITON_AMD_ENABLE` set to `"TRUE"`.

```
To install, first get PyTorch for ROCm from https://pytorch.org/get-started/locally/, then install Triton and Flash Attention:
```sh
pip install triton==3.5.1
cd flash-attention
FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" python setup.py install
```

To test that things are working, you can run our tests. These tests take hours so you don't need to run the full thing.
```
To run the tests (note: full suite takes hours):
```sh
FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" pytest tests/test_flash_attn_triton_amd.py
```

You can use autotune for better performance by using this flag `FLASH_ATTENTION_TRITON_AMD_AUTOTUNE="TRUE"`
```
FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" FLASH_ATTENTION_TRITON_AMD_AUTOTUNE="TRUE" python $PATH_TO_CODE
```
For better performance, enable autotune with `FLASH_ATTENTION_TRITON_AMD_AUTOTUNE="TRUE"`.

###### Docker
You can also use the Dockerfile below which does the above steps on top of the latest rocm/pytorch image.
Alternativly, if _not_ autotuning, `FLASH_ATTENTION_FWD_TRITON_AMD_CONFIG_JSON` may be used to set a single triton config overriding the hardcoded defaults for `attn_fwd`. E.g.
```sh
FLASH_ATTENTION_FWD_TRITON_AMD_CONFIG_JSON='{"BLOCK_M":128,"BLOCK_N":64,"waves_per_eu":1,"PRE_LOAD_V":false,"num_stages":1,"num_warps":8}'
```

For a quick start with Docker:
```dockerfile
FROM rocm/pytorch:latest

WORKDIR /workspace

# install flash attention
ENV FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
# install triton
RUN pip install triton==3.5.1

RUN git clone https://github.com/ROCm/flash-attention.git &&\
# build flash attention with triton backend
RUN git clone https://github.com/Dao-AILab/flash-attention &&\
cd flash-attention &&\
python setup.py install
FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" python setup.py install

# set working dir
WORKDIR /workspace/flash-attention
```

To build the docker file
```
docker build -t fa_triton .
# set env variable to use triton backend
ENV FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
```

To run the docker image
```
docker run -it --network=host --user root --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --ipc=host --shm-size 16G --device=/dev/kfd --device=/dev/dri fa_triton
Build and run:
```sh
docker build -t flash-attn-triton .
docker run -it --network=host --user root --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --ipc=host --shm-size 16G --device=/dev/kfd --device=/dev/dri flash-attn-triton
```

## How to use FlashAttention
Expand Down
2 changes: 1 addition & 1 deletion csrc/cutlass
Submodule cutlass updated 1240 files
26 changes: 26 additions & 0 deletions flash_attn/cute/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Flash Attention CUTE

## Development Installation

1. Clone the repository (if you haven't already):
```bash
git clone https://github.com/Dao-AILab/flash-attention.git
cd flash-attention/cute
```

2. Install in editable mode with dev dependencies:
```bash
pip install -e "./cute[dev]"
```

## Running Tests

```bash
pytest tests/cute/
```

## Linting

```bash
ruff check flash_attn/cute/
```
41 changes: 13 additions & 28 deletions flash_attn/cute/block_sparse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

# Import data structures from block_sparsity
from flash_attn.cute.block_sparsity import BlockSparseTensors
from flash_attn.cute import utils
from flash_attn.cute import copy_utils
from flash_attn.cute.named_barrier import NamedBarrierBwd

Expand Down Expand Up @@ -698,14 +697,14 @@ def handle_block_sparse_empty_tile_correction_sm100(
row_max_value = sink_val * (LOG2_E / softmax_scale_log2)
row_sum_value = Float32(1.0)
else:
row_sum_value = row_sum_value + utils.exp2f(
sink_val * LOG2_E - row_max_value * softmax_scale_log2
row_sum_value = row_sum_value + cute.math.exp2(
sink_val * LOG2_E - row_max_value * softmax_scale_log2, fastmath=True
)
if tidx < m_block_size:
scale_row_idx = tidx + stage * m_block_size
sScale[scale_row_idx] = row_sum_value
if const_expr(mLSE is not None or learnable_sink is not None):
sScale[scale_row_idx + m_block_size * 2] = row_max_value
sScale[scale_row_idx + q_stage * m_block_size] = row_max_value
acc_flag = row_sum_value == Float32(0.0) or row_sum_value != row_sum_value
stats[stage] = (row_sum_value, row_max_value, acc_flag)

Expand Down Expand Up @@ -1123,8 +1122,7 @@ def _load_q_do_block_sm90(
else:
pipeline_Q.producer_acquire(producer_state_Q)
load_Q(m_block, producer_state=producer_state_Q)
with cute.arch.elect_one():
load_LSE(m_block, producer_state=producer_state_Q)
load_LSE(m_block, producer_state=producer_state_Q)

producer_state_dO_cur = (
producer_state_dO if const_expr(not Q_stage_eq_dO_stage) else producer_state_Q
Expand All @@ -1135,8 +1133,7 @@ def _load_q_do_block_sm90(
else:
pipeline_dO.producer_acquire(producer_state_dO_cur)
load_dO(m_block, producer_state=producer_state_dO_cur)
with cute.arch.elect_one():
load_dPsum(m_block, producer_state=producer_state_dO_cur)
load_dPsum(m_block, producer_state=producer_state_dO_cur)

producer_state_Q.advance()
producer_state_dO.advance()
Expand Down Expand Up @@ -1253,10 +1250,10 @@ def consume_block_sparse_mma_bwd_sm90(
is_causal: cutlass.Constexpr,
is_local: cutlass.Constexpr,
thr_mma_SdP,
softmax_scale,
seqlen,
subtile_factor: cutlass.Constexpr,
m_block_max: int,
score_mod_fn=None,
score_mod_bwd_fn=None,
subtile_factor: cutlass.Constexpr = 1,
m_block_max: int = 0,
aux_tensors=None,
fastdiv_mods=(None, None),
):
Expand Down Expand Up @@ -1318,15 +1315,9 @@ def consume_block_sparse_mma_bwd_sm90(
consumer_state_Q,
consumer_state_dO,
mask_fn=mask_fn_partial,
score_mod_fn=score_mod_fn,
score_mod_bwd_fn=score_mod_bwd_fn,
dKV_accumulate=dKV_accumulate,
thr_mma_SdP=thr_mma_SdP,
batch_idx=batch_idx,
head_idx=head_idx,
n_block=n_block,
softmax_scale=softmax_scale,
seqlen=seqlen,
aux_tensors=aux_tensors,
fastdiv_mods=fastdiv_mods,
)
dKV_accumulate = True

Expand All @@ -1342,15 +1333,9 @@ def consume_block_sparse_mma_bwd_sm90(
consumer_state_Q,
consumer_state_dO,
mask_fn=mask_fn_full,
score_mod_fn=score_mod_fn,
score_mod_bwd_fn=score_mod_bwd_fn,
dKV_accumulate=dKV_accumulate,
thr_mma_SdP=thr_mma_SdP,
batch_idx=batch_idx,
head_idx=head_idx,
n_block=n_block,
softmax_scale=softmax_scale,
seqlen=seqlen,
aux_tensors=aux_tensors,
fastdiv_mods=fastdiv_mods,
)
dKV_accumulate = True

Expand Down
151 changes: 151 additions & 0 deletions flash_attn/cute/cute_dsl_ptxas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
"""
System ptxas replacement for CUTLASS DSL.
Environment variables:
CUTE_DSL_PTXAS_PATH - Path to ptxas (e.g., /usr/local/cuda/bin/ptxas)
CUTE_DSL_PTXAS_VERBOSE - Set to 1 for verbose output
"""

import os
import sys
import re
import ctypes
import subprocess
from pathlib import Path

import cutlass


CUTE_DSL_PTXAS_PATH = os.environ.get("CUTE_DSL_PTXAS_PATH", None)
VERBOSE = os.environ.get("CUTE_DSL_PTXAS_VERBOSE", "0") == "1"

_original_load_cuda_library = None
_user_wanted_ptx = False # True if user originally set CUTE_DSL_KEEP_PTX=1


def _log(msg):
if VERBOSE:
print(f"[ptxas] {msg}", file=sys.stderr)


def _get_ptx(compiled_func) -> tuple[str, Path] | None:
"""Find and read PTX file, stripping null bytes."""
func_name = getattr(compiled_func, "function_name", None)
if not func_name:
return None

dump_dir = os.environ.get("CUTE_DSL_DUMP_DIR", Path.cwd())
for ptx_path in Path(dump_dir).glob(f"*{func_name}*.ptx"):
content = ptx_path.read_text().rstrip("\x00")
if ".entry " in content and content.rstrip().endswith("}"):
_log(f"Found PTX: {ptx_path}")
return content, ptx_path
return None


def _compile_ptx(ptx_path: Path, ptx_content: str) -> bytes:
"""Compile PTX to cubin using system ptxas."""
# Extract arch from PTX
match = re.search(r"\.target\s+(sm_\d+[a-z]?)", ptx_content)
arch = match.group(1) if match else "sm_90a"

# Write stripped content back if needed
if ptx_path.read_text() != ptx_content:
ptx_path.write_text(ptx_content)

# Compile
cubin_tmp = ptx_path.with_suffix(".cubin.tmp")
try:
assert CUTE_DSL_PTXAS_PATH is not None
result = subprocess.run(
[CUTE_DSL_PTXAS_PATH, f"-arch={arch}", "-O3", "-o", str(cubin_tmp), str(ptx_path)],
capture_output=True,
text=True,
)
if result.returncode != 0:
raise RuntimeError(f"ptxas failed: {result.stderr}")

cubin_data = cubin_tmp.read_bytes()
_log(f"Compiled {ptx_path.name} -> {len(cubin_data)} bytes ({arch})")

# Save cubin if CUTE_DSL_KEEP_CUBIN is set
if os.environ.get("CUTE_DSL_KEEP_CUBIN", "0") == "1":
cubin_out = ptx_path.with_suffix(".cubin")
cubin_out.write_bytes(cubin_data)
_log(f"Saved: {cubin_out}")

return cubin_data
finally:
cubin_tmp.unlink(missing_ok=True)


def _patched_load_cuda_library(self):
"""Replacement for _load_cuda_library that uses system ptxas."""

result = _get_ptx(self)
if not result:
_log("PTX not found, falling back to embedded ptxas")
return _original_load_cuda_library(self)

ptx_content, ptx_path = result

try:
cubin = _compile_ptx(ptx_path, ptx_content)
except Exception as e:
_log(f"Compilation failed ({e}), falling back to embedded ptxas")
return _original_load_cuda_library(self)

# Load cubin
import cuda.bindings.runtime as cuda_runtime

err, library = cuda_runtime.cudaLibraryLoadData(cubin, None, None, 0, None, None, 0)
if err != cuda_runtime.cudaError_t.cudaSuccess:
_log(f"cudaLibraryLoadData failed ({err}), falling back to embedded ptxas")
return _original_load_cuda_library(self)

# Register kernels on all devices
_, cuda_load_to_device = self._get_cuda_init_and_load()
lib_ptr = ctypes.c_void_p(int(library))
dev_id = ctypes.c_int32(0)
err_val = ctypes.c_int32(0)
args = (ctypes.c_void_p * 3)(
ctypes.cast(ctypes.pointer(lib_ptr), ctypes.c_void_p),
ctypes.cast(ctypes.pointer(dev_id), ctypes.c_void_p),
ctypes.cast(ctypes.pointer(err_val), ctypes.c_void_p),
)

for dev in range(self.num_devices):
dev_id.value = dev
cuda_load_to_device(args)
if err_val.value != 0:
_log("cuda_load_to_device failed, falling back to embedded ptxas")
return _original_load_cuda_library(self)

_log(f"Loaded kernel from {ptx_path.name}")

# Delete PTX if user didn't originally want it kept
if not _user_wanted_ptx:
ptx_path.unlink(missing_ok=True)

return [cuda_runtime.cudaLibrary_t(lib_ptr.value)]


def patch():
"""Install system ptxas hook. Call before importing cutlass."""
global _original_load_cuda_library, _user_wanted_ptx

assert CUTE_DSL_PTXAS_PATH is not None
if not os.path.isfile(CUTE_DSL_PTXAS_PATH) or not os.access(CUTE_DSL_PTXAS_PATH, os.X_OK):
raise RuntimeError(f"ptxas not found: {CUTE_DSL_PTXAS_PATH}")

# Track if user originally wanted PTX kept
_user_wanted_ptx = os.environ.get("CUTE_DSL_KEEP_PTX", "0") == "1"
# os.environ['CUTE_DSL_KEEP_PTX'] = '1'
assert os.environ.get("CUTE_DSL_KEEP_PTX", "0") == "1", (
"Require CUTE_DSL_KEEP_PTX=1 to use system's ptxas"
)

cls = cutlass.cutlass_dsl.cuda_jit_executor.CudaDialectJitCompiledFunction
_original_load_cuda_library = cls._load_cuda_library
cls._load_cuda_library = _patched_load_cuda_library
_log("Patch applied")
return
Loading