Skip to content

Commit d714bd0

Browse files
committed
first attempt
ghstack-source-id: 00a0a5d Pull Request resolved: #2571
1 parent 383b467 commit d714bd0

10 files changed

Lines changed: 587 additions & 175 deletions

File tree

Lines changed: 356 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,356 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
11+
from torchtitan.models.common.linear import Linear
12+
from torchtitan.models.common.moe.moe import (
13+
GroupedExperts,
14+
MoE,
15+
TokenChoiceTopKRouter,
16+
TokenReorderer,
17+
)
18+
from torchtitan.protocols.module import Module
19+
20+
21+
class TestGroupedExperts(unittest.TestCase):
22+
"""Tests for GroupedExperts Config/build pattern."""
23+
24+
def test_config_build(self):
25+
"""GroupedExperts.Config.build() creates a working instance."""
26+
config = GroupedExperts.Config()
27+
experts = config.build(dim=32, hidden_dim=64, num_experts=4)
28+
self.assertIsInstance(experts, GroupedExperts)
29+
self.assertIsInstance(experts, Module)
30+
self.assertEqual(experts.w1.shape, torch.Size([4, 64, 32]))
31+
self.assertEqual(experts.w2.shape, torch.Size([4, 32, 64]))
32+
self.assertEqual(experts.w3.shape, torch.Size([4, 64, 32]))
33+
self.assertEqual(experts.num_experts, 4)
34+
self.assertTrue(experts.use_grouped_mm)
35+
36+
def test_config_build_no_grouped_mm(self):
37+
"""GroupedExperts.Config(use_grouped_mm=False) is respected."""
38+
config = GroupedExperts.Config(use_grouped_mm=False)
39+
experts = config.build(dim=16, hidden_dim=32, num_experts=2)
40+
self.assertFalse(experts.use_grouped_mm)
41+
42+
def test_config_build_without_fields_raises(self):
43+
"""build() raises when required fields are not provided."""
44+
config = GroupedExperts.Config()
45+
with self.assertRaises(TypeError):
46+
config.build()
47+
48+
def test_config_build_partial_fields_raises(self):
49+
"""build() raises when only some required fields are provided."""
50+
config = GroupedExperts.Config()
51+
with self.assertRaises(TypeError):
52+
config.build(dim=32)
53+
54+
def test_init_weights(self):
55+
"""init_weights re-initializes weight tensors."""
56+
config = GroupedExperts.Config()
57+
experts = config.build(dim=16, hidden_dim=32, num_experts=2)
58+
59+
with torch.no_grad():
60+
torch.nn.init.zeros_(experts.w1)
61+
torch.nn.init.zeros_(experts.w2)
62+
torch.nn.init.zeros_(experts.w3)
63+
self.assertTrue(torch.all(experts.w1 == 0))
64+
experts.init_weights(init_std=0.02)
65+
self.assertFalse(torch.all(experts.w1 == 0))
66+
self.assertFalse(torch.all(experts.w2 == 0))
67+
self.assertFalse(torch.all(experts.w3 == 0))
68+
69+
def test_init_weights_requires_init_std(self):
70+
"""init_weights raises when init_std is not provided."""
71+
config = GroupedExperts.Config()
72+
experts = config.build(dim=16, hidden_dim=32, num_experts=2)
73+
with self.assertRaises(AssertionError):
74+
experts.init_weights()
75+
76+
def test_forward_for_loop(self):
77+
"""Forward pass works with for-loop implementation."""
78+
config = GroupedExperts.Config(use_grouped_mm=False)
79+
experts = config.build(dim=16, hidden_dim=32, num_experts=4)
80+
experts.init_weights(init_std=0.02)
81+
82+
num_tokens_per_expert = torch.tensor([3, 2, 4, 1])
83+
total_tokens = num_tokens_per_expert.sum().item()
84+
x = torch.randn(total_tokens, 16)
85+
out = experts(x, num_tokens_per_expert)
86+
self.assertEqual(out.shape, torch.Size([total_tokens, 16]))
87+
88+
def test_shared_config_builds_independent_instances(self):
89+
"""A single Config can build multiple independent instances."""
90+
config = GroupedExperts.Config()
91+
e1 = config.build(dim=16, hidden_dim=32, num_experts=4)
92+
e2 = config.build(dim=32, hidden_dim=64, num_experts=8)
93+
self.assertIsNot(e1, e2)
94+
self.assertEqual(e1.w1.shape, torch.Size([4, 32, 16]))
95+
self.assertEqual(e2.w1.shape, torch.Size([8, 64, 32]))
96+
97+
def test_default_use_grouped_mm_true(self):
98+
"""GroupedExperts.Config defaults to use_grouped_mm=True."""
99+
config = GroupedExperts.Config()
100+
self.assertTrue(config.use_grouped_mm)
101+
102+
103+
class TestTokenChoiceTopKRouter(unittest.TestCase):
104+
"""Tests for TokenChoiceTopKRouter Config/build pattern."""
105+
106+
def test_config_build(self):
107+
"""TokenChoiceTopKRouter.Config.build() creates a working instance."""
108+
config = TokenChoiceTopKRouter.Config()
109+
router = config.build(dim=32, num_experts=8)
110+
self.assertIsInstance(router, TokenChoiceTopKRouter)
111+
self.assertIsInstance(router, Module)
112+
self.assertEqual(router.num_experts, 8)
113+
self.assertEqual(router.top_k, 1)
114+
self.assertEqual(router.score_func, "sigmoid")
115+
self.assertFalse(router.route_norm)
116+
self.assertEqual(router.route_scale, 1.0)
117+
118+
def test_config_build_with_custom_params(self):
119+
"""Config respects custom routing parameters."""
120+
config = TokenChoiceTopKRouter.Config(
121+
top_k=4,
122+
score_func="softmax",
123+
route_norm=True,
124+
route_scale=2.5,
125+
)
126+
router = config.build(dim=64, num_experts=16)
127+
self.assertEqual(router.top_k, 4)
128+
self.assertEqual(router.score_func, "softmax")
129+
self.assertTrue(router.route_norm)
130+
self.assertEqual(router.route_scale, 2.5)
131+
132+
def test_config_build_with_gate_bias(self):
133+
"""gate=Linear.Config(bias=True) creates a gate with bias."""
134+
config = TokenChoiceTopKRouter.Config(
135+
gate=Linear.Config(bias=True),
136+
)
137+
router = config.build(dim=32, num_experts=8)
138+
self.assertIsNotNone(router.gate.bias)
139+
self.assertEqual(router.gate.bias.shape, torch.Size([8]))
140+
141+
def test_config_build_default_gate_no_bias(self):
142+
"""Default gate config has no bias."""
143+
config = TokenChoiceTopKRouter.Config()
144+
router = config.build(dim=32, num_experts=8)
145+
self.assertIsNone(router.gate.bias)
146+
147+
def test_config_build_without_fields_raises(self):
148+
"""build() raises when required fields are not provided."""
149+
config = TokenChoiceTopKRouter.Config()
150+
with self.assertRaises(TypeError):
151+
config.build()
152+
153+
def test_gate_shape(self):
154+
"""Gate linear layer has correct shape (dim -> num_experts)."""
155+
config = TokenChoiceTopKRouter.Config()
156+
router = config.build(dim=64, num_experts=16)
157+
self.assertIsInstance(router.gate, Linear)
158+
self.assertEqual(router.gate.weight.shape, torch.Size([16, 64]))
159+
160+
def test_node_limited_routing_config(self):
161+
"""Config correctly passes node-limited routing parameters."""
162+
config = TokenChoiceTopKRouter.Config(
163+
num_expert_groups=4,
164+
num_limited_groups=2,
165+
)
166+
router = config.build(dim=32, num_experts=16)
167+
self.assertEqual(router.num_expert_groups, 4)
168+
self.assertEqual(router.num_limited_groups, 2)
169+
170+
@unittest.skipUnless(torch.cuda.is_available(), "requires CUDA")
171+
def test_forward(self):
172+
"""Forward pass returns correct shapes."""
173+
config = TokenChoiceTopKRouter.Config(top_k=2, score_func="softmax")
174+
router = config.build(dim=32, num_experts=8).cuda()
175+
router.init_weights(init_std=0.02)
176+
177+
x = torch.randn(10, 32, device="cuda")
178+
top_scores, selected_experts, num_tokens_per_expert = router(x)
179+
self.assertEqual(top_scores.shape, torch.Size([10, 2]))
180+
self.assertEqual(selected_experts.shape, torch.Size([10, 2]))
181+
self.assertEqual(num_tokens_per_expert.shape, torch.Size([8]))
182+
183+
def test_init_weights(self):
184+
"""init_weights delegates to gate.init_weights."""
185+
config = TokenChoiceTopKRouter.Config()
186+
router = config.build(dim=16, num_experts=4)
187+
188+
with torch.no_grad():
189+
torch.nn.init.zeros_(router.gate.weight)
190+
self.assertTrue(torch.all(router.gate.weight == 0))
191+
router.init_weights(init_std=0.02)
192+
self.assertFalse(torch.all(router.gate.weight == 0))
193+
194+
def test_shared_config_builds_independent_instances(self):
195+
"""A single Config can build multiple independent routers."""
196+
config = TokenChoiceTopKRouter.Config(top_k=2)
197+
r1 = config.build(dim=32, num_experts=8)
198+
r2 = config.build(dim=64, num_experts=16)
199+
self.assertIsNot(r1, r2)
200+
self.assertEqual(r1.gate.weight.shape, torch.Size([8, 32]))
201+
self.assertEqual(r2.gate.weight.shape, torch.Size([16, 64]))
202+
203+
204+
class TestTokenReorderer(unittest.TestCase):
205+
"""Tests for TokenReorderer Config/build pattern."""
206+
207+
def test_config_build(self):
208+
"""TokenReorderer.Config.build() creates a working instance."""
209+
config = TokenReorderer.Config()
210+
reorderer = config.build(num_experts=8, top_k=2)
211+
self.assertIsInstance(reorderer, TokenReorderer)
212+
self.assertIsInstance(reorderer, Module)
213+
self.assertEqual(reorderer.num_experts, 8)
214+
self.assertEqual(reorderer.top_k, 2)
215+
216+
def test_config_build_without_fields_raises(self):
217+
"""build() raises when required fields are not provided."""
218+
config = TokenReorderer.Config()
219+
with self.assertRaises(TypeError):
220+
config.build()
221+
222+
def test_config_build_partial_fields_raises(self):
223+
"""build() raises when only some required fields are provided."""
224+
config = TokenReorderer.Config()
225+
with self.assertRaises(TypeError):
226+
config.build(num_experts=8)
227+
228+
def test_shared_config_builds_independent_instances(self):
229+
"""A single Config can build multiple independent instances."""
230+
config = TokenReorderer.Config()
231+
r1 = config.build(num_experts=8, top_k=2)
232+
r2 = config.build(num_experts=16, top_k=4)
233+
self.assertIsNot(r1, r2)
234+
self.assertEqual(r1.num_experts, 8)
235+
self.assertEqual(r2.num_experts, 16)
236+
self.assertEqual(r1.top_k, 2)
237+
self.assertEqual(r2.top_k, 4)
238+
239+
def test_init_weights_noop(self):
240+
"""init_weights is a no-op (TokenReorderer has no learnable parameters)."""
241+
config = TokenReorderer.Config()
242+
reorderer = config.build(num_experts=4, top_k=1)
243+
# Should not raise
244+
reorderer.init_weights()
245+
246+
247+
class TestMoEConfig(unittest.TestCase):
248+
"""Tests for MoE with nested GroupedExperts and TokenChoiceTopKRouter configs."""
249+
250+
def test_config_build_defaults(self):
251+
"""MoE.Config.build() with defaults creates a working MoE."""
252+
config = MoE.Config(hidden_dim=64)
253+
moe = config.build(dim=32)
254+
self.assertIsInstance(moe, MoE)
255+
self.assertIsInstance(moe, Module)
256+
self.assertIsInstance(moe.experts, GroupedExperts)
257+
self.assertIsInstance(moe.router, TokenChoiceTopKRouter)
258+
259+
def test_config_with_nested_router(self):
260+
"""MoE.Config with custom router config is respected."""
261+
config = MoE.Config(
262+
hidden_dim=64,
263+
num_experts=16,
264+
router=TokenChoiceTopKRouter.Config(
265+
top_k=4,
266+
score_func="softmax",
267+
route_norm=True,
268+
),
269+
)
270+
moe = config.build(dim=32)
271+
self.assertEqual(moe.router.top_k, 4)
272+
self.assertEqual(moe.router.score_func, "softmax")
273+
self.assertTrue(moe.router.route_norm)
274+
self.assertEqual(moe.experts.num_experts, 16)
275+
276+
def test_config_with_nested_experts(self):
277+
"""MoE.Config with custom experts config is respected."""
278+
config = MoE.Config(
279+
hidden_dim=64,
280+
experts=GroupedExperts.Config(use_grouped_mm=False),
281+
)
282+
moe = config.build(dim=32)
283+
self.assertFalse(moe.experts.use_grouped_mm)
284+
285+
def test_config_with_gate_bias(self):
286+
"""MoE.Config with gate bias via nested router config."""
287+
config = MoE.Config(
288+
hidden_dim=64,
289+
router=TokenChoiceTopKRouter.Config(
290+
gate=Linear.Config(bias=True),
291+
),
292+
)
293+
moe = config.build(dim=32)
294+
self.assertIsNotNone(moe.router.gate.bias)
295+
296+
def test_num_experts_propagated(self):
297+
"""num_experts from MoE.Config propagates to experts and router."""
298+
config = MoE.Config(hidden_dim=64, num_experts=16)
299+
moe = config.build(dim=32)
300+
self.assertEqual(moe.experts.num_experts, 16)
301+
self.assertEqual(moe.router.num_experts, 16)
302+
self.assertEqual(moe.experts.w1.shape[0], 16)
303+
304+
def test_hidden_dim_propagated(self):
305+
"""hidden_dim from MoE.Config propagates to experts."""
306+
config = MoE.Config(hidden_dim=128, num_experts=4)
307+
moe = config.build(dim=32)
308+
self.assertEqual(moe.experts.w1.shape, torch.Size([4, 128, 32]))
309+
self.assertEqual(moe.experts.w2.shape, torch.Size([4, 32, 128]))
310+
311+
def test_init_weights(self):
312+
"""MoE.init_weights initializes experts and router weights."""
313+
config = MoE.Config(hidden_dim=64, num_experts=4)
314+
moe = config.build(dim=32)
315+
316+
with torch.no_grad():
317+
torch.nn.init.zeros_(moe.experts.w1)
318+
torch.nn.init.zeros_(moe.router.gate.weight)
319+
self.assertTrue(torch.all(moe.experts.w1 == 0))
320+
self.assertTrue(torch.all(moe.router.gate.weight == 0))
321+
322+
moe.init_weights(init_std=0.02, buffer_device=torch.device("cpu"))
323+
self.assertFalse(torch.all(moe.experts.w1 == 0))
324+
self.assertFalse(torch.all(moe.router.gate.weight == 0))
325+
326+
@unittest.skipUnless(torch.cuda.is_available(), "requires CUDA")
327+
def test_forward(self):
328+
"""MoE forward pass produces correct output shape."""
329+
config = MoE.Config(
330+
hidden_dim=64,
331+
num_experts=4,
332+
num_shared_experts=0,
333+
experts=GroupedExperts.Config(use_grouped_mm=False),
334+
)
335+
moe = config.build(dim=32).cuda()
336+
moe.init_weights(init_std=0.02, buffer_device=torch.device("cuda"))
337+
338+
x = torch.randn(2, 8, 32, device="cuda")
339+
out = moe(x)
340+
self.assertEqual(out.shape, torch.Size([2, 8, 32]))
341+
342+
def test_shared_experts(self):
343+
"""MoE with shared experts creates FeedForward."""
344+
config = MoE.Config(hidden_dim=64, num_shared_experts=2)
345+
moe = config.build(dim=32)
346+
self.assertIsNotNone(moe.shared_experts)
347+
348+
def test_no_shared_experts(self):
349+
"""MoE with num_shared_experts=0 has no shared experts."""
350+
config = MoE.Config(hidden_dim=64, num_shared_experts=0)
351+
moe = config.build(dim=32)
352+
self.assertIsNone(moe.shared_experts)
353+
354+
355+
if __name__ == "__main__":
356+
unittest.main()

0 commit comments

Comments
 (0)