Skip to content

feat(cp): add ulysses all-to-all CP path alongside ring#2418

Open
samsja wants to merge 3 commits intomainfrom
feat/ulysses-cp
Open

feat(cp): add ulysses all-to-all CP path alongside ring#2418
samsja wants to merge 3 commits intomainfrom
feat/ulysses-cp

Conversation

@samsja
Copy link
Copy Markdown
Member

@samsja samsja commented May 5, 2026

Summary

Adds a Ulysses all-to-all CP path alongside the existing ring CP, gated by a new model.cp_style arg ("ring" (default) | "ulysses"). Ulysses redistributes Q/K/V via two seq↔head all-to-alls and runs vanilla flash-attn on the full sequence with H/cp heads, so the kernel itself doesn't need to be CP-aware (works OOB with FA2/3/4, linear attn, mamba). Constraint: cp_size must divide both num_attention_heads and num_key_value_heads.

This PR adds: the cp_style config, the new ulysses_attn.py module (a2a helpers + flash wrapper + custom/HF substitution), the dispatcher branch in setup_cp_params, and the SFT/RL trainer wiring.

Slowdown vs cp=1 (GLM-4.5-Air, FA3, 8× H200, seq=64k)

cp=1 = no CP, the gold-standard throughput. The slowdown column is the cost of going to cp>1 at the same seq. Throughput = batch × seq / step_time.

cp style tps Δ vs cp=1 peak GiB
1 84 k 133.5
2 ulysses 77 k −8% 86.1
2 ring 60 k −29% 86.1
4 ulysses 85 k +1% 53.1
4 ring 57 k −32% 53.1

Ulysses pays almost nothing vs cp=1 (8% at cp=2, +1% at cp=4); ring pays ~30%. Peak memory drops monotonically with cp; ring and ulysses use the same peak at every point.

correctness

nemotron super cp1 vs cp2 (expected to have different loss since different data)

image

glm air cp2 ulysse vs ring. Exact match as same data

image

Full sweep — GLM-4.5-Air, debug.num_layers=4, ep=8

FA3 (flash_attention_3, requires #2417)

cp seq style step (s) tok/s (raw) MFU% (raw) peak GiB
1 32k ring 1.93 134 916 27.3 86.1
1 64k ring 6.23 88 022 24.8 133.5
1 128k ring OOM
2 32k ring 1.23 53 396 10.8 53.1
2 32k ulysses 0.98 60 579 12.2 53.1
2 64k ring 4.34 31 647 8.9 86.1
2 64k ulysses 3.41 39 979 11.3 86.1
2 128k ring 17.99 15 469 6.8 133.5
2 128k ulysses 12.14 22 756 10.1 133.5
4 64k ring 2.28 14 300 4.0 53.1
4 64k ulysses 1.55 21 242 6.0 53.1
8 128k ring 3.04 5 211 2.3 53.1
8 128k ulysses 2.40 6 956 3.1 53.1

The tok/s (raw) and MFU% (raw) columns are bench-reported and divide by cp (per-rank-useful-work convention), which understates cp>1 numbers by ~cp×. The slowdown table at the top uses batch × seq / step_time directly.

FA2 (flash_attention_2)

cp seq style step (s) tok/s (raw) MFU% (raw) peak GiB
1 32k ring 2.91 94 311 19.1 86.1
1 64k ring 12.25 45 974 13.0 133.5
1 128k ring OOM
2 32k ring 3.01 25 142 5.1 53.1
2 32k ulysses 1.74 37 691 7.6 53.1
2 64k ring 8.54 16 241 4.6 86.1
2 64k ulysses 5.56 24 414 6.9 86.1
2 128k ring 29.69 8 911 3.9 133.5
2 128k ulysses 19.73 13 633 6.0 133.5
4 64k ring 4.45 7 814 2.2 53.1
4 64k ulysses 2.43 12 787 3.6 53.1
8 128k ring 5.75 2 829 1.3 53.1
8 128k ulysses 4.31 3 933 1.7 53.1

Ulysses faster than ring in every (cp, seq) cell; peak memory identical.


Note

Medium Risk
Adds a new context-parallel attention path and dynamically patches FlashAttention/HF attention entrypoints, which is performance- and correctness-sensitive in distributed training. Risk is mitigated by keeping ring as default and gating unsupported model/attention combinations via validation.

Overview
Adds a new model.cp_style switch (ring default, ulysses optional) to choose between ring-attention CP and an all-to-all Ulysses CP strategy.

Implements Ulysses CP in ulysses_attn.py by redistributing Q/K/V via seq↔head all-to-all, running local FlashAttention (FA2/FA3/FA4), then all-to-all back; includes patching for both custom FlashAttention._compute_attention and HF _flash_attention_forward.

Wires the new style through RL and SFT trainers (selecting the appropriate patching path and passing cp_style into setup_cp_params), and updates setup_cp_params to publish full-sequence cu_seqlens/max_seqlen to either ring or ulysses. Adds assert_cp_style_supports_model to reject ring CP for models containing linear-attn/SSM layers, only configuring hybrid/Mamba CP setup when cp_style='ulysses'.

Reviewed by Cursor Bugbot for commit e577861. Bugbot is set up for automated code reviews on this repo. Configure here.

samsja and others added 2 commits May 5, 2026 08:39
Adds a `model.cp_style` arg ("ring" default | "ulysses") that selects
between the existing ring CP and a new Ulysses all-to-all variant.

Ulysses redistributes Q/K/V via two seq <-> head all-to-alls and runs
vanilla flash attention on the full sequence with H/cp heads, so the
attention kernel does not need to be CP-aware. Works OOB with FA2/3/4,
linear attention, mamba, etc.

Constraint: cp_size must divide both num_attention_heads and num_kv_heads.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@samsja samsja marked this pull request as ready for review May 5, 2026 04:14
Ring CP is a softmax-attention algorithm (sequence ring all-gather of K/V).
It cannot correctly drive non-softmax kernels (Qwen3.5 DeltaNet linear-attn,
NemotronH Mamba). Previously the linear-attn CP setup hooks were called
unconditionally under ring, which silently mixed two different CP layouts
in the same model.

Now:
- Refuse cp_style='ring' at startup if the model has any linear/SSM
  attention layer, with a clear error pointing at cp_style='ulysses'.
- Only call setup_hybrid_cp / setup_nemotron_h_cp under cp_style='ulysses'.
- setup_sparse_mla_cp (still softmax) keeps working under both.

Ulysses' all-to-all is purely on Q/K/V tensors, so the linear/SSM kernel
runs unchanged on a sequence shard — this is the correct CP path for
hybrid models.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Copy link
Copy Markdown

@cursor cursor Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 1 potential issue.

Fix All in Cursor

❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

Reviewed by Cursor Bugbot for commit e577861. Configure here.


window_size = (-1, -1)
if sliding_window is not None and key_states.shape[1] > sliding_window:
window_size = (sliding_window, sliding_window)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sliding window check uses local sharded sequence length

Medium Severity

In _ulysses_flash_attention_forward, the sliding window gate key_states.shape[1] > sliding_window checks the local (sequence-sharded) key length. After the all-to-all inside ulysses_flash_attn_varlen_func, flash attention operates on the global sequence (S_local * cp_size). When S_local <= sliding_window < S_global, the window is incorrectly disabled and full causal attention runs on the entire global sequence. The custom ulysses path (_ulysses_compute_attention) correctly applies the window unconditionally without a length check.

Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit e577861. Configure here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant