Open
Conversation
…ed attention The FLOPs calculation incorrectly assumed full causal attention for all models, leading to significant overestimation for specialized attention patterns: - Sliding Window Attention (e.g., Gemma 3): 4× overestimate - Chunked Attention (e.g., Llama 4): 8× overestimate Changes: - Add get_effective_seq_length() helper to dynamically calculate effective sequence length based on attention pattern - Update attn_layer_flops() to use attention-aware calculation - Update MHA/GQA branch in transformer_flops() - Update MLA branch in transformer_flops() - Add comprehensive test suite with 12 test cases The fix is fully backward compatible - models without window_size or chunk_attention_size continue to use the original seq_length / 2 behavior. Impact: - Gemma 3 (window=512, seq=2048): 75% FLOPs reduction - Llama 4 (chunk=256, seq=2048): 87.5% FLOPs reduction - Standard GPT (full causal): No change (backward compatible) Fixes NVIDIA#1725 Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do?
Fixes inaccurate FLOPs calculation for models using sliding window attention (Gemma 3) and chunked attention
(Llama 4). The previous implementation assumed full causal attention for all models, leading to significant
overestimation of computational costs.
Fixes #1725
Problem
The FLOPs calculation in
megatron/training/training.pyused a hardcodedseq_length / 2formula thatincorrectly assumes full causal attention for all attention types.
Impact on Specialized Attention Patterns
1. Sliding Window Attention (e.g., Gemma 3)
min(seq_length, window_size)instead ofseq_length2. Chunked Attention (e.g., Llama 4)
chunk_sizeinstead ofseq_lengthExample
Gemma 3: seq_length=2048, window_size=512
Before: effective_seq = 2048 / 2 = 1024 ❌ 4× overestimate
After: effective_seq = 512 / 2 = 256 ✅ Accurate
Solution
Implemented an attention-pattern-aware FLOPs calculation system:
Added get_effective_seq_length(seq_len) that dynamically calculates effective sequence length:
Added TestFLOPsCalculation class with 12 test cases covering:
Changes
Core Implementation (megatron/training/training.py)
Added get_effective_seq_length() helper (lines 235-266)
Updated 3 calculation locations:
Test Suite (tests/unit_tests/test_training.py)
Added comprehensive TestFLOPsCalculation class with 12 test cases:
Run tests:
pytest tests/unit_tests/test_training.py::TestFLOPsCalculation -v
Testing
Local Testing
Run specific test class
pytest tests/unit_tests/test_training.py::TestFLOPsCalculation -v
Run all training tests
pytest tests/unit_tests/test_training.py -v
Test Coverage
Breaking Changes
None. This change is fully backward compatible:
Parameters
Existing parameters (no changes needed):
Future optional enhancement:
For Reviewers
Key Review Areas
Related Issues