5656)
5757
5858
59- # class NamedBarrierFwd(enum.IntEnum):
60- # Epilogue = enum.auto() # starts from 1 as barrier 0 is reserved for sync_threads()
59+ class NamedBarrierFwd (enum .IntEnum ):
60+ Epilogue = enum .auto () # starts from 1 as barrier 0 is reserved for sync_threads()
6161# WarpSchedulerWG1 = enum.auto()
6262# WarpSchedulerWG2 = enum.auto()
6363# WarpSchedulerWG3 = enum.auto()
@@ -85,6 +85,7 @@ def __init__(
8585 mask_mod : cutlass .Constexpr | None = None ,
8686 has_aux_tensors : cutlass .Constexpr = False ,
8787 paged_kv_non_tma : bool = False ,
88+ is_varlen_q : bool = False ,
8889 ):
8990 self .use_tma_KV = not paged_kv_non_tma
9091 # self.dtype = dtype
@@ -112,6 +113,8 @@ def __init__(
112113 self .is_persistent = is_persistent
113114 self .is_causal = is_causal
114115 self .is_local = is_local
116+ self .is_varlen_q = is_varlen_q
117+ self .use_correction_warps_for_epi = is_varlen_q
115118 self .qhead_per_kvhead = qhead_per_kvhead
116119 self .is_split_kv = is_split_kv
117120 self .pack_gqa = pack_gqa
@@ -146,8 +149,8 @@ def __init__(
146149 self .softmax1_warp_ids = (4 , 5 , 6 , 7 )
147150 self .correction_warp_ids = (8 , 9 , 10 , 11 )
148151 self .mma_warp_id = 12
149- self .load_warp_ids = (13 ,)
150- self .epilogue_warp_ids = (14 ,)
152+ self .epilogue_warp_ids = (13 ,)
153+ self .load_warp_ids = (14 ,)
151154 self .empty_warp_ids = (15 ,)
152155 SM100_TMEM_CAPACITY_COLUMNS = 512
153156 self .tmem_alloc_cols = SM100_TMEM_CAPACITY_COLUMNS
@@ -164,6 +167,15 @@ def __init__(
164167 )
165168 )
166169
170+ if not self .use_tma_KV :
171+ self .load_warp_ids = (14 , 15 )
172+ self .empty_warp_ids = ()
173+ if self .use_correction_warps_for_epi :
174+ self .empty_warp_ids = self .empty_warp_ids + self .epilogue_warp_ids
175+ self .epilogue_warp_ids = self .correction_warp_ids
176+ elif self .is_varlen_q : # fallback
177+ self .epilogue_warp_ids = (13 , 14 )
178+
167179 self .tmem_s_offset = [0 , self .n_block_size ] # e.g., 0, 128
168180 self .tmem_o_offset = [
169181 self .tmem_s_offset [- 1 ] + self .n_block_size + i * self .head_dim_v_padded
@@ -506,19 +518,11 @@ def __call__(
506518 self .cluster_layout_vmnk .shape ,
507519 )
508520 else :
509- assert self .use_tma_O , "Loading O and K/V will contend for the empty warp."
510- self .epilogue_warp_ids = (13 ,)
511- self .load_warp_ids = (14 , 15 )
512- self .empty_warp_ids = ()
513521 tma_atom_K = None
514522 tma_atom_V = None
515523
516524 o_cta_v_layout = cute .composition (cute .make_identity_layout (mO .shape ), self .epi_tile )
517525
518- # print(sO_layout.outer)
519- if const_expr (not self .use_tma_O ):
520- self .epilogue_warp_ids = (14 , 15 )
521- self .empty_warp_ids = ()
522526 self .num_epilogue_threads = cute .arch .WARP_SIZE * len (self .epilogue_warp_ids )
523527 if const_expr (self .use_tma_O ):
524528 tma_atom_O , mO = cpasync .make_tiled_tma_atom (
@@ -546,7 +550,6 @@ def __call__(
546550 assert self .m_block_size % tO_layout .shape [0 ] == 0
547551 vO_layout = cute .make_layout ((1 , async_copy_elems ))
548552 gmem_tiled_copy_O = cute .make_tiled_copy_tv (atom_universal_copy , tO_layout , vO_layout )
549- print ("gmem_tiled_copy_O: " , gmem_tiled_copy_O )
550553
551554 if const_expr (mCuSeqlensQ is not None or mSeqUsedQ is not None ):
552555 TileScheduler = SingleTileVarlenScheduler
@@ -799,7 +802,7 @@ def kernel(
799802 cute .arch .mbarrier_init (
800803 mbar_ptr + self .mbar_s0_s1_sequence_offset + i , cute .arch .WARP_SIZE
801804 )
802- if warp_idx == 4 :
805+ if const_expr ( not self . use_correction_warps_for_epi ) and warp_idx == 4 :
803806 for i in cutlass .range_constexpr (self .q_stage ):
804807 cute .arch .mbarrier_init (
805808 mbar_ptr + self .mbar_corr_epi_full_offset + i ,
@@ -931,6 +934,12 @@ def kernel(
931934 if warp_idx == self .empty_warp_ids [0 ]:
932935 cute .arch .warpgroup_reg_dealloc (self .num_regs_empty )
933936
937+ if const_expr (len (self .empty_warp_ids ) > 1 ):
938+ if warp_idx == self .empty_warp_ids [1 ]:
939+ cute .arch .warpgroup_reg_dealloc (self .num_regs_empty )
940+
941+ assert len (self .empty_warp_ids ) <= 2
942+
934943 # ///////////////////////////////////////////////////////////////////////////////
935944 # LOAD
936945 # ///////////////////////////////////////////////////////////////////////////////
@@ -1004,19 +1013,20 @@ def kernel(
10041013 # ///////////////////////////////////////////////////////////////////////////////
10051014 # Epilogue
10061015 # ///////////////////////////////////////////////////////////////////////////////
1007- if warp_idx >= self .epilogue_warp_ids [0 ] and warp_idx <= self .epilogue_warp_ids [- 1 ]:
1008- cute .arch .warpgroup_reg_dealloc (self .num_regs_other )
1009- self .epilogue_s2g (
1010- mO ,
1011- sO ,
1012- gmem_tiled_copy_O ,
1013- tma_atom_O ,
1014- mbar_ptr ,
1015- block_info ,
1016- num_splits ,
1017- SeqlenInfoCls ,
1018- TileSchedulerCls ,
1019- )
1016+ if const_expr (not self .use_correction_warps_for_epi ):
1017+ if warp_idx >= self .epilogue_warp_ids [0 ] and warp_idx <= self .epilogue_warp_ids [- 1 ]:
1018+ cute .arch .warpgroup_reg_dealloc (self .num_regs_other )
1019+ self .epilogue_s2g (
1020+ mO ,
1021+ sO ,
1022+ gmem_tiled_copy_O ,
1023+ tma_atom_O ,
1024+ mbar_ptr ,
1025+ block_info ,
1026+ num_splits ,
1027+ SeqlenInfoCls ,
1028+ TileSchedulerCls ,
1029+ )
10201030
10211031 # ///////////////////////////////////////////////////////////////////////////////
10221032 # Softmax
@@ -1080,6 +1090,7 @@ def kernel(
10801090 mLSE ,
10811091 sO ,
10821092 learnable_sink ,
1093+ gmem_tiled_copy_O ,
10831094 tma_atom_O ,
10841095 mbar_ptr ,
10851096 softmax_scale_log2 ,
@@ -1931,6 +1942,7 @@ def correction_loop(
19311942 mLSE : cute .Tensor ,
19321943 sO : cute .Tensor ,
19331944 learnable_sink : Optional [cute .Tensor ],
1945+ gmem_tiled_copy_O : cute .TiledCopy ,
19341946 tma_atom_O : cute .CopyAtom ,
19351947 mbar_ptr : cute .Pointer ,
19361948 softmax_scale_log2 : Float32 ,
@@ -1972,6 +1984,12 @@ def correction_loop(
19721984 seqlen = SeqlenInfoCls (batch_idx )
19731985 n_block_min , n_block_max = block_info .get_n_block_min_max (seqlen , m_block , split_idx , num_splits )
19741986
1987+ if const_expr (self .is_split_kv ):
1988+ mO_cur = seqlen .offset_batch_Q (mO , batch_idx , dim = 3 )[None , None , head_idx , split_idx ]
1989+ else :
1990+ mO_cur = seqlen .offset_batch_Q (mO , batch_idx , dim = 3 )[None , None , head_idx ]
1991+ gO = cute .local_tile (mO_cur , (self .m_block_size , self .head_dim_v_padded ), (None , 0 ))
1992+
19751993 # Default LSE to -inf for invalid split_idx tiles
19761994 stats = [(0.0 , - Float32 .inf if const_expr (mLSE is not None or learnable_sink is not None ) else None , True )] * self .q_stage
19771995
@@ -2070,17 +2088,25 @@ def correction_loop(
20702088 cute .arch .mbarrier_wait (
20712089 mbar_ptr + self .mbar_O_full_offset + stage , o_corr_consumer_phase
20722090 )
2073- cute .arch .mbarrier_wait (
2074- mbar_ptr + self .mbar_corr_epi_empty_offset + stage , corr_epi_producer_phase
2075- )
2091+ if const_expr (not self .use_correction_warps_for_epi ):
2092+ cute .arch .mbarrier_wait (
2093+ mbar_ptr + self .mbar_corr_epi_empty_offset + stage , corr_epi_producer_phase
2094+ )
20762095 self .correction_epilogue (
20772096 thr_mma_pv ,
20782097 tOtOs [stage ],
20792098 tidx ,
2099+ stage ,
2100+ m_block ,
2101+ seqlen .seqlen_q ,
20802102 scale ,
20812103 sO [None , None , stage ],
2104+ mO_cur ,
2105+ gO ,
2106+ gmem_tiled_copy_O ,
20822107 )
2083- cute .arch .mbarrier_arrive (mbar_ptr + self .mbar_corr_epi_full_offset + stage )
2108+ if const_expr (not self .use_correction_warps_for_epi ):
2109+ cute .arch .mbarrier_arrive (mbar_ptr + self .mbar_corr_epi_full_offset + stage )
20842110 # Signal for the next work tile that O buffers in tmem are already read, so
20852111 # mma warp can write to them
20862112 cute .arch .mbarrier_arrive (mbar_ptr + self .mbar_P_full_O_rescaled_offset + stage )
@@ -2090,6 +2116,11 @@ def correction_loop(
20902116 softmax_corr_consumer_phase ^= 1
20912117 corr_epi_producer_phase ^= 1
20922118 else :
2119+ # WARNING: we need some code before the const_expr, see https://github.com/NVIDIA/cutlass/issues/2781
2120+ if const_expr (self .use_correction_warps_for_epi ):
2121+ gmem_tiled_copy_O_for_empty_tile = gmem_tiled_copy_O
2122+ else :
2123+ gmem_tiled_copy_O_for_empty_tile = None
20932124 if const_expr (self .use_block_sparsity ):
20942125 (
20952126 softmax_corr_consumer_phase ,
@@ -2126,6 +2157,9 @@ def correction_loop(
21262157 o_corr_consumer_phase ,
21272158 corr_epi_producer_phase ,
21282159 softmax_scale_log2 ,
2160+ mO_cur ,
2161+ gO ,
2162+ gmem_tiled_copy_O_for_empty_tile ,
21292163 )
21302164
21312165 if const_expr (mLSE is not None ):
@@ -2228,8 +2262,14 @@ def correction_epilogue(
22282262 thr_mma : cute .core .ThrMma ,
22292263 tOtO : cute .Tensor ,
22302264 tidx : Int32 ,
2265+ stage : Int32 ,
2266+ m_block : Int32 ,
2267+ seqlen_q : Int32 ,
22312268 scale : Float32 ,
22322269 sO : cute .Tensor ,
2270+ mO_cur : Optional [cute .Tensor ] = None ,
2271+ gO : Optional [cute .Tensor ] = None ,
2272+ gmem_tiled_copy_O : Optional [cute .TiledCopy ] = None ,
22332273 ):
22342274 """Apply final scaling and transformation to attention output before writing to global memory.
22352275
@@ -2302,6 +2342,57 @@ def correction_epilogue(
23022342 space = cute .arch .SharedSpace .shared_cta ,
23032343 )
23042344
2345+ if const_expr (self .use_correction_warps_for_epi ):
2346+ assert (not self .use_tma_O )
2347+ assert (gmem_tiled_copy_O is not None )
2348+ cute .arch .barrier (barrier_id = int (NamedBarrierFwd .Epilogue ),
2349+ number_of_threads = len (self .epilogue_warp_ids ) * cute .arch .WARP_SIZE )
2350+ gmem_thr_copy_O = gmem_tiled_copy_O .get_slice (tidx )
2351+ tOsO = gmem_thr_copy_O .partition_S (sO )
2352+ cO = cute .make_identity_tensor ((self .m_block_size , self .head_dim_v_padded ))
2353+ tOgO = gmem_thr_copy_O .partition_D (gO )
2354+ tOcO = gmem_thr_copy_O .partition_S (cO )
2355+ t0OcO = gmem_tiled_copy_O .get_slice (0 ).partition_S (cO )
2356+ tOpO = utils .predicate_k (tOcO , limit = mO_cur .shape [1 ])
2357+ # TODO: the packgqa case isn't correct rn (sometimes IMA), disabling it
2358+ assert not self .pack_gqa
2359+ pack_gqa = PackGQA (
2360+ self .m_block_size ,
2361+ self .head_dim_v_padded ,
2362+ self .check_hdim_v_oob ,
2363+ self .qhead_per_kvhead ,
2364+ )
2365+
2366+ # load acc O from smem to rmem for wider vectorization
2367+ tOrO = cute .make_fragment_like (tOsO , self .o_dtype )
2368+ cute .autovec_copy (tOsO , tOrO )
2369+ # copy acc O from rmem to gmem
2370+ if const_expr (not self .pack_gqa ):
2371+ for rest_m in cutlass .range_constexpr (cute .size (tOrO .shape [1 ])):
2372+ if (
2373+ t0OcO [0 , rest_m , 0 ][0 ]
2374+ < seqlen_q
2375+ - (self .q_stage * m_block + stage ) * self .m_block_size
2376+ - tOcO [0 ][0 ]
2377+ ):
2378+ cute .copy (
2379+ gmem_tiled_copy_O ,
2380+ tOrO [None , rest_m , None ],
2381+ tOgO [None , rest_m , None , self .q_stage * m_block + stage ],
2382+ pred = tOpO [None , rest_m , None ]
2383+ if const_expr (self .check_hdim_v_oob )
2384+ else None ,
2385+ )
2386+ else :
2387+ pack_gqa .store_O (
2388+ mO_cur ,
2389+ tOrO ,
2390+ gmem_tiled_copy_O ,
2391+ tidx ,
2392+ self .q_stage * m_block + stage ,
2393+ seqlen_q ,
2394+ )
2395+
23052396 @cute .jit
23062397 def epilogue_s2g (
23072398 self ,
@@ -2389,7 +2480,7 @@ def epilogue_s2g(
23892480 tOrO [None , rest_m , None ],
23902481 tOgO [None , rest_m , None , self .q_stage * m_block + stage ],
23912482 pred = tOpO [None , rest_m , None ]
2392- if self .check_hdim_v_oob
2483+ if const_expr ( self .check_hdim_v_oob )
23932484 else None ,
23942485 )
23952486 else :
0 commit comments