Skip to content

fix(model): convert bool mask_cache to float additive mask for softcapping#2235

Open
nuthalapativarun wants to merge 3 commits into
Lightning-AI:mainfrom
nuthalapativarun:fix/attention-mask-softcapping-kv-cache
Open

fix(model): convert bool mask_cache to float additive mask for softcapping#2235
nuthalapativarun wants to merge 3 commits into
Lightning-AI:mainfrom
nuthalapativarun:fix/attention-mask-softcapping-kv-cache

Conversation

@nuthalapativarun
Copy link
Copy Markdown

What does this PR do?

When the KV cache is active, build_mask_cache() returns a torch.bool tensor where True indicates a position that should be attended to (lower triangle). In scaled_dot_product_attention, for models that use attention_logit_softcapping (e.g. Gemma 2), this boolean mask was added directly to the softcapped scores:

scores = scores + mask  # mask is torch.bool → adds 0 or 1, NOT 0 or -inf

This breaks causal masking: future positions received a score boost of +1 instead of -inf, so softmax assigned non-zero attention weight to tokens that should be completely masked out.

Fix

Add an elif branch that converts the incoming boolean mask to a float additive mask before the addition (True → 0.0, False → -inf). The same fix is applied to both CausalSelfAttention and MultiheadLatentAttention.

elif mask.dtype == torch.bool:
    # build_mask_cache returns a boolean mask (True=keep); convert to additive float mask
    mask = torch.zeros_like(mask, dtype=q.dtype).masked_fill_(~mask, torch.finfo(q.dtype).min)
scores = scores + mask

Testing

Added test_attention_mask_bool_to_float_with_softcapping which:

  1. Verifies mask_cache is indeed torch.bool (pre-condition of the bug)
  2. Runs a prefill forward pass with KV cache enabled
  3. Runs the same forward pass without KV cache
  4. Asserts the two outputs are numerically close — they diverge without the fix because the bool mask corrupts the attention distribution

Fixes #1672

…pping

When KV cache is active, build_mask_cache() returns a torch.bool tensor
(True=keep). In scaled_dot_product_attention the bool mask was added
directly to scores, contributing 0 or 1 instead of 0 or -inf, which
breaks causal masking for models that use attention_logit_softcapping
(e.g. Gemma 2).

Add an elif branch that converts the boolean mask to an additive float
mask (True→0.0, False→-inf) before the scores addition. The fix is
applied to both CausalSelfAttention and MultiheadLatentAttention.

Fixes Lightning-AI#1672
@azure-pipelines
Copy link
Copy Markdown

Azure Pipelines:
4 pipeline(s) require an authorized user to comment /azp run to run.

@azure-pipelines
Copy link
Copy Markdown

Azure Pipelines:
4 pipeline(s) require an authorized user to comment /azp run to run.

@nuthalapativarun
Copy link
Copy Markdown
Author

Hi! Just checking in — CI appears to be waiting on an authorized /azp run trigger. Happy to make any changes needed to move this forward. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

attention mask is incorrect when generate with softcapping

1 participant