Skip to content

Commit a5bb07e

Browse files
committed
[CUTE][SM100] Fix backward gqa on sm100 post mask-mod semantic change
stack-info: PR: #2146, branch: drisspg/stack/11
1 parent e317aa4 commit a5bb07e

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

tests/cute/test_mask_mod.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def compute_reference_flex_attn(tensors, mask_mod_flex, block_size: tuple[int, i
9595
device=q.device,
9696
**block_mask_kwargs,
9797
)
98-
out_ref = flex_attention(q, k, v, block_mask=block_mask, scale=scale)
98+
out_ref = flex_attention(q, k, v, block_mask=block_mask, scale=scale, enable_gqa=True)
9999
return out_ref.transpose(1, 2).contiguous()
100100

101101

@@ -809,7 +809,7 @@ def run_flex_reference_bwd(q, k, v, block_mask, grad_out, dtype=None):
809809

810810
# Use flex_attention directly without torch.compile for backward tests
811811
# torch.compile can hang on certain mask patterns (e.g., mini_causal with float32)
812-
out_ref = flex_attention(q_ref, k_ref, v_ref, block_mask=block_mask)
812+
out_ref = flex_attention(q_ref, k_ref, v_ref, block_mask=block_mask, enable_gqa=True)
813813
dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), grad_out_ref)
814814

815815
# Transpose back to BSHD

0 commit comments

Comments
 (0)