2323from flash_attn .cute import hopper_helpers as sm90_utils
2424from flash_attn .cute import utils
2525from flash_attn .cute .mask import AttentionMask
26- from flash_attn .cute .softmax import Softmax
26+ from flash_attn .cute .softmax import Softmax , apply_score_mod_inner
2727from flash_attn .cute .seqlen_info import SeqlenInfoQK
2828from flash_attn .cute .block_info import BlockInfo
2929from flash_attn .cute import pipeline
@@ -51,6 +51,7 @@ def __init__(
5151 num_threads : int = 128 ,
5252 Q_in_regs : bool = False ,
5353 score_mod : cutlass .Constexpr | None = None ,
54+ has_buffers : bool = False ,
5455 ):
5556 """Initializes the configuration for a flash attention kernel.
5657
@@ -67,7 +68,7 @@ def __init__(
6768 :type num_threads: int
6869 :param is_causal: is causal
6970 :param score_mod: A callable that takes the attention scores and applies a modification.
70- Callable signature:
71+ Callable signature: ``score_mod(scores, batch_idx, head_idx, q_idx, kv_idx, buffers) -> Any``
7172 """
7273 self .dtype = dtype
7374 # padding head_dim to a multiple of 16 as k_block_size
@@ -89,6 +90,11 @@ def __init__(
8990 self .num_stages = num_stages
9091 self .Q_in_regs = Q_in_regs
9192 self .score_mod = score_mod
93+ self .qk_acc_dtype = Float32
94+ if cutlass .const_expr (has_buffers ):
95+ self .vec_size : cutlass .Constexpr = 1
96+ else :
97+ self .vec_size : cutlass .Constexpr = 2
9298
9399 @staticmethod
94100 def can_implement (
@@ -934,11 +940,15 @@ def load_V_next():
934940 )
935941 if cutlass .const_expr (score_mod is not None ):
936942 self .apply_score_mod (
937- acc_S , mma_params .thr_mma_qk , score_mod ,
938- batch_idx , head_idx , m_block , n_block ,
939- softmax_scale = softmax .softmax_scale ,
943+ acc_S ,
944+ mma_params .thr_mma_qk ,
945+ batch_idx ,
946+ head_idx ,
947+ m_block ,
948+ n_block ,
949+ softmax = softmax ,
940950 buffers = buffers ,
941- fastdiv_mods = fastdiv_mods
951+ fastdiv_mods = fastdiv_mods ,
942952 )
943953
944954 smem_pipe_write = self .advance_pipeline (smem_pipe_write )
@@ -1692,11 +1702,15 @@ def mma(
16921702 # Use vectorized score modification
16931703 if cutlass .const_expr (score_mod is not None ):
16941704 self .apply_score_mod (
1695- acc_S , thr_mma_qk , score_mod ,
1696- batch_idx , head_idx , m_block , n_block_max - 1 ,
1697- softmax_scale = softmax .softmax_scale ,
1705+ acc_S ,
1706+ thr_mma_qk ,
1707+ batch_idx ,
1708+ head_idx ,
1709+ m_block ,
1710+ n_block_max - 1 ,
1711+ softmax = softmax ,
16981712 buffers = buffers ,
1699- fastdiv_mods = fastdiv_mods
1713+ fastdiv_mods = fastdiv_mods ,
17001714 )
17011715 # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(utils.make_acc_tensor_mn_view(acc_S))
17021716 mask_fn (acc_S , n_block = n_block_max - 1 , mask_seqlen = True )
@@ -1843,11 +1857,15 @@ def mma_one_n_block(
18431857 pipeline_k .consumer_release (smem_pipe_read )
18441858 if cutlass .const_expr (score_mod is not None ):
18451859 self .apply_score_mod (
1846- acc_S , thr_mma_qk , score_mod ,
1847- batch_idx , head_idx , m_block , n_block ,
1848- softmax_scale = softmax .softmax_scale ,
1860+ acc_S ,
1861+ thr_mma_qk ,
1862+ batch_idx ,
1863+ head_idx ,
1864+ m_block ,
1865+ n_block ,
1866+ softmax = softmax ,
18491867 buffers = buffers ,
1850- fastdiv_mods = fastdiv_mods
1868+ fastdiv_mods = fastdiv_mods ,
18511869 )
18521870 if const_expr (mask_fn is not None ):
18531871 mask_fn (acc_S , n_block = n_block )
@@ -1923,11 +1941,15 @@ def mma_one_n_block_intrawg_overlap(
19231941 pipeline_k .consumer_release (smem_pipe_read )
19241942 if cutlass .const_expr (score_mod is not None ):
19251943 self .apply_score_mod (
1926- acc_S , thr_mma_qk , score_mod ,
1927- batch_idx , head_idx , m_block , n_block ,
1928- softmax_scale = softmax .softmax_scale ,
1944+ acc_S ,
1945+ thr_mma_qk ,
1946+ batch_idx ,
1947+ head_idx ,
1948+ m_block ,
1949+ n_block ,
1950+ softmax = softmax ,
19291951 buffers = buffers ,
1930- fastdiv_mods = fastdiv_mods
1952+ fastdiv_mods = fastdiv_mods ,
19311953 )
19321954 # if cute.arch.thread_idx()[0] == 128: cute.print_tensor(utils.make_acc_tensor_mn_view(acc_S))
19331955 if const_expr (mask_fn is not None ):
@@ -1965,73 +1987,32 @@ def apply_score_mod(
19651987 self ,
19661988 acc_S ,
19671989 thr_mma_qk ,
1968- score_mod ,
19691990 batch_idx ,
19701991 head_idx ,
19711992 m_block ,
19721993 n_block ,
1973- softmax_scale = None ,
1994+ softmax ,
19741995 buffers = None ,
19751996 fastdiv_mods = None ,
1976- VEC_SIZE : cutlass .Constexpr [int ] = 1 ,
19771997 ):
1978- # Get index tensors with proper domain offset
1998+ # Prepare index tensor
19791999 cS = cute .make_identity_tensor ((self .m_block_size , self .n_block_size ))
19802000 cS = cute .domain_offset ((m_block * self .m_block_size , n_block * self .n_block_size ), cS )
19812001 tScS = thr_mma_qk .partition_C (cS )
1982-
1983- # Vectorized processing similar to SM100
1984- n_vals = cutlass .const_expr (cute .size (acc_S .shape ))
1985- score_vec = cute .make_fragment (VEC_SIZE , Float32 )
1986- kv_idx_vec = cute .make_fragment (VEC_SIZE , cutlass .Int32 )
1987-
1988- # Create broadcasted fragments for constant values
1989- batch_idx_vec = utils .broadcast_scalar_to_vec (batch_idx , kv_idx_vec , cutlass .Int32 )
1990- head_idx_vec = utils .broadcast_scalar_to_vec (head_idx , kv_idx_vec , cutlass .Int32 )
1991- q_idx_vec = cute .make_fragment (VEC_SIZE , cutlass .Int32 )
1992-
1993- # Load SSA values once
1994- batch_idx_ssa = batch_idx_vec .load ()
1995- head_idx_ssa = head_idx_vec .load ()
1996-
1997- # Build SSA slices and call into scoremod / writeback
1998- for i in cutlass .range (0 , n_vals , VEC_SIZE , unroll_full = True ):
1999- for j in cutlass .range (VEC_SIZE , unroll_full = True ):
2000- score_vec [j ] = acc_S [i + j ]
2001- if softmax_scale is not None :
2002- score_vec [j ] = score_vec [j ] * softmax_scale
2003-
2004- # Use FastDivmod for bounds checking if buffers and divmods are present
2005- if cutlass .const_expr (buffers is not None and fastdiv_mods is not None ):
2006- seqlen_q_divmod , seqlen_k_divmod = fastdiv_mods
2007- _ , q_idx_wrapped = seqlen_q_divmod .divmod (tScS [i + j ][0 ])
2008- _ , kv_idx_wrapped = seqlen_k_divmod .divmod (tScS [i + j ][1 ])
2009- q_idx_vec [j ] = q_idx_wrapped
2010- kv_idx_vec [j ] = kv_idx_wrapped
2011- else :
2012- # No bounds checking needed - direct indexing
2013- q_idx_vec [j ] = tScS [i + j ][0 ]
2014- kv_idx_vec [j ] = tScS [i + j ][1 ]
2015- score_ssa = score_vec .load ()
2016- q_idx_ssa = q_idx_vec .load ()
2017- kv_idx_ssa = kv_idx_vec .load ()
2018-
2019- buffer_args = []
2020- if cutlass .const_expr (buffers is not None ):
2021- buffer_args = buffers
2022-
2023- post_mod_scores = score_mod (
2024- score_ssa ,
2025- batch_idx_ssa ,
2026- head_idx_ssa ,
2027- q_idx = q_idx_ssa ,
2028- kv_idx = kv_idx_ssa ,
2029- buffers = buffer_args
2030- )
20312002
2032- score_vec .store (post_mod_scores )
2033- for j in cutlass .range (VEC_SIZE , unroll_full = True ):
2034- acc_S [i + j ] = score_vec [j ]
2003+ apply_score_mod_inner (
2004+ acc_S ,
2005+ tScS ,
2006+ self .score_mod ,
2007+ batch_idx ,
2008+ head_idx ,
2009+ softmax .softmax_scale ,
2010+ self .vec_size ,
2011+ self .qk_acc_dtype ,
2012+ buffers ,
2013+ fastdiv_mods ,
2014+ constant_q_idx = None
2015+ )
20352016
20362017 def warp_scheduler_barrier_sync (self ):
20372018 if const_expr (self .use_scheduler_barrier ):
0 commit comments