Skip to content

Commit 9fde31b

Browse files
authored
[Cute,Sm100,Fwd] use correction warps for epi when not using TMA (Dao-AILab#2014)
* use correction warps for epi when varlen (non tma O) * properly enable fallback epilogue for varlen q * fix rebase errors * update tests
1 parent bf3fc8a commit 9fde31b

4 files changed

Lines changed: 158 additions & 52 deletions

File tree

flash_attn/cute/block_sparse_utils.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
These utilities are used by CUTE DSL kernels to produce and consume block-sparse loads.
66
"""
77

8-
from typing import Callable
8+
from typing import Callable, Optional
99
from functools import partial
1010
import math
1111
import cutlass
@@ -606,6 +606,9 @@ def handle_block_sparse_empty_tile_correction_sm100(
606606
o_corr_consumer_phase: Int32,
607607
corr_epi_producer_phase: Int32,
608608
softmax_scale_log2: Float32,
609+
mO_cur: Optional[cute.Tensor] = None,
610+
gO: Optional[cute.Tensor] = None,
611+
gmem_tiled_copy_O: Optional[cute.TiledCopy] = None,
609612
):
610613
"""Handle the block-sparse case where a tile is fully masked:
611614
* zero staged results
@@ -650,18 +653,26 @@ def handle_block_sparse_empty_tile_correction_sm100(
650653
)
651654
cute.arch.mbarrier_arrive(mbar_ptr + mbar_softmax_corr_empty_offset + stage)
652655

653-
cute.arch.mbarrier_wait(
654-
mbar_ptr + mbar_corr_epi_empty_offset + stage,
655-
corr_epi_producer_phase,
656-
)
656+
if const_expr(gmem_tiled_copy_O is None):
657+
cute.arch.mbarrier_wait(
658+
mbar_ptr + mbar_corr_epi_empty_offset + stage,
659+
corr_epi_producer_phase,
660+
)
657661
correction_epilogue(
658662
thr_mma_pv,
659663
tOtOs[stage],
660664
tidx,
665+
stage,
666+
m_block,
667+
seqlen.seqlen_q,
661668
Float32(0.0), # zero scale ensures empty tile writes zeros into staged outputs
662669
sO[None, None, stage],
670+
mO_cur,
671+
gO,
672+
gmem_tiled_copy_O,
663673
)
664-
cute.arch.mbarrier_arrive(mbar_ptr + mbar_corr_epi_full_offset + stage)
674+
if const_expr(gmem_tiled_copy_O is None):
675+
cute.arch.mbarrier_arrive(mbar_ptr + mbar_corr_epi_full_offset + stage)
665676
cute.arch.mbarrier_arrive(mbar_ptr + mbar_P_full_O_rescaled_offset + stage)
666677
cute.arch.mbarrier_arrive(mbar_ptr + mbar_P_full_2_offset + stage)
667678

flash_attn/cute/flash_fwd_sm100.py

Lines changed: 123 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@
5656
)
5757

5858

59-
# class NamedBarrierFwd(enum.IntEnum):
60-
# Epilogue = enum.auto() # starts from 1 as barrier 0 is reserved for sync_threads()
59+
class NamedBarrierFwd(enum.IntEnum):
60+
Epilogue = enum.auto() # starts from 1 as barrier 0 is reserved for sync_threads()
6161
# WarpSchedulerWG1 = enum.auto()
6262
# WarpSchedulerWG2 = enum.auto()
6363
# WarpSchedulerWG3 = enum.auto()
@@ -85,6 +85,7 @@ def __init__(
8585
mask_mod: cutlass.Constexpr | None = None,
8686
has_aux_tensors: cutlass.Constexpr = False,
8787
paged_kv_non_tma: bool = False,
88+
is_varlen_q: bool = False,
8889
):
8990
self.use_tma_KV = not paged_kv_non_tma
9091
# self.dtype = dtype
@@ -112,6 +113,8 @@ def __init__(
112113
self.is_persistent = is_persistent
113114
self.is_causal = is_causal
114115
self.is_local = is_local
116+
self.is_varlen_q = is_varlen_q
117+
self.use_correction_warps_for_epi = is_varlen_q
115118
self.qhead_per_kvhead = qhead_per_kvhead
116119
self.is_split_kv = is_split_kv
117120
self.pack_gqa = pack_gqa
@@ -146,8 +149,8 @@ def __init__(
146149
self.softmax1_warp_ids = (4, 5, 6, 7)
147150
self.correction_warp_ids = (8, 9, 10, 11)
148151
self.mma_warp_id = 12
149-
self.load_warp_ids = (13,)
150-
self.epilogue_warp_ids = (14,)
152+
self.epilogue_warp_ids = (13,)
153+
self.load_warp_ids = (14,)
151154
self.empty_warp_ids = (15,)
152155
SM100_TMEM_CAPACITY_COLUMNS = 512
153156
self.tmem_alloc_cols = SM100_TMEM_CAPACITY_COLUMNS
@@ -164,6 +167,15 @@ def __init__(
164167
)
165168
)
166169

170+
if not self.use_tma_KV:
171+
self.load_warp_ids = (14, 15)
172+
self.empty_warp_ids = ()
173+
if self.use_correction_warps_for_epi:
174+
self.empty_warp_ids = self.empty_warp_ids + self.epilogue_warp_ids
175+
self.epilogue_warp_ids = self.correction_warp_ids
176+
elif self.is_varlen_q: # fallback
177+
self.epilogue_warp_ids = (13, 14)
178+
167179
self.tmem_s_offset = [0, self.n_block_size] # e.g., 0, 128
168180
self.tmem_o_offset = [
169181
self.tmem_s_offset[-1] + self.n_block_size + i * self.head_dim_v_padded
@@ -506,19 +518,11 @@ def __call__(
506518
self.cluster_layout_vmnk.shape,
507519
)
508520
else:
509-
assert self.use_tma_O, "Loading O and K/V will contend for the empty warp."
510-
self.epilogue_warp_ids = (13,)
511-
self.load_warp_ids = (14, 15)
512-
self.empty_warp_ids = ()
513521
tma_atom_K = None
514522
tma_atom_V = None
515523

516524
o_cta_v_layout = cute.composition(cute.make_identity_layout(mO.shape), self.epi_tile)
517525

518-
# print(sO_layout.outer)
519-
if const_expr(not self.use_tma_O):
520-
self.epilogue_warp_ids = (14, 15)
521-
self.empty_warp_ids = ()
522526
self.num_epilogue_threads = cute.arch.WARP_SIZE * len(self.epilogue_warp_ids)
523527
if const_expr(self.use_tma_O):
524528
tma_atom_O, mO = cpasync.make_tiled_tma_atom(
@@ -546,7 +550,6 @@ def __call__(
546550
assert self.m_block_size % tO_layout.shape[0] == 0
547551
vO_layout = cute.make_layout((1, async_copy_elems))
548552
gmem_tiled_copy_O = cute.make_tiled_copy_tv(atom_universal_copy, tO_layout, vO_layout)
549-
print("gmem_tiled_copy_O: ", gmem_tiled_copy_O)
550553

551554
if const_expr(mCuSeqlensQ is not None or mSeqUsedQ is not None):
552555
TileScheduler = SingleTileVarlenScheduler
@@ -799,7 +802,7 @@ def kernel(
799802
cute.arch.mbarrier_init(
800803
mbar_ptr + self.mbar_s0_s1_sequence_offset + i, cute.arch.WARP_SIZE
801804
)
802-
if warp_idx == 4:
805+
if const_expr(not self.use_correction_warps_for_epi) and warp_idx == 4:
803806
for i in cutlass.range_constexpr(self.q_stage):
804807
cute.arch.mbarrier_init(
805808
mbar_ptr + self.mbar_corr_epi_full_offset + i,
@@ -931,6 +934,12 @@ def kernel(
931934
if warp_idx == self.empty_warp_ids[0]:
932935
cute.arch.warpgroup_reg_dealloc(self.num_regs_empty)
933936

937+
if const_expr(len(self.empty_warp_ids) > 1):
938+
if warp_idx == self.empty_warp_ids[1]:
939+
cute.arch.warpgroup_reg_dealloc(self.num_regs_empty)
940+
941+
assert len(self.empty_warp_ids) <= 2
942+
934943
# ///////////////////////////////////////////////////////////////////////////////
935944
# LOAD
936945
# ///////////////////////////////////////////////////////////////////////////////
@@ -1004,19 +1013,20 @@ def kernel(
10041013
# ///////////////////////////////////////////////////////////////////////////////
10051014
# Epilogue
10061015
# ///////////////////////////////////////////////////////////////////////////////
1007-
if warp_idx >= self.epilogue_warp_ids[0] and warp_idx <= self.epilogue_warp_ids[-1]:
1008-
cute.arch.warpgroup_reg_dealloc(self.num_regs_other)
1009-
self.epilogue_s2g(
1010-
mO,
1011-
sO,
1012-
gmem_tiled_copy_O,
1013-
tma_atom_O,
1014-
mbar_ptr,
1015-
block_info,
1016-
num_splits,
1017-
SeqlenInfoCls,
1018-
TileSchedulerCls,
1019-
)
1016+
if const_expr(not self.use_correction_warps_for_epi):
1017+
if warp_idx >= self.epilogue_warp_ids[0] and warp_idx <= self.epilogue_warp_ids[-1]:
1018+
cute.arch.warpgroup_reg_dealloc(self.num_regs_other)
1019+
self.epilogue_s2g(
1020+
mO,
1021+
sO,
1022+
gmem_tiled_copy_O,
1023+
tma_atom_O,
1024+
mbar_ptr,
1025+
block_info,
1026+
num_splits,
1027+
SeqlenInfoCls,
1028+
TileSchedulerCls,
1029+
)
10201030

10211031
# ///////////////////////////////////////////////////////////////////////////////
10221032
# Softmax
@@ -1080,6 +1090,7 @@ def kernel(
10801090
mLSE,
10811091
sO,
10821092
learnable_sink,
1093+
gmem_tiled_copy_O,
10831094
tma_atom_O,
10841095
mbar_ptr,
10851096
softmax_scale_log2,
@@ -1931,6 +1942,7 @@ def correction_loop(
19311942
mLSE: cute.Tensor,
19321943
sO: cute.Tensor,
19331944
learnable_sink: Optional[cute.Tensor],
1945+
gmem_tiled_copy_O: cute.TiledCopy,
19341946
tma_atom_O: cute.CopyAtom,
19351947
mbar_ptr: cute.Pointer,
19361948
softmax_scale_log2: Float32,
@@ -1972,6 +1984,12 @@ def correction_loop(
19721984
seqlen = SeqlenInfoCls(batch_idx)
19731985
n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits)
19741986

1987+
if const_expr(self.is_split_kv):
1988+
mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx, split_idx]
1989+
else:
1990+
mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx]
1991+
gO = cute.local_tile(mO_cur, (self.m_block_size, self.head_dim_v_padded), (None, 0))
1992+
19751993
# Default LSE to -inf for invalid split_idx tiles
19761994
stats = [(0.0, -Float32.inf if const_expr(mLSE is not None or learnable_sink is not None) else None, True)] * self.q_stage
19771995

@@ -2070,17 +2088,25 @@ def correction_loop(
20702088
cute.arch.mbarrier_wait(
20712089
mbar_ptr + self.mbar_O_full_offset + stage, o_corr_consumer_phase
20722090
)
2073-
cute.arch.mbarrier_wait(
2074-
mbar_ptr + self.mbar_corr_epi_empty_offset + stage, corr_epi_producer_phase
2075-
)
2091+
if const_expr(not self.use_correction_warps_for_epi):
2092+
cute.arch.mbarrier_wait(
2093+
mbar_ptr + self.mbar_corr_epi_empty_offset + stage, corr_epi_producer_phase
2094+
)
20762095
self.correction_epilogue(
20772096
thr_mma_pv,
20782097
tOtOs[stage],
20792098
tidx,
2099+
stage,
2100+
m_block,
2101+
seqlen.seqlen_q,
20802102
scale,
20812103
sO[None, None, stage],
2104+
mO_cur,
2105+
gO,
2106+
gmem_tiled_copy_O,
20822107
)
2083-
cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_corr_epi_full_offset + stage)
2108+
if const_expr(not self.use_correction_warps_for_epi):
2109+
cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_corr_epi_full_offset + stage)
20842110
# Signal for the next work tile that O buffers in tmem are already read, so
20852111
# mma warp can write to them
20862112
cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage)
@@ -2090,6 +2116,11 @@ def correction_loop(
20902116
softmax_corr_consumer_phase ^= 1
20912117
corr_epi_producer_phase ^= 1
20922118
else:
2119+
# WARNING: we need some code before the const_expr, see https://github.com/NVIDIA/cutlass/issues/2781
2120+
if const_expr(self.use_correction_warps_for_epi):
2121+
gmem_tiled_copy_O_for_empty_tile = gmem_tiled_copy_O
2122+
else:
2123+
gmem_tiled_copy_O_for_empty_tile = None
20932124
if const_expr(self.use_block_sparsity):
20942125
(
20952126
softmax_corr_consumer_phase,
@@ -2126,6 +2157,9 @@ def correction_loop(
21262157
o_corr_consumer_phase,
21272158
corr_epi_producer_phase,
21282159
softmax_scale_log2,
2160+
mO_cur,
2161+
gO,
2162+
gmem_tiled_copy_O_for_empty_tile,
21292163
)
21302164

21312165
if const_expr(mLSE is not None):
@@ -2228,8 +2262,14 @@ def correction_epilogue(
22282262
thr_mma: cute.core.ThrMma,
22292263
tOtO: cute.Tensor,
22302264
tidx: Int32,
2265+
stage: Int32,
2266+
m_block: Int32,
2267+
seqlen_q: Int32,
22312268
scale: Float32,
22322269
sO: cute.Tensor,
2270+
mO_cur: Optional[cute.Tensor] = None,
2271+
gO: Optional[cute.Tensor] = None,
2272+
gmem_tiled_copy_O: Optional[cute.TiledCopy] = None,
22332273
):
22342274
"""Apply final scaling and transformation to attention output before writing to global memory.
22352275
@@ -2302,6 +2342,57 @@ def correction_epilogue(
23022342
space=cute.arch.SharedSpace.shared_cta,
23032343
)
23042344

2345+
if const_expr(self.use_correction_warps_for_epi):
2346+
assert(not self.use_tma_O)
2347+
assert(gmem_tiled_copy_O is not None)
2348+
cute.arch.barrier(barrier_id=int(NamedBarrierFwd.Epilogue),
2349+
number_of_threads=len(self.epilogue_warp_ids) * cute.arch.WARP_SIZE)
2350+
gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx)
2351+
tOsO = gmem_thr_copy_O.partition_S(sO)
2352+
cO = cute.make_identity_tensor((self.m_block_size, self.head_dim_v_padded))
2353+
tOgO = gmem_thr_copy_O.partition_D(gO)
2354+
tOcO = gmem_thr_copy_O.partition_S(cO)
2355+
t0OcO = gmem_tiled_copy_O.get_slice(0).partition_S(cO)
2356+
tOpO = utils.predicate_k(tOcO, limit=mO_cur.shape[1])
2357+
# TODO: the packgqa case isn't correct rn (sometimes IMA), disabling it
2358+
assert not self.pack_gqa
2359+
pack_gqa = PackGQA(
2360+
self.m_block_size,
2361+
self.head_dim_v_padded,
2362+
self.check_hdim_v_oob,
2363+
self.qhead_per_kvhead,
2364+
)
2365+
2366+
# load acc O from smem to rmem for wider vectorization
2367+
tOrO = cute.make_fragment_like(tOsO, self.o_dtype)
2368+
cute.autovec_copy(tOsO, tOrO)
2369+
# copy acc O from rmem to gmem
2370+
if const_expr(not self.pack_gqa):
2371+
for rest_m in cutlass.range_constexpr(cute.size(tOrO.shape[1])):
2372+
if (
2373+
t0OcO[0, rest_m, 0][0]
2374+
< seqlen_q
2375+
- (self.q_stage * m_block + stage) * self.m_block_size
2376+
- tOcO[0][0]
2377+
):
2378+
cute.copy(
2379+
gmem_tiled_copy_O,
2380+
tOrO[None, rest_m, None],
2381+
tOgO[None, rest_m, None, self.q_stage * m_block + stage],
2382+
pred=tOpO[None, rest_m, None]
2383+
if const_expr(self.check_hdim_v_oob)
2384+
else None,
2385+
)
2386+
else:
2387+
pack_gqa.store_O(
2388+
mO_cur,
2389+
tOrO,
2390+
gmem_tiled_copy_O,
2391+
tidx,
2392+
self.q_stage * m_block + stage,
2393+
seqlen_q,
2394+
)
2395+
23052396
@cute.jit
23062397
def epilogue_s2g(
23072398
self,
@@ -2389,7 +2480,7 @@ def epilogue_s2g(
23892480
tOrO[None, rest_m, None],
23902481
tOgO[None, rest_m, None, self.q_stage * m_block + stage],
23912482
pred=tOpO[None, rest_m, None]
2392-
if self.check_hdim_v_oob
2483+
if const_expr(self.check_hdim_v_oob)
23932484
else None,
23942485
)
23952486
else:

flash_attn/cute/interface.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -464,14 +464,16 @@ def _flash_attn_fwd(
464464
m_block_size=m_block_size,
465465
n_block_size=n_block_size,
466466
is_persistent=not causal
467-
and not local
468-
and cu_seqlens_q is None
469-
and seqused_q is None
470-
and not is_split_kv,
467+
and not local
468+
and cu_seqlens_q is None
469+
and seqused_q is None
470+
and not is_split_kv,
471471
score_mod=score_mod,
472472
mask_mod=mask_mod,
473473
has_aux_tensors=aux_tensors is not None,
474474
paged_kv_non_tma=page_size not in [None, 128],
475+
is_varlen_q=cu_seqlens_q is not None
476+
or seqused_q is not None,
475477
)
476478
else:
477479
raise ValueError(

0 commit comments

Comments
 (0)