Skip to content

Commit a246af3

Browse files
committed
comments
1 parent 807ea18 commit a246af3

6 files changed

Lines changed: 198 additions & 204 deletions

File tree

flash_attn/cute/flash_fwd.py

Lines changed: 55 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from flash_attn.cute import hopper_helpers as sm90_utils
2424
from flash_attn.cute import utils
2525
from flash_attn.cute.mask import AttentionMask
26-
from flash_attn.cute.softmax import Softmax
26+
from flash_attn.cute.softmax import Softmax, apply_score_mod_inner
2727
from flash_attn.cute.seqlen_info import SeqlenInfoQK
2828
from flash_attn.cute.block_info import BlockInfo
2929
from flash_attn.cute import pipeline
@@ -51,6 +51,7 @@ def __init__(
5151
num_threads: int = 128,
5252
Q_in_regs: bool = False,
5353
score_mod: cutlass.Constexpr | None = None,
54+
has_buffers: bool = False,
5455
):
5556
"""Initializes the configuration for a flash attention kernel.
5657
@@ -67,7 +68,7 @@ def __init__(
6768
:type num_threads: int
6869
:param is_causal: is causal
6970
:param score_mod: A callable that takes the attention scores and applies a modification.
70-
Callable signature:
71+
Callable signature: ``score_mod(scores, batch_idx, head_idx, q_idx, kv_idx, buffers) -> Any``
7172
"""
7273
self.dtype = dtype
7374
# padding head_dim to a multiple of 16 as k_block_size
@@ -89,6 +90,11 @@ def __init__(
8990
self.num_stages = num_stages
9091
self.Q_in_regs = Q_in_regs
9192
self.score_mod = score_mod
93+
self.qk_acc_dtype = Float32
94+
if cutlass.const_expr(has_buffers):
95+
self.vec_size: cutlass.Constexpr = 1
96+
else:
97+
self.vec_size: cutlass.Constexpr = 2
9298

9399
@staticmethod
94100
def can_implement(
@@ -934,11 +940,15 @@ def load_V_next():
934940
)
935941
if cutlass.const_expr(score_mod is not None):
936942
self.apply_score_mod(
937-
acc_S, mma_params.thr_mma_qk, score_mod,
938-
batch_idx, head_idx, m_block, n_block,
939-
softmax_scale=softmax.softmax_scale,
943+
acc_S,
944+
mma_params.thr_mma_qk,
945+
batch_idx,
946+
head_idx,
947+
m_block,
948+
n_block,
949+
softmax=softmax,
940950
buffers=buffers,
941-
fastdiv_mods=fastdiv_mods
951+
fastdiv_mods=fastdiv_mods,
942952
)
943953

944954
smem_pipe_write = self.advance_pipeline(smem_pipe_write)
@@ -1692,11 +1702,15 @@ def mma(
16921702
# Use vectorized score modification
16931703
if cutlass.const_expr(score_mod is not None):
16941704
self.apply_score_mod(
1695-
acc_S, thr_mma_qk, score_mod,
1696-
batch_idx, head_idx, m_block, n_block_max - 1,
1697-
softmax_scale=softmax.softmax_scale,
1705+
acc_S,
1706+
thr_mma_qk,
1707+
batch_idx,
1708+
head_idx,
1709+
m_block,
1710+
n_block_max - 1,
1711+
softmax=softmax,
16981712
buffers=buffers,
1699-
fastdiv_mods=fastdiv_mods
1713+
fastdiv_mods=fastdiv_mods,
17001714
)
17011715
# if cute.arch.thread_idx()[0] == 128: cute.print_tensor(utils.make_acc_tensor_mn_view(acc_S))
17021716
mask_fn(acc_S, n_block=n_block_max - 1, mask_seqlen=True)
@@ -1843,11 +1857,15 @@ def mma_one_n_block(
18431857
pipeline_k.consumer_release(smem_pipe_read)
18441858
if cutlass.const_expr(score_mod is not None):
18451859
self.apply_score_mod(
1846-
acc_S, thr_mma_qk, score_mod,
1847-
batch_idx, head_idx, m_block, n_block,
1848-
softmax_scale=softmax.softmax_scale,
1860+
acc_S,
1861+
thr_mma_qk,
1862+
batch_idx,
1863+
head_idx,
1864+
m_block,
1865+
n_block,
1866+
softmax=softmax,
18491867
buffers=buffers,
1850-
fastdiv_mods=fastdiv_mods
1868+
fastdiv_mods=fastdiv_mods,
18511869
)
18521870
if const_expr(mask_fn is not None):
18531871
mask_fn(acc_S, n_block=n_block)
@@ -1923,11 +1941,15 @@ def mma_one_n_block_intrawg_overlap(
19231941
pipeline_k.consumer_release(smem_pipe_read)
19241942
if cutlass.const_expr(score_mod is not None):
19251943
self.apply_score_mod(
1926-
acc_S, thr_mma_qk, score_mod,
1927-
batch_idx, head_idx, m_block, n_block,
1928-
softmax_scale=softmax.softmax_scale,
1944+
acc_S,
1945+
thr_mma_qk,
1946+
batch_idx,
1947+
head_idx,
1948+
m_block,
1949+
n_block,
1950+
softmax=softmax,
19291951
buffers=buffers,
1930-
fastdiv_mods=fastdiv_mods
1952+
fastdiv_mods=fastdiv_mods,
19311953
)
19321954
# if cute.arch.thread_idx()[0] == 128: cute.print_tensor(utils.make_acc_tensor_mn_view(acc_S))
19331955
if const_expr(mask_fn is not None):
@@ -1965,73 +1987,32 @@ def apply_score_mod(
19651987
self,
19661988
acc_S,
19671989
thr_mma_qk,
1968-
score_mod,
19691990
batch_idx,
19701991
head_idx,
19711992
m_block,
19721993
n_block,
1973-
softmax_scale=None,
1994+
softmax,
19741995
buffers=None,
19751996
fastdiv_mods=None,
1976-
VEC_SIZE: cutlass.Constexpr[int] = 1,
19771997
):
1978-
# Get index tensors with proper domain offset
1998+
# Prepare index tensor
19791999
cS = cute.make_identity_tensor((self.m_block_size, self.n_block_size))
19802000
cS = cute.domain_offset((m_block * self.m_block_size, n_block * self.n_block_size), cS)
19812001
tScS = thr_mma_qk.partition_C(cS)
1982-
1983-
# Vectorized processing similar to SM100
1984-
n_vals = cutlass.const_expr(cute.size(acc_S.shape))
1985-
score_vec = cute.make_fragment(VEC_SIZE, Float32)
1986-
kv_idx_vec = cute.make_fragment(VEC_SIZE, cutlass.Int32)
1987-
1988-
# Create broadcasted fragments for constant values
1989-
batch_idx_vec = utils.broadcast_scalar_to_vec(batch_idx, kv_idx_vec, cutlass.Int32)
1990-
head_idx_vec = utils.broadcast_scalar_to_vec(head_idx, kv_idx_vec, cutlass.Int32)
1991-
q_idx_vec = cute.make_fragment(VEC_SIZE, cutlass.Int32)
1992-
1993-
# Load SSA values once
1994-
batch_idx_ssa = batch_idx_vec.load()
1995-
head_idx_ssa = head_idx_vec.load()
1996-
1997-
# Build SSA slices and call into scoremod / writeback
1998-
for i in cutlass.range(0, n_vals, VEC_SIZE, unroll_full=True):
1999-
for j in cutlass.range(VEC_SIZE, unroll_full=True):
2000-
score_vec[j] = acc_S[i + j]
2001-
if softmax_scale is not None:
2002-
score_vec[j] = score_vec[j] * softmax_scale
2003-
2004-
# Use FastDivmod for bounds checking if buffers and divmods are present
2005-
if cutlass.const_expr(buffers is not None and fastdiv_mods is not None):
2006-
seqlen_q_divmod, seqlen_k_divmod = fastdiv_mods
2007-
_, q_idx_wrapped = seqlen_q_divmod.divmod(tScS[i + j][0])
2008-
_, kv_idx_wrapped = seqlen_k_divmod.divmod(tScS[i + j][1])
2009-
q_idx_vec[j] = q_idx_wrapped
2010-
kv_idx_vec[j] = kv_idx_wrapped
2011-
else:
2012-
# No bounds checking needed - direct indexing
2013-
q_idx_vec[j] = tScS[i + j][0]
2014-
kv_idx_vec[j] = tScS[i + j][1]
2015-
score_ssa = score_vec.load()
2016-
q_idx_ssa = q_idx_vec.load()
2017-
kv_idx_ssa = kv_idx_vec.load()
2018-
2019-
buffer_args = []
2020-
if cutlass.const_expr(buffers is not None):
2021-
buffer_args = buffers
2022-
2023-
post_mod_scores = score_mod(
2024-
score_ssa,
2025-
batch_idx_ssa,
2026-
head_idx_ssa,
2027-
q_idx=q_idx_ssa,
2028-
kv_idx=kv_idx_ssa,
2029-
buffers=buffer_args
2030-
)
20312002

2032-
score_vec.store(post_mod_scores)
2033-
for j in cutlass.range(VEC_SIZE, unroll_full=True):
2034-
acc_S[i + j] = score_vec[j]
2003+
apply_score_mod_inner(
2004+
acc_S,
2005+
tScS,
2006+
self.score_mod,
2007+
batch_idx,
2008+
head_idx,
2009+
softmax.softmax_scale,
2010+
self.vec_size,
2011+
self.qk_acc_dtype,
2012+
buffers,
2013+
fastdiv_mods,
2014+
constant_q_idx=None
2015+
)
20352016

20362017
def warp_scheduler_barrier_sync(self):
20372018
if const_expr(self.use_scheduler_barrier):

flash_attn/cute/flash_fwd_sm100.py

Lines changed: 29 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import cuda.bindings.driver as cuda
2121

2222
import cutlass
23-
from cutlass._mlir.dialects.llvm import intr_prefetch
2423
import cutlass.cute as cute
2524
from cutlass import Float32, Int32, const_expr
2625
from cutlass.cute.nvgpu import cpasync
@@ -30,7 +29,7 @@
3029
import flash_attn.cute.utils as utils
3130
# import flash_attn.cute.pipeline as pipeline
3231
from flash_attn.cute.mask import AttentionMask
33-
from flash_attn.cute.softmax import SoftmaxSm100
32+
from flash_attn.cute.softmax import SoftmaxSm100, apply_score_mod_inner
3433
from flash_attn.cute.seqlen_info import SeqlenInfoQK
3534
from flash_attn.cute.block_info import BlockInfo
3635
from flash_attn.cute.pack_gqa import PackGQA
@@ -66,6 +65,7 @@ def __init__(
6665
n_block_size: int = 128,
6766
is_persistent: bool = True,
6867
score_mod: cutlass.Constexpr | None = None,
68+
has_buffers: cutlass.Constexpr = False,
6969
):
7070
# self.dtype = dtype
7171
# padding head_dim to a multiple of 16 as k_block_size
@@ -97,6 +97,10 @@ def __init__(
9797
if pack_gqa:
9898
assert m_block_size % self.qhead_per_kvhead == 0, "For PackGQA, m_block_size must be divisible by qhead_per_kvhead"
9999
self.score_mod = score_mod
100+
if cutlass.const_expr(has_buffers):
101+
self.vec_size: cutlass.Constexpr = 1
102+
else:
103+
self.vec_size: cutlass.Constexpr = 2
100104
# Does S1 need to wait for S0 to finish
101105
# self.s0_s1_barrier = self.head_dim_padded in [64, 96] and (not self.is_causal and not self.is_local)
102106
self.s0_s1_barrier = False
@@ -485,7 +489,6 @@ class SharedStorage:
485489
window_size_right = Int32(window_size_right)
486490

487491
fastdiv_mods = None
488-
# Create FastDivmod objects only if buffers are present
489492
if cutlass.const_expr(buffers is not None):
490493
seqlen_q = cute.size(mQ.shape[0])
491494
seqlen_k = cute.size(mK.shape[0])
@@ -814,7 +817,10 @@ def kernel(
814817
stage = Int32(0 if warp_idx < self.softmax1_warp_ids[0] else 1)
815818
softmax_loop(
816819
stage=stage,
817-
tStSi=cute.make_tensor(tStS.iterator + (self.tmem_s_offset[0] if stage == 0 else self.tmem_s_offset[1]), tStS.layout),
820+
tStSi=cute.make_tensor(
821+
tStS.iterator + (self.tmem_s_offset[0] if stage == 0 else self.tmem_s_offset[1]),
822+
tStS.layout
823+
),
818824
)
819825
cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_tmem_dealloc_offset)
820826
else:
@@ -1970,101 +1976,32 @@ def apply_score_mod(
19701976
m_block,
19711977
n_block,
19721978
softmax,
1973-
buffers = None,
1974-
fastdiv_mods = (None, None),
1975-
VEC_SIZE: cutlass.Constexpr[int] = 1,
1979+
buffers=None,
1980+
fastdiv_mods=(None, None),
19761981
):
1977-
"""Apply score modification function to attention scores.
1978-
1979-
Args:
1980-
tSrS_t2r: Score tensor to modify
1981-
thr_tmem_load: Thread memory load partition
1982-
thr_mma_qk: Thread MMA QK partition
1983-
batch_idx: Batch index
1984-
head_idx: Head index
1985-
m_block: M block index
1986-
n_block: N block index
1987-
softmax: Softmax module containing scale
1988-
"""
1989-
# Get M, N index tensor + layout like accum
1982+
"""Apply score modification for SM100 (constant q_idx)."""
1983+
# Prepare index tensor with extra partition
19901984
cS = cute.make_identity_tensor((self.m_block_size, self.n_block_size))
19911985
cS = cute.domain_offset((m_block * self.m_block_size, n_block * self.n_block_size), cS)
19921986
tScS = thr_mma_qk.partition_C(cS)
19931987
tScS_t2r = thr_tmem_load.partition_D(tScS)
19941988

1995-
# Build index + score fragments
1996-
n_vals = cutlass.const_expr(cute.size(tSrS_t2r.shape))
1997-
score_vec = cute.make_fragment(VEC_SIZE, self.qk_acc_dtype)
1998-
kv_idx_vec = cute.make_fragment(VEC_SIZE, cutlass.Int32)
1999-
2000-
# Create broadcasted fragments for constant values
2001-
batch_idx_vec = utils.broadcast_scalar_to_vec(batch_idx, kv_idx_vec, cutlass.Int32)
2002-
head_idx_vec = utils.broadcast_scalar_to_vec(head_idx, kv_idx_vec, cutlass.Int32)
2003-
2004-
# Use FastDivmod for bounds checking if buffers and divmods are present
1989+
# Shared q_idx for all scores
1990+
q_idx_wrapped = tScS_t2r[0][0]
20051991
if cutlass.const_expr(buffers is not None):
20061992
seqlen_q_divmod, _ = fastdiv_mods
20071993
_, q_idx_wrapped = seqlen_q_divmod.divmod(tScS_t2r[0][0])
2008-
q_idx_vec = utils.broadcast_scalar_to_vec(q_idx_wrapped, kv_idx_vec, cutlass.Int32)
2009-
else:
2010-
# No bounds checking needed - direct indexing
2011-
q_idx_vec = utils.broadcast_scalar_to_vec(tScS_t2r[0][0], kv_idx_vec, cutlass.Int32)
2012-
2013-
# Load SSA values once
2014-
batch_idx_ssa = batch_idx_vec.load()
2015-
head_idx_ssa = head_idx_vec.load()
2016-
q_idx_ssa = q_idx_vec.load()
2017-
2018-
# Build SSA slices and call into scoremod / writeback
2019-
for i in cutlass.range(0, n_vals, VEC_SIZE, unroll_full=True):
2020-
for j in cutlass.range(VEC_SIZE, unroll_full=True):
2021-
score_vec[j] = tSrS_t2r[i + j] * softmax.softmax_scale
2022-
# Use FastDivmod for bounds checking if buffers and divmods are present
2023-
if cutlass.const_expr(buffers is not None and fastdiv_mods is not None):
2024-
_, seqlen_k_divmod = fastdiv_mods
2025-
_, kv_idx_wrapped = seqlen_k_divmod.divmod(tScS_t2r[i + j][1])
2026-
kv_idx_vec[j] = kv_idx_wrapped
2027-
else:
2028-
# No bounds checking needed - direct indexing
2029-
kv_idx_vec[j] = tScS_t2r[i + j][1]
2030-
score_ssa = score_vec.load()
2031-
kv_idx_ssa = kv_idx_vec.load()
2032-
2033-
buffer_args = []
2034-
if cutlass.const_expr(buffers is not None):
2035-
buffer_args = buffers
2036-
2037-
post_mod_scores = self.score_mod(
2038-
score_ssa,
2039-
batch_idx_ssa,
2040-
head_idx_ssa,
2041-
q_idx=q_idx_ssa,
2042-
kv_idx=kv_idx_ssa,
2043-
buffers=buffer_args
2044-
)
2045-
2046-
score_vec.store(post_mod_scores)
2047-
for j in cutlass.range(VEC_SIZE, unroll_full=True):
2048-
tSrS_t2r[i + j] = score_vec[j]
20491994

2050-
2051-
2052-
### Grave yard
2053-
2054-
# PRETTY WAY but uses to much rmem
2055-
2056-
# # TODO: We need to not materialize all of kv_idx
2057-
# tSrKV_idx = cute.make_fragment(tSrS_t2r.shape, cutlass.Int32)
2058-
# n_vals = cutlass.const_expr(cute.size(tSrS_t2r.shape))
2059-
# for i in cutlass.range(n_vals, unroll_full=True):
2060-
# tSrKV_idx[i] = tScS_t2r[i][1]
2061-
2062-
# # Create broadcasted q_idx
2063-
# tSrQ_idx = cute.make_fragment(1, cutlass.Int32)
2064-
# tSrQ_idx[0] = tScS_t2r[0][0]
2065-
# tSrQ_idx_broadcasted = utils.broadcast_to(tSrQ_idx, tSrKV_idx).load()
2066-
2067-
# tSrS_t2r_ssa = tSrS_t2r.load()
2068-
# tSrS_t2r_ssa = tSrS_t2r_ssa * softmax.softmax_scale
2069-
# post_mod_scores = self.score_mod(tSrS_t2r_ssa, batch_idx, head_idx, q_idx=tSrQ_idx_broadcasted, kv_idx=tSrKV_idx.load())
2070-
# tSrS_t2r.store(post_mod_scores)
1995+
apply_score_mod_inner(
1996+
tSrS_t2r,
1997+
tScS_t2r,
1998+
self.score_mod,
1999+
batch_idx,
2000+
head_idx,
2001+
softmax.softmax_scale,
2002+
self.vec_size,
2003+
self.qk_acc_dtype,
2004+
buffers,
2005+
fastdiv_mods,
2006+
constant_q_idx=q_idx_wrapped
2007+
)

0 commit comments

Comments
 (0)