Skip to content

Add Magi Attention for varlen + Context Parallelism#2543

Open
wangbinluo wants to merge 4 commits intopytorch:mainfrom
wangbinluo:feature/magi-attention
Open

Add Magi Attention for varlen + Context Parallelism#2543
wangbinluo wants to merge 4 commits intopytorch:mainfrom
wangbinluo:feature/magi-attention

Conversation

@wangbinluo
Copy link

@wangbinluo wangbinluo commented Mar 11, 2026

Summary

The existing _ContextParallel with FlexAttention uses AllGather to collect the full K/V sequence, then relies on BlockMask to skip cross-document computation. This works, but has limitations for varlen packed sequence training (SFT, DPO):

  1. No varlen input support — upstream CP requires (B, H, S, D) padded tensors. Real SFT/DPO pipelines produce packed tokens with cu_seqlens (variable-length format used by Flash Attention). Using upstream CP requires converting back to padded format, wasting memory and compute on padding tokens.

  2. Per-document length constraints_PerDocumentHeadTailLoadBalancer requires every document to be divisible by 2 × cp_world_size. With CP=8, each document must be a multiple of 16 tokens. Real training data has arbitrary document lengths (137, 2041, 89, ...), making this impractical.

  3. Communication inefficiency for packed sequences — AllGather sends the entire K/V sequence to every rank, including tokens from documents that a rank has no Q tokens for. With many short documents packed into a long sequence, most communicated K/V is unused.

This PR implements MagiAttention to address these limitations, enabling varlen packed sequences with Context Parallelism.

Closes #2536

What changed

New module: torchtitan/distributed/varlen_cp/ (~2900 lines)

  • magi_attention.py — Custom autograd.Function implementing the Magi Attention forward/backward. Accepts cu_seqlens + packed tokens directly. Uses AllToAll-V to communicate only needed K/V per document (zero redundancy). LPT dispatch redistributes Q sub-chunks across ranks for load balancing.
  • magi_dispatch.py — AllToAll-V communication helpers, Q redistribution, result gathering.
  • dispatch_solver.py — LPT greedy solver that assigns Q sub-chunks to ranks based on actual attention workload (trapezoidal area estimation from cu_seqlens).
  • flex_attn_kernels.py — Range-based flex_attention wrappers. Converts (Q_range, K_range, attn_type) descriptors into BlockMask for PyTorch-native flex_attention. Both forward and backward use flex_attention exclusively.
  • mask_primitives.py — Utilities for converting cu_seqlens to attention slices and ranges.
  • dispatch_ops.py — Dispatch/undispatch operations for Q redistribution.
  • ring_attention.py — Entry point (varlen_ring_attention) and LSE merge utilities.

Modified files:

  • context_parallel.py"varlen" case sets _cp_mesh on VarlenAttentionWrapper modules.
  • attention.pyVarlenAttentionWrapper._forward_cp() creates dispatch plan from cu_seqlens and calls Magi Attention when CP is active.

Tests (tests/unit_tests/test_varlen_cp/):

  • Forward/backward correctness against single-GPU reference (5 configs × fwd + bwd)
  • Ring-pass / all-gather reference equivalence
  • Dispatch solver, mask primitives, dispatch ops unit tests

Design decisions

  • Zero external dependencies — Uses only torch.nn.attention.flex_attention and standard torch.distributed collectives. No external kernel libraries.
  • flex_attention for both forward and backward — No fallback paths. The compiled Triton kernel handles all attention computation.
  • Cross-layer metadata cache — Dispatch plan, AllToAll-V split sizes, and BlockMask are cached across transformer layers within one forward pass (same cu_seqlens → same plan for all 96+ layers).
  • Backward K/V reuse — Forward saves packed K/V via save_for_backward; backward reuses them instead of re-gathering.

Comparison with FlexAttention + CP

FlexAttention + CP (upstream) Varlen + Magi Attention (this PR)
Input format (B, H, S, D) padded tensor (total_tokens, H, D) + cu_seqlens
Document length constraint Each doc divisible by 2 × cp_world_size Only total sequence divisible by cp_world_size
Communication AllGather full K/V — O(seq_len) Per-doc AllToAll-V — O(needed_KV)
Document isolation BlockMask mask_mod (computation-level) cu_seqlens + per-doc packed K/V (communication + computation)
Load balancing PTRR index reordering LPT Q sub-chunk redistribution across ranks
Attention kernel flex_attention flex_attention (same)

Test plan

  • Unit tests: dispatch solver, mask primitives, dispatch ops
  • Unit tests: Magi Attention forward (5 configs) and backward (3 configs)
  • Ring-pass / all-gather reference equivalence
  • Integration test: varlen + CP + TP in training loop
  • Multi-node benchmark

🤖 Generated with Claude Code

@meta-cla
Copy link

meta-cla bot commented Mar 11, 2026

Hi @wangbinluo!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks!

Implements MagiAttention (arXiv:2505.13211) to enable varlen attention
with Context Parallelism, resolving the existing NotImplementedError
in context_parallel.py.

New module: torchtitan/distributed/varlen_cp/ (8 files)
- Dispatch Solver: LPT greedy load balancing across CP ranks
- Group-Cast: per-doc packed AllToAll-V (zero-redundancy communication)
- Attention kernel: PyTorch-native flex_attention with BlockMask
- Cross-layer metadata cache for dispatch plan and FFA ranges
- Backward K/V reuse: saves 1x AllToAll-V
- NVSHMEM transport: auto-detect with NCCL fallback

Integration:
- context_parallel.py: varlen case sets _cp_mesh on attention modules
- attention.py: VarlenAttentionWrapper._forward_cp() dispatches to Magi Attention

Zero external dependencies — all attention uses PyTorch native APIs
(flex_attention, varlen_attn, scaled_dot_product_attention).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@wangbinluo wangbinluo force-pushed the feature/magi-attention branch from 377485e to 013a61f Compare March 11, 2026 16:18
wangbinluo and others added 2 commits March 12, 2026 07:05
Replace varlen_attn backward with flex_attention recompute + autograd.grad,
making the entire attention path use only PyTorch-native flex_attention API.

Key changes:
- flex_attn_backward: recompute forward via _compiled_flex_attention and
  differentiate with torch.autograd.grad (no varlen_attn dependency)
- Cache BlockMask from forward in FlexAttnMeta.block_mask, reuse in backward
  to skip expensive create_block_mask rebuild
- Save block_mask_cache per qi on autograd ctx for backward pass
- Raise torch._dynamo.config.cache_size_limit to 256 to prevent fallback
  to unfused O(n²) backward with many different mask shapes
- Remove torch.nn.attention.varlen import (no longer needed)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Remove try/except RuntimeError patterns and fallback code paths:

ring_attention.py:
- Remove varlen_attn fallback in _compute_and_merge_step (direct flex_attn)
- Remove _backward_step_flash (used private _flash_attention_backward API)
- Remove helper functions only used by fallback: _extract_tokens_for_docs,
  _scatter_to_chunk, _extract_lse_for_docs, _scatter_grad_to_chunk
- Remove varlen_attn and AuxRequest imports

magi_attention.py:
- Remove _forward_pipelined and helpers (dead code, single-machine only)
- Remove per-(qi,ki) varlen_attn fallback loops in forward and backward
- Remove TORCHTITAN_OVERLAP_STAGES env var
- Remove unused imports (_alltoall_v_nccl, _return_lse, _backward_step,
  _compute_and_merge_step)

Net reduction: -913 lines. The entire attention path now uses only
PyTorch-native flex_attention with zero fallback complexity.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Mar 12, 2026
- Remove _NvshmemBufferPool and NVSHMEM detection (private API
  torch.distributed._symmetric_memory). _alltoall_v now uses NCCL only.
- Remove unused functions: _alltoall_v_nccl, _alltoall_v_async,
  _return_lse, _scatter_dkv_to_owners, preinit_nvshmem_buffers
- Remove unused ring-pass functions: _compute_step_ffa,
  _compute_and_merge_step, _backward_step_ffa, _backward_step
- Fix test imports: remove references to deleted functions, move
  ring-pass test helpers into test file
- Remove TestExtractTokensForDocs and TestScatterToChunk (tested
  functions that no longer exist)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[RFC] Magi Attention: Varlen + Context Parallelism with Document-Aware Load Balancing

1 participant