Skip to content

Comments

Fix flops sliding window attention#3280

Open
zhenghax wants to merge 10 commits intoNVIDIA:mainfrom
zhenghax:fix-flops-sliding-window-attention
Open

Fix flops sliding window attention#3280
zhenghax wants to merge 10 commits intoNVIDIA:mainfrom
zhenghax:fix-flops-sliding-window-attention

Conversation

@zhenghax
Copy link

@zhenghax zhenghax commented Feb 6, 2026


What does this PR do?

Fixes inaccurate FLOPs calculation for models using sliding window attention (Gemma 3) and chunked attention
(Llama 4). The previous implementation assumed full causal attention for all models, leading to significant
overestimation of computational costs.

Fixes #1725

Problem

The FLOPs calculation in megatron/training/training.py used a hardcoded seq_length / 2 formula that
incorrectly assumes full causal attention for all attention types.

Impact on Specialized Attention Patterns

1. Sliding Window Attention (e.g., Gemma 3)

  • Only attends to a fixed window of previous tokens
  • Should use min(seq_length, window_size) instead of seq_length
  • Previous calculation overestimated by 4× (window=512, seq=2048)

2. Chunked Attention (e.g., Llama 4)

  • Attention computed only within fixed-size chunks
  • Should use chunk_size instead of seq_length
  • Previous calculation overestimated by 8× (chunk=256, seq=2048)

Example

Gemma 3: seq_length=2048, window_size=512
Before: effective_seq = 2048 / 2 = 1024 ❌ 4× overestimate
After: effective_seq = 512 / 2 = 256 ✅ Accurate


Solution

Implemented an attention-pattern-aware FLOPs calculation system:

  1. New Helper Function

Added get_effective_seq_length(seq_len) that dynamically calculates effective sequence length:

  • Priority order: chunk_attention_size > window_size > full_causal (default)
  • Edge cases: Handles infinite windows (-1), missing attributes
  • Backward compatible: Falls back to seq_length / 2 when no specialized attention is set
  1. Updated All FLOPs Calculation Points
  • attn_layer_flops() - Used by hybrid models
  • MHA/GQA branch in transformer_flops() - Standard transformers
  • MLA branch in transformer_flops() - DeepSeek-style attention
  1. Comprehensive Test Suite

Added TestFLOPsCalculation class with 12 test cases covering:

  • All attention patterns (full causal, sliding window, chunked)
  • All attention types (MHA, GQA, MLA)
  • Edge cases and scaling properties

Changes

Core Implementation (megatron/training/training.py)

Added get_effective_seq_length() helper (lines 235-266)

  def get_effective_seq_length(seq_len):
      # Priority: chunk > window > full_causal
      if hasattr(args, 'chunk_attention_size') and args.chunk_attention_size is not None:
          effective_len = args.chunk_attention_size
      elif hasattr(args, 'window_size') and args.window_size is not None:
          # Handle tuple windows, filter out -1 (infinite)
          finite_windows = [w for w in args.window_size if w > 0]
          effective_len = min(seq_len, max(finite_windows)) if finite_windows else seq_len
      else:
          effective_len = seq_len  # Default: full causal

      return effective_len / 2  # Causal mask factor

Updated 3 calculation locations:

  1. Line 296: attn_layer_flops()
  2. Line 444: MLA branch in transformer_flops()
  3. Line 475: MHA/GQA branch in transformer_flops()

Test Suite (tests/unit_tests/test_training.py)

Added comprehensive TestFLOPsCalculation class with 12 test cases:

  • ✅ test_full_causal_attention_baseline - Backward compatibility
  • ✅ test_sliding_window_attention_reduces_flops - Verify FLOPs reduction
  • ✅ test_sliding_window_with_infinite_window - Edge case handling
  • ✅ test_chunked_attention_reduces_flops - Chunked attention accuracy
  • ✅ test_gqa_with_sliding_window - GQA compatibility
  • ✅ test_mla_with_sliding_window - MLA compatibility
  • ✅ test_chunk_attention_takes_precedence - Priority rules
  • ✅ test_various_window_sizes - Parametrized testing
  • ✅ test_various_chunk_sizes - Parametrized testing
  • ✅ test_flops_scale_with_batch_size - Scaling verification

Run tests:
pytest tests/unit_tests/test_training.py::TestFLOPsCalculation -v


Testing

Local Testing

Run specific test class

pytest tests/unit_tests/test_training.py::TestFLOPsCalculation -v

Run all training tests

pytest tests/unit_tests/test_training.py -v

Test Coverage

  • ✅ All attention patterns tested
  • ✅ All attention types (MHA, GQA, MLA) tested
  • ✅ Edge cases covered (infinite windows, missing attrs)
  • ✅ Scaling properties verified

Breaking Changes

None. This change is fully backward compatible:

  • ✅ Models without window_size or chunk_attention_size use original behavior
  • ✅ All existing training scripts continue to work unchanged
  • ✅ FLOPs reported for standard models (GPT, BERT, etc.) remain the same
  • ✅ No API changes, no parameter changes required

Parameters

Existing parameters (no changes needed):

  • --window-size - Already exists for sliding window attention (tuple format)

Future optional enhancement:

  • --chunk-attention-size - Can be added as explicit CLI parameter
    • Currently works via hasattr() checks
    • Formal parameter in arguments.py could improve clarity

For Reviewers

Key Review Areas

  1. Core Logic Correctness (megatron/training/training.py)
  • Lines 235-266: get_effective_seq_length() logic
    • Priority order (chunk > window > full)
    • Edge case handling (infinite windows, tuples)
    • Causal mask division (÷ 2)
  1. Consistency Across Branches
  • Line 296: attn_layer_flops() updated ✅
  • Line 444: MLA branch updated ✅
  • Line 475: MHA/GQA branch updated ✅
  • All three use the same helper function ✅
  1. Test Coverage (tests/unit_tests/test_training.py)
  • Lines 139-325: Comprehensive test suite
  • Parametrized tests for various configurations
  • Backward compatibility verification

Related Issues


zhenghax and others added 5 commits January 28, 2026 13:53
…ed attention

The FLOPs calculation incorrectly assumed full causal attention for all
models, leading to significant overestimation for specialized attention
patterns:

- Sliding Window Attention (e.g., Gemma 3): 4× overestimate
- Chunked Attention (e.g., Llama 4): 8× overestimate

Changes:
- Add get_effective_seq_length() helper to dynamically calculate
  effective sequence length based on attention pattern
- Update attn_layer_flops() to use attention-aware calculation
- Update MHA/GQA branch in transformer_flops()
- Update MLA branch in transformer_flops()
- Add comprehensive test suite with 12 test cases

The fix is fully backward compatible - models without window_size or
chunk_attention_size continue to use the original seq_length / 2
behavior.

Impact:
- Gemma 3 (window=512, seq=2048): 75% FLOPs reduction
- Llama 4 (chunk=256, seq=2048): 87.5% FLOPs reduction
- Standard GPT (full causal): No change (backward compatible)

Fixes NVIDIA#1725

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
@zhenghax zhenghax requested a review from a team as a code owner February 6, 2026 05:43
@copy-pr-bot
Copy link

copy-pr-bot bot commented Feb 6, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@ko3n1g ko3n1g requested a review from a team February 6, 2026 05:44
@chtruong814 chtruong814 added the needs-follow-up Issue needs follow-up label Feb 8, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG] Inaccurate FLOPs Calculation for Models with Specialized Attention

3 participants