feat(cp): add ulysses all-to-all CP path alongside ring#2418
feat(cp): add ulysses all-to-all CP path alongside ring#2418
Conversation
Adds a `model.cp_style` arg ("ring" default | "ulysses") that selects
between the existing ring CP and a new Ulysses all-to-all variant.
Ulysses redistributes Q/K/V via two seq <-> head all-to-alls and runs
vanilla flash attention on the full sequence with H/cp heads, so the
attention kernel does not need to be CP-aware. Works OOB with FA2/3/4,
linear attention, mamba, etc.
Constraint: cp_size must divide both num_attention_heads and num_kv_heads.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Ring CP is a softmax-attention algorithm (sequence ring all-gather of K/V). It cannot correctly drive non-softmax kernels (Qwen3.5 DeltaNet linear-attn, NemotronH Mamba). Previously the linear-attn CP setup hooks were called unconditionally under ring, which silently mixed two different CP layouts in the same model. Now: - Refuse cp_style='ring' at startup if the model has any linear/SSM attention layer, with a clear error pointing at cp_style='ulysses'. - Only call setup_hybrid_cp / setup_nemotron_h_cp under cp_style='ulysses'. - setup_sparse_mla_cp (still softmax) keeps working under both. Ulysses' all-to-all is purely on Q/K/V tensors, so the linear/SSM kernel runs unchanged on a sequence shard — this is the correct CP path for hybrid models. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes and found 1 potential issue.
❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.
Reviewed by Cursor Bugbot for commit e577861. Configure here.
|
|
||
| window_size = (-1, -1) | ||
| if sliding_window is not None and key_states.shape[1] > sliding_window: | ||
| window_size = (sliding_window, sliding_window) |
There was a problem hiding this comment.
Sliding window check uses local sharded sequence length
Medium Severity
In _ulysses_flash_attention_forward, the sliding window gate key_states.shape[1] > sliding_window checks the local (sequence-sharded) key length. After the all-to-all inside ulysses_flash_attn_varlen_func, flash attention operates on the global sequence (S_local * cp_size). When S_local <= sliding_window < S_global, the window is incorrectly disabled and full causal attention runs on the entire global sequence. The custom ulysses path (_ulysses_compute_attention) correctly applies the window unconditionally without a length check.
Reviewed by Cursor Bugbot for commit e577861. Configure here.


Summary
Adds a Ulysses all-to-all CP path alongside the existing ring CP, gated by a new
model.cp_stylearg ("ring"(default) |"ulysses"). Ulysses redistributes Q/K/V via two seq↔head all-to-alls and runs vanilla flash-attn on the full sequence with H/cp heads, so the kernel itself doesn't need to be CP-aware (works OOB with FA2/3/4, linear attn, mamba). Constraint:cp_sizemust divide bothnum_attention_headsandnum_key_value_heads.This PR adds: the
cp_styleconfig, the newulysses_attn.pymodule (a2a helpers + flash wrapper + custom/HF substitution), the dispatcher branch insetup_cp_params, and the SFT/RL trainer wiring.Slowdown vs cp=1 (GLM-4.5-Air, FA3, 8× H200, seq=64k)
cp=1 = no CP, the gold-standard throughput. The slowdown column is the cost of going to cp>1 at the same seq. Throughput =
batch × seq / step_time.Ulysses pays almost nothing vs cp=1 (8% at cp=2, +1% at cp=4); ring pays ~30%. Peak memory drops monotonically with cp; ring and ulysses use the same peak at every point.
correctness
nemotron super cp1 vs cp2 (expected to have different loss since different data)
glm air cp2 ulysse vs ring. Exact match as same data
Full sweep — GLM-4.5-Air, debug.num_layers=4, ep=8
FA3 (
flash_attention_3, requires #2417)The
tok/s (raw)andMFU% (raw)columns are bench-reported and divide bycp(per-rank-useful-work convention), which understates cp>1 numbers by ~cp×. The slowdown table at the top usesbatch × seq / step_timedirectly.FA2 (
flash_attention_2)Ulysses faster than ring in every (cp, seq) cell; peak memory identical.
Note
Medium Risk
Adds a new context-parallel attention path and dynamically patches FlashAttention/HF attention entrypoints, which is performance- and correctness-sensitive in distributed training. Risk is mitigated by keeping
ringas default and gating unsupported model/attention combinations via validation.Overview
Adds a new
model.cp_styleswitch (ringdefault,ulyssesoptional) to choose between ring-attention CP and an all-to-all Ulysses CP strategy.Implements Ulysses CP in
ulysses_attn.pyby redistributing Q/K/V via seq↔head all-to-all, running local FlashAttention (FA2/FA3/FA4), then all-to-all back; includes patching for both customFlashAttention._compute_attentionand HF_flash_attention_forward.Wires the new style through RL and SFT trainers (selecting the appropriate patching path and passing
cp_styleintosetup_cp_params), and updatessetup_cp_paramsto publish full-sequencecu_seqlens/max_seqlento either ring or ulysses. Addsassert_cp_style_supports_modelto rejectringCP for models containing linear-attn/SSM layers, only configuring hybrid/Mamba CP setup whencp_style='ulysses'.Reviewed by Cursor Bugbot for commit e577861. Bugbot is set up for automated code reviews on this repo. Configure here.