Add Magi Attention for varlen + Context Parallelism#2543
Add Magi Attention for varlen + Context Parallelism#2543wangbinluo wants to merge 4 commits intopytorch:mainfrom
Conversation
|
Hi @wangbinluo! Thank you for your pull request and welcome to our community. Action RequiredIn order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks! |
Implements MagiAttention (arXiv:2505.13211) to enable varlen attention with Context Parallelism, resolving the existing NotImplementedError in context_parallel.py. New module: torchtitan/distributed/varlen_cp/ (8 files) - Dispatch Solver: LPT greedy load balancing across CP ranks - Group-Cast: per-doc packed AllToAll-V (zero-redundancy communication) - Attention kernel: PyTorch-native flex_attention with BlockMask - Cross-layer metadata cache for dispatch plan and FFA ranges - Backward K/V reuse: saves 1x AllToAll-V - NVSHMEM transport: auto-detect with NCCL fallback Integration: - context_parallel.py: varlen case sets _cp_mesh on attention modules - attention.py: VarlenAttentionWrapper._forward_cp() dispatches to Magi Attention Zero external dependencies — all attention uses PyTorch native APIs (flex_attention, varlen_attn, scaled_dot_product_attention). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
377485e to
013a61f
Compare
Replace varlen_attn backward with flex_attention recompute + autograd.grad, making the entire attention path use only PyTorch-native flex_attention API. Key changes: - flex_attn_backward: recompute forward via _compiled_flex_attention and differentiate with torch.autograd.grad (no varlen_attn dependency) - Cache BlockMask from forward in FlexAttnMeta.block_mask, reuse in backward to skip expensive create_block_mask rebuild - Save block_mask_cache per qi on autograd ctx for backward pass - Raise torch._dynamo.config.cache_size_limit to 256 to prevent fallback to unfused O(n²) backward with many different mask shapes - Remove torch.nn.attention.varlen import (no longer needed) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Remove try/except RuntimeError patterns and fallback code paths: ring_attention.py: - Remove varlen_attn fallback in _compute_and_merge_step (direct flex_attn) - Remove _backward_step_flash (used private _flash_attention_backward API) - Remove helper functions only used by fallback: _extract_tokens_for_docs, _scatter_to_chunk, _extract_lse_for_docs, _scatter_grad_to_chunk - Remove varlen_attn and AuxRequest imports magi_attention.py: - Remove _forward_pipelined and helpers (dead code, single-machine only) - Remove per-(qi,ki) varlen_attn fallback loops in forward and backward - Remove TORCHTITAN_OVERLAP_STAGES env var - Remove unused imports (_alltoall_v_nccl, _return_lse, _backward_step, _compute_and_merge_step) Net reduction: -913 lines. The entire attention path now uses only PyTorch-native flex_attention with zero fallback complexity. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- Remove _NvshmemBufferPool and NVSHMEM detection (private API torch.distributed._symmetric_memory). _alltoall_v now uses NCCL only. - Remove unused functions: _alltoall_v_nccl, _alltoall_v_async, _return_lse, _scatter_dkv_to_owners, preinit_nvshmem_buffers - Remove unused ring-pass functions: _compute_step_ffa, _compute_and_merge_step, _backward_step_ffa, _backward_step - Fix test imports: remove references to deleted functions, move ring-pass test helpers into test file - Remove TestExtractTokensForDocs and TestScatterToChunk (tested functions that no longer exist) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Summary
The existing
_ContextParallelwith FlexAttention uses AllGather to collect the full K/V sequence, then relies onBlockMaskto skip cross-document computation. This works, but has limitations for varlen packed sequence training (SFT, DPO):No varlen input support — upstream CP requires
(B, H, S, D)padded tensors. Real SFT/DPO pipelines produce packed tokens withcu_seqlens(variable-length format used by Flash Attention). Using upstream CP requires converting back to padded format, wasting memory and compute on padding tokens.Per-document length constraints —
_PerDocumentHeadTailLoadBalancerrequires every document to be divisible by2 × cp_world_size. With CP=8, each document must be a multiple of 16 tokens. Real training data has arbitrary document lengths (137, 2041, 89, ...), making this impractical.Communication inefficiency for packed sequences — AllGather sends the entire K/V sequence to every rank, including tokens from documents that a rank has no Q tokens for. With many short documents packed into a long sequence, most communicated K/V is unused.
This PR implements MagiAttention to address these limitations, enabling varlen packed sequences with Context Parallelism.
Closes #2536
What changed
New module:
torchtitan/distributed/varlen_cp/(~2900 lines)magi_attention.py— Customautograd.Functionimplementing the Magi Attention forward/backward. Acceptscu_seqlens+ packed tokens directly. Uses AllToAll-V to communicate only needed K/V per document (zero redundancy). LPT dispatch redistributes Q sub-chunks across ranks for load balancing.magi_dispatch.py— AllToAll-V communication helpers, Q redistribution, result gathering.dispatch_solver.py— LPT greedy solver that assigns Q sub-chunks to ranks based on actual attention workload (trapezoidal area estimation fromcu_seqlens).flex_attn_kernels.py— Range-basedflex_attentionwrappers. Converts (Q_range, K_range, attn_type) descriptors intoBlockMaskfor PyTorch-nativeflex_attention. Both forward and backward useflex_attentionexclusively.mask_primitives.py— Utilities for convertingcu_seqlensto attention slices and ranges.dispatch_ops.py— Dispatch/undispatch operations for Q redistribution.ring_attention.py— Entry point (varlen_ring_attention) and LSE merge utilities.Modified files:
context_parallel.py—"varlen"case sets_cp_meshonVarlenAttentionWrappermodules.attention.py—VarlenAttentionWrapper._forward_cp()creates dispatch plan fromcu_seqlensand calls Magi Attention when CP is active.Tests (
tests/unit_tests/test_varlen_cp/):Design decisions
torch.nn.attention.flex_attentionand standardtorch.distributedcollectives. No external kernel libraries.flex_attentionfor both forward and backward — No fallback paths. The compiled Triton kernel handles all attention computation.BlockMaskare cached across transformer layers within one forward pass (samecu_seqlens→ same plan for all 96+ layers).save_for_backward; backward reuses them instead of re-gathering.Comparison with FlexAttention + CP
(B, H, S, D)padded tensor(total_tokens, H, D)+cu_seqlens2 × cp_world_sizecp_world_sizeO(seq_len)O(needed_KV)mask_mod(computation-level)cu_seqlens+ per-doc packed K/V (communication + computation)flex_attentionflex_attention(same)Test plan
🤖 Generated with Claude Code