Skip to content

Commit 013a61f

Browse files
wangbinluoclaude
andcommitted
Add Magi Attention for varlen + Context Parallelism
Implements MagiAttention (arXiv:2505.13211) to enable varlen attention with Context Parallelism, resolving the existing NotImplementedError in context_parallel.py. New module: torchtitan/distributed/varlen_cp/ (8 files) - Dispatch Solver: LPT greedy load balancing across CP ranks - Group-Cast: per-doc packed AllToAll-V (zero-redundancy communication) - Attention kernel: PyTorch-native flex_attention with BlockMask - Cross-layer metadata cache for dispatch plan and FFA ranges - Backward K/V reuse: saves 1x AllToAll-V - NVSHMEM transport: auto-detect with NCCL fallback Integration: - context_parallel.py: varlen case sets _cp_mesh on attention modules - attention.py: VarlenAttentionWrapper._forward_cp() dispatches to Magi Attention Zero external dependencies — all attention uses PyTorch native APIs (flex_attention, varlen_attn, scaled_dot_product_attention). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent c6856e3 commit 013a61f

15 files changed

Lines changed: 6191 additions & 34 deletions
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
11+
from torchtitan.distributed.varlen_cp.dispatch_ops import (
12+
compute_local_cu_seqlens,
13+
shard_sequence,
14+
)
15+
16+
17+
class TestComputeLocalCuSeqlens(unittest.TestCase):
18+
def test_doc_within_chunk(self):
19+
"""Document entirely within the chunk."""
20+
global_cu = torch.tensor([0, 128, 256], dtype=torch.int32)
21+
local_cu, max_seqlen = compute_local_cu_seqlens(global_cu, 0, 128)
22+
self.assertEqual(local_cu.tolist(), [0, 128])
23+
self.assertEqual(max_seqlen, 128)
24+
25+
def test_doc_spanning_chunk(self):
26+
"""Document spans across the chunk boundary."""
27+
global_cu = torch.tensor([0, 300, 512], dtype=torch.int32)
28+
# Chunk [0, 256): contains part of doc 0 (0-256)
29+
local_cu, max_seqlen = compute_local_cu_seqlens(global_cu, 0, 256)
30+
self.assertEqual(local_cu.tolist(), [0, 256])
31+
self.assertEqual(max_seqlen, 256)
32+
33+
# Chunk [256, 512): contains rest of doc 0 (256-300) and all of doc 1 (300-512)
34+
local_cu, max_seqlen = compute_local_cu_seqlens(global_cu, 256, 512)
35+
self.assertEqual(local_cu.tolist(), [0, 44, 256])
36+
self.assertEqual(max_seqlen, 212) # doc 1 has 212 tokens in this chunk
37+
38+
def test_multiple_docs_in_chunk(self):
39+
"""Multiple documents fit within one chunk."""
40+
global_cu = torch.tensor([0, 64, 128, 192, 256], dtype=torch.int32)
41+
local_cu, max_seqlen = compute_local_cu_seqlens(global_cu, 0, 256)
42+
self.assertEqual(local_cu.tolist(), [0, 64, 128, 192, 256])
43+
self.assertEqual(max_seqlen, 64)
44+
45+
def test_chunk_with_no_doc_boundaries(self):
46+
"""Chunk is entirely within a single document."""
47+
global_cu = torch.tensor([0, 512], dtype=torch.int32)
48+
local_cu, max_seqlen = compute_local_cu_seqlens(global_cu, 128, 384)
49+
self.assertEqual(local_cu.tolist(), [0, 256])
50+
self.assertEqual(max_seqlen, 256)
51+
52+
53+
class TestShardSequence(unittest.TestCase):
54+
def test_basic_sharding(self):
55+
x = torch.arange(8).float()
56+
shard_0 = shard_sequence(x, cp_rank=0, cp_world_size=2, seq_dim=0)
57+
shard_1 = shard_sequence(x, cp_rank=1, cp_world_size=2, seq_dim=0)
58+
torch.testing.assert_close(shard_0, torch.tensor([0.0, 1, 2, 3]))
59+
torch.testing.assert_close(shard_1, torch.tensor([4.0, 5, 6, 7]))
60+
61+
def test_2d_sharding(self):
62+
x = torch.arange(16).float().reshape(2, 8)
63+
shard_0 = shard_sequence(x, cp_rank=0, cp_world_size=2, seq_dim=1)
64+
shard_1 = shard_sequence(x, cp_rank=1, cp_world_size=2, seq_dim=1)
65+
self.assertEqual(shard_0.shape, (2, 4))
66+
self.assertEqual(shard_1.shape, (2, 4))
67+
68+
69+
if __name__ == "__main__":
70+
unittest.main()
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
from torchtitan.distributed.varlen_cp.dispatch_solver import solve_dispatch
10+
from torchtitan.distributed.varlen_cp.mask_primitives import (
11+
cu_seqlens_to_attn_slices,
12+
)
13+
14+
15+
class TestSolveDispatch(unittest.TestCase):
16+
def test_uniform_docs(self):
17+
"""Uniform document lengths should be balanced by default."""
18+
cu_seqlens = [0, 128, 256, 384, 512]
19+
slices = cu_seqlens_to_attn_slices(cu_seqlens)
20+
plan = solve_dispatch(slices, total_seqlen=512, chunk_size=256, cp_world_size=2)
21+
22+
self.assertEqual(plan.cp_world_size, 2)
23+
self.assertEqual(plan.chunk_size, 256)
24+
self.assertEqual(plan.total_seqlen, 512)
25+
self.assertEqual(plan.pad_size, 0)
26+
27+
# All ranks should have some work
28+
for rank in range(2):
29+
self.assertGreater(plan.get_rank_work(rank), 0)
30+
31+
def test_skewed_docs(self):
32+
"""Skewed document lengths should still be distributed."""
33+
cu_seqlens = [0, 400, 450, 500, 512]
34+
slices = cu_seqlens_to_attn_slices(cu_seqlens)
35+
plan = solve_dispatch(slices, total_seqlen=512, chunk_size=256, cp_world_size=2)
36+
37+
# Both ranks should get work
38+
work_0 = plan.get_rank_work(0)
39+
work_1 = plan.get_rank_work(1)
40+
self.assertGreater(work_0, 0)
41+
self.assertGreater(work_1, 0)
42+
43+
def test_all_slices_covered(self):
44+
"""All global slices should appear in some rank's assignment."""
45+
cu_seqlens = [0, 128, 300, 512]
46+
global_slices = cu_seqlens_to_attn_slices(cu_seqlens)
47+
plan = solve_dispatch(
48+
global_slices, total_seqlen=512, chunk_size=256, cp_world_size=2
49+
)
50+
51+
# Count total sub-slices across all ranks
52+
total_sub_slices = sum(
53+
len(cp.slices)
54+
for rank_assignments in plan.assignments
55+
for cp in rank_assignments
56+
)
57+
self.assertGreater(total_sub_slices, 0)
58+
59+
def test_single_rank(self):
60+
"""Single CP rank should get all work."""
61+
cu_seqlens = [0, 256]
62+
slices = cu_seqlens_to_attn_slices(cu_seqlens)
63+
plan = solve_dispatch(slices, total_seqlen=256, chunk_size=256, cp_world_size=1)
64+
65+
self.assertEqual(len(plan.assignments), 1)
66+
self.assertGreater(plan.get_rank_work(0), 0)
67+
68+
def test_minimax_property(self):
69+
"""Greedy min-heap should produce reasonable load balance."""
70+
cu_seqlens = [0, 100, 200, 300, 400, 512]
71+
slices = cu_seqlens_to_attn_slices(cu_seqlens)
72+
plan = solve_dispatch(slices, total_seqlen=512, chunk_size=128, cp_world_size=4)
73+
74+
works = [plan.get_rank_work(r) for r in range(4)]
75+
# Max work should be within 3x of min work (loose bound)
76+
if min(works) > 0:
77+
ratio = max(works) / min(works)
78+
self.assertLess(ratio, 3.0)
79+
80+
def test_num_chunks(self):
81+
plan = solve_dispatch(
82+
cu_seqlens_to_attn_slices([0, 512]),
83+
total_seqlen=512,
84+
chunk_size=128,
85+
cp_world_size=4,
86+
)
87+
self.assertEqual(plan.num_chunks, 4)
88+
89+
def test_pair_has_work(self):
90+
"""pair_has_work returns True for valid pairs and False for above-diagonal pairs."""
91+
# Single doc of length 256, chunk_size=128, 2 chunks
92+
# Valid pairs: (0,0) diagonal, (1,1) diagonal, (1,0) below diagonal
93+
# Invalid: (0,1) above diagonal
94+
cu_seqlens = [0, 256]
95+
slices = cu_seqlens_to_attn_slices(cu_seqlens)
96+
plan = solve_dispatch(slices, total_seqlen=256, chunk_size=128, cp_world_size=2)
97+
98+
# Diagonal and below-diagonal pairs should have work
99+
self.assertTrue(plan.pair_has_work(0, 0))
100+
self.assertTrue(plan.pair_has_work(1, 1))
101+
self.assertTrue(plan.pair_has_work(1, 0))
102+
103+
# Above diagonal should NOT have work
104+
self.assertFalse(plan.pair_has_work(0, 1))
105+
106+
def test_pair_has_work_no_spanning_docs(self):
107+
"""pair_has_work returns False for off-diagonal when no doc spans chunks."""
108+
# Two docs, each exactly one chunk, no doc spans both chunks
109+
cu_seqlens = [0, 128, 256]
110+
slices = cu_seqlens_to_attn_slices(cu_seqlens)
111+
plan = solve_dispatch(slices, total_seqlen=256, chunk_size=128, cp_world_size=2)
112+
113+
# Diagonal pairs have work
114+
self.assertTrue(plan.pair_has_work(0, 0))
115+
self.assertTrue(plan.pair_has_work(1, 1))
116+
117+
# Off-diagonal pairs have NO work (no doc spans both chunks)
118+
self.assertFalse(plan.pair_has_work(1, 0))
119+
self.assertFalse(plan.pair_has_work(0, 1))
120+
121+
122+
if __name__ == "__main__":
123+
unittest.main()
Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
11+
from torchtitan.distributed.varlen_cp.mask_primitives import (
12+
AttnSlice,
13+
cu_seqlens_to_attn_slices,
14+
make_slice_mask,
15+
MaskType,
16+
split_slice_at_chunk_boundary,
17+
)
18+
19+
20+
class TestMaskType(unittest.TestCase):
21+
def test_mask_type_values(self):
22+
self.assertEqual(MaskType.FULL, 0)
23+
self.assertEqual(MaskType.CAUSAL, 1)
24+
self.assertEqual(MaskType.INVCAUSAL, 2)
25+
self.assertEqual(MaskType.BICAUSAL, 3)
26+
27+
28+
class TestAttnSlice(unittest.TestCase):
29+
def test_basic_properties(self):
30+
s = AttnSlice(q_start=0, q_end=10, k_start=0, k_end=10, mask_type=MaskType.FULL)
31+
self.assertEqual(s.q_len, 10)
32+
self.assertEqual(s.k_len, 10)
33+
34+
def test_work_estimate_full(self):
35+
s = AttnSlice(q_start=0, q_end=100, k_start=0, k_end=100, mask_type=MaskType.FULL)
36+
self.assertAlmostEqual(s.work_estimate, 10000.0)
37+
38+
def test_work_estimate_causal(self):
39+
s = AttnSlice(q_start=0, q_end=100, k_start=0, k_end=100, mask_type=MaskType.CAUSAL)
40+
self.assertAlmostEqual(s.work_estimate, 5000.0)
41+
42+
def test_work_estimate_invcausal(self):
43+
s = AttnSlice(
44+
q_start=0, q_end=100, k_start=0, k_end=100, mask_type=MaskType.INVCAUSAL
45+
)
46+
self.assertAlmostEqual(s.work_estimate, 5000.0)
47+
48+
def test_work_estimate_bicausal(self):
49+
s = AttnSlice(
50+
q_start=0, q_end=100, k_start=0, k_end=100, mask_type=MaskType.BICAUSAL
51+
)
52+
self.assertAlmostEqual(s.work_estimate, 2500.0)
53+
54+
def test_work_estimate_minimum(self):
55+
"""Empty slices should have work_estimate >= 1.0."""
56+
s = AttnSlice(q_start=0, q_end=1, k_start=0, k_end=1, mask_type=MaskType.CAUSAL)
57+
self.assertGreaterEqual(s.work_estimate, 1.0)
58+
59+
60+
class TestCuSeqlensToAttnSlices(unittest.TestCase):
61+
def test_single_doc(self):
62+
cu_seqlens = [0, 256]
63+
slices = cu_seqlens_to_attn_slices(cu_seqlens, is_causal=True)
64+
self.assertEqual(len(slices), 1)
65+
self.assertEqual(slices[0].q_start, 0)
66+
self.assertEqual(slices[0].q_end, 256)
67+
self.assertEqual(slices[0].mask_type, MaskType.CAUSAL)
68+
69+
def test_multi_doc(self):
70+
cu_seqlens = [0, 128, 300, 512]
71+
slices = cu_seqlens_to_attn_slices(cu_seqlens, is_causal=True)
72+
self.assertEqual(len(slices), 3)
73+
self.assertEqual(slices[0], AttnSlice(0, 128, 0, 128, MaskType.CAUSAL))
74+
self.assertEqual(slices[1], AttnSlice(128, 300, 128, 300, MaskType.CAUSAL))
75+
self.assertEqual(slices[2], AttnSlice(300, 512, 300, 512, MaskType.CAUSAL))
76+
77+
def test_full_mask(self):
78+
cu_seqlens = [0, 256]
79+
slices = cu_seqlens_to_attn_slices(cu_seqlens, is_causal=False)
80+
self.assertEqual(len(slices), 1)
81+
self.assertEqual(slices[0].mask_type, MaskType.FULL)
82+
83+
def test_tensor_input(self):
84+
cu_seqlens = torch.tensor([0, 128, 256])
85+
slices = cu_seqlens_to_attn_slices(cu_seqlens, is_causal=True)
86+
self.assertEqual(len(slices), 2)
87+
88+
def test_empty_doc(self):
89+
"""Adjacent equal values in cu_seqlens create zero-length docs."""
90+
cu_seqlens = [0, 128, 128, 256]
91+
slices = cu_seqlens_to_attn_slices(cu_seqlens)
92+
# Zero-length docs should be skipped
93+
self.assertEqual(len(slices), 2)
94+
95+
96+
class TestSplitSliceAtChunkBoundary(unittest.TestCase):
97+
def test_doc_within_one_chunk(self):
98+
"""Document fits entirely within one chunk."""
99+
s = AttnSlice(q_start=10, q_end=50, k_start=10, k_end=50, mask_type=MaskType.CAUSAL)
100+
result = split_slice_at_chunk_boundary(s, chunk_size=64, total_seqlen=128)
101+
self.assertEqual(len(result), 1)
102+
self.assertEqual(result[0].mask_type, MaskType.CAUSAL)
103+
self.assertEqual(result[0].q_start, 10)
104+
self.assertEqual(result[0].q_end, 50)
105+
106+
def test_doc_spanning_two_chunks(self):
107+
"""Document spans two chunks: diagonal blocks are CAUSAL, below-diagonal are FULL."""
108+
s = AttnSlice(q_start=48, q_end=80, k_start=48, k_end=80, mask_type=MaskType.CAUSAL)
109+
result = split_slice_at_chunk_boundary(s, chunk_size=64, total_seqlen=128)
110+
111+
# Should produce 3 sub-slices:
112+
# (chunk 0, chunk 0): q=[48,64), k=[48,64), CAUSAL
113+
# (chunk 1, chunk 0): q=[64,80), k=[48,64), FULL (below diagonal)
114+
# (chunk 1, chunk 1): q=[64,80), k=[64,80), CAUSAL (diagonal)
115+
self.assertEqual(len(result), 3)
116+
117+
# Check that we have the expected types
118+
types = {(r.q_start // 64, r.k_start // 64): r.mask_type for r in result}
119+
self.assertEqual(types[(0, 0)], MaskType.CAUSAL)
120+
self.assertEqual(types[(1, 0)], MaskType.FULL)
121+
self.assertEqual(types[(1, 1)], MaskType.CAUSAL)
122+
123+
def test_full_mask_stays_full(self):
124+
"""FULL mask type sub-blocks are all FULL."""
125+
s = AttnSlice(q_start=48, q_end=80, k_start=48, k_end=80, mask_type=MaskType.FULL)
126+
result = split_slice_at_chunk_boundary(s, chunk_size=64, total_seqlen=128)
127+
for r in result:
128+
self.assertEqual(r.mask_type, MaskType.FULL)
129+
130+
131+
class TestMakeSliceMask(unittest.TestCase):
132+
def test_full_mask(self):
133+
mask = make_slice_mask(4, 4, MaskType.FULL)
134+
self.assertTrue(mask.all())
135+
136+
def test_causal_square(self):
137+
mask = make_slice_mask(4, 4, MaskType.CAUSAL)
138+
expected = torch.tensor(
139+
[
140+
[True, False, False, False],
141+
[True, True, False, False],
142+
[True, True, True, False],
143+
[True, True, True, True],
144+
]
145+
)
146+
self.assertTrue(torch.equal(mask, expected))
147+
148+
def test_causal_rectangular(self):
149+
"""Bottom-right aligned causal for q_len < k_len."""
150+
mask = make_slice_mask(2, 4, MaskType.CAUSAL)
151+
# q_len=2, k_len=4, offset = k_len - q_len = 2
152+
# Row 0: j <= 0+2 → j in {0,1,2}
153+
# Row 1: j <= 1+2 → j in {0,1,2,3}
154+
expected = torch.tensor(
155+
[
156+
[True, True, True, False],
157+
[True, True, True, True],
158+
]
159+
)
160+
self.assertTrue(torch.equal(mask, expected))
161+
162+
def test_invcausal_mask(self):
163+
mask = make_slice_mask(4, 4, MaskType.INVCAUSAL)
164+
expected = torch.tensor(
165+
[
166+
[True, True, True, True],
167+
[False, True, True, True],
168+
[False, False, True, True],
169+
[False, False, False, True],
170+
]
171+
)
172+
self.assertTrue(torch.equal(mask, expected))
173+
174+
def test_bicausal_mask(self):
175+
mask = make_slice_mask(4, 4, MaskType.BICAUSAL)
176+
expected = torch.tensor(
177+
[
178+
[True, False, False, False],
179+
[False, True, False, False],
180+
[False, False, True, False],
181+
[False, False, False, True],
182+
]
183+
)
184+
self.assertTrue(torch.equal(mask, expected))
185+
186+
187+
if __name__ == "__main__":
188+
unittest.main()

0 commit comments

Comments
 (0)