Skip to content

Commit 072ab4a

Browse files
committed
first attempt
ghstack-source-id: d58ba64 Pull Request resolved: #2571
1 parent 383b467 commit 072ab4a

10 files changed

Lines changed: 526 additions & 172 deletions

File tree

Lines changed: 316 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,316 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
11+
from torchtitan.models.common.linear import Linear
12+
from torchtitan.models.common.moe.moe import (
13+
GroupedExperts,
14+
MoE,
15+
TokenChoiceTopKRouter,
16+
)
17+
from torchtitan.protocols.module import Module
18+
19+
20+
class TestGroupedExperts(unittest.TestCase):
21+
"""Tests for GroupedExperts Config/build pattern."""
22+
23+
def test_config_build(self):
24+
"""GroupedExperts.Config.build() creates a working instance."""
25+
config = GroupedExperts.Config()
26+
experts = config.build(dim=32, hidden_dim=64, num_experts=4)
27+
self.assertIsInstance(experts, GroupedExperts)
28+
self.assertIsInstance(experts, Module)
29+
self.assertEqual(experts.w1.shape, torch.Size([4, 64, 32]))
30+
self.assertEqual(experts.w2.shape, torch.Size([4, 32, 64]))
31+
self.assertEqual(experts.w3.shape, torch.Size([4, 64, 32]))
32+
self.assertEqual(experts.num_experts, 4)
33+
self.assertTrue(experts.use_grouped_mm)
34+
35+
def test_config_build_no_grouped_mm(self):
36+
"""GroupedExperts.Config(use_grouped_mm=False) is respected."""
37+
config = GroupedExperts.Config(use_grouped_mm=False)
38+
experts = config.build(dim=16, hidden_dim=32, num_experts=2)
39+
self.assertFalse(experts.use_grouped_mm)
40+
41+
def test_config_build_without_fields_raises(self):
42+
"""build() raises when required fields are not provided."""
43+
config = GroupedExperts.Config()
44+
with self.assertRaises(TypeError):
45+
config.build()
46+
47+
def test_config_build_partial_fields_raises(self):
48+
"""build() raises when only some required fields are provided."""
49+
config = GroupedExperts.Config()
50+
with self.assertRaises(TypeError):
51+
config.build(dim=32)
52+
53+
def test_init_weights(self):
54+
"""init_weights re-initializes weight tensors."""
55+
config = GroupedExperts.Config()
56+
experts = config.build(dim=16, hidden_dim=32, num_experts=2)
57+
58+
with torch.no_grad():
59+
torch.nn.init.zeros_(experts.w1)
60+
torch.nn.init.zeros_(experts.w2)
61+
torch.nn.init.zeros_(experts.w3)
62+
self.assertTrue(torch.all(experts.w1 == 0))
63+
experts.init_weights(init_std=0.02)
64+
self.assertFalse(torch.all(experts.w1 == 0))
65+
self.assertFalse(torch.all(experts.w2 == 0))
66+
self.assertFalse(torch.all(experts.w3 == 0))
67+
68+
def test_init_weights_requires_init_std(self):
69+
"""init_weights raises when init_std is not provided."""
70+
config = GroupedExperts.Config()
71+
experts = config.build(dim=16, hidden_dim=32, num_experts=2)
72+
with self.assertRaises(AssertionError):
73+
experts.init_weights()
74+
75+
def test_forward_for_loop(self):
76+
"""Forward pass works with for-loop implementation."""
77+
config = GroupedExperts.Config(use_grouped_mm=False)
78+
experts = config.build(dim=16, hidden_dim=32, num_experts=4)
79+
experts.init_weights(init_std=0.02)
80+
81+
num_tokens_per_expert = torch.tensor([3, 2, 4, 1])
82+
total_tokens = num_tokens_per_expert.sum().item()
83+
x = torch.randn(total_tokens, 16)
84+
out = experts(x, num_tokens_per_expert)
85+
self.assertEqual(out.shape, torch.Size([total_tokens, 16]))
86+
87+
def test_shared_config_builds_independent_instances(self):
88+
"""A single Config can build multiple independent instances."""
89+
config = GroupedExperts.Config()
90+
e1 = config.build(dim=16, hidden_dim=32, num_experts=4)
91+
e2 = config.build(dim=32, hidden_dim=64, num_experts=8)
92+
self.assertIsNot(e1, e2)
93+
self.assertEqual(e1.w1.shape, torch.Size([4, 32, 16]))
94+
self.assertEqual(e2.w1.shape, torch.Size([8, 64, 32]))
95+
96+
def test_default_use_grouped_mm_true(self):
97+
"""GroupedExperts.Config defaults to use_grouped_mm=True."""
98+
config = GroupedExperts.Config()
99+
self.assertTrue(config.use_grouped_mm)
100+
101+
102+
class TestTokenChoiceTopKRouter(unittest.TestCase):
103+
"""Tests for TokenChoiceTopKRouter Config/build pattern."""
104+
105+
def test_config_build(self):
106+
"""TokenChoiceTopKRouter.Config.build() creates a working instance."""
107+
config = TokenChoiceTopKRouter.Config()
108+
router = config.build(dim=32, num_experts=8)
109+
self.assertIsInstance(router, TokenChoiceTopKRouter)
110+
self.assertIsInstance(router, Module)
111+
self.assertEqual(router.num_experts, 8)
112+
self.assertEqual(router.top_k, 1)
113+
self.assertEqual(router.score_func, "sigmoid")
114+
self.assertFalse(router.route_norm)
115+
self.assertEqual(router.route_scale, 1.0)
116+
117+
def test_config_build_with_custom_params(self):
118+
"""Config respects custom routing parameters."""
119+
config = TokenChoiceTopKRouter.Config(
120+
top_k=4,
121+
score_func="softmax",
122+
route_norm=True,
123+
route_scale=2.5,
124+
)
125+
router = config.build(dim=64, num_experts=16)
126+
self.assertEqual(router.top_k, 4)
127+
self.assertEqual(router.score_func, "softmax")
128+
self.assertTrue(router.route_norm)
129+
self.assertEqual(router.route_scale, 2.5)
130+
131+
def test_config_build_with_gate_bias(self):
132+
"""gate=Linear.Config(bias=True) creates a gate with bias."""
133+
config = TokenChoiceTopKRouter.Config(
134+
gate=Linear.Config(bias=True),
135+
)
136+
router = config.build(dim=32, num_experts=8)
137+
self.assertIsNotNone(router.gate.bias)
138+
self.assertEqual(router.gate.bias.shape, torch.Size([8]))
139+
140+
def test_config_build_default_gate_no_bias(self):
141+
"""Default gate config has no bias."""
142+
config = TokenChoiceTopKRouter.Config()
143+
router = config.build(dim=32, num_experts=8)
144+
self.assertIsNone(router.gate.bias)
145+
146+
def test_config_build_without_fields_raises(self):
147+
"""build() raises when required fields are not provided."""
148+
config = TokenChoiceTopKRouter.Config()
149+
with self.assertRaises(TypeError):
150+
config.build()
151+
152+
def test_gate_shape(self):
153+
"""Gate linear layer has correct shape (dim -> num_experts)."""
154+
config = TokenChoiceTopKRouter.Config()
155+
router = config.build(dim=64, num_experts=16)
156+
self.assertIsInstance(router.gate, Linear)
157+
self.assertEqual(router.gate.weight.shape, torch.Size([16, 64]))
158+
159+
def test_node_limited_routing_config(self):
160+
"""Config correctly passes node-limited routing parameters."""
161+
config = TokenChoiceTopKRouter.Config(
162+
num_expert_groups=4,
163+
num_limited_groups=2,
164+
)
165+
router = config.build(dim=32, num_experts=16)
166+
self.assertEqual(router.num_expert_groups, 4)
167+
self.assertEqual(router.num_limited_groups, 2)
168+
169+
@unittest.skipUnless(torch.cuda.is_available(), "requires CUDA")
170+
def test_forward(self):
171+
"""Forward pass returns correct shapes."""
172+
config = TokenChoiceTopKRouter.Config(top_k=2, score_func="softmax")
173+
router = config.build(dim=32, num_experts=8).cuda()
174+
router.init_weights(init_std=0.02)
175+
176+
x = torch.randn(10, 32, device="cuda")
177+
top_scores, selected_experts, num_tokens_per_expert = router(x)
178+
self.assertEqual(top_scores.shape, torch.Size([10, 2]))
179+
self.assertEqual(selected_experts.shape, torch.Size([10, 2]))
180+
self.assertEqual(num_tokens_per_expert.shape, torch.Size([8]))
181+
182+
def test_init_weights(self):
183+
"""init_weights delegates to gate.init_weights."""
184+
config = TokenChoiceTopKRouter.Config()
185+
router = config.build(dim=16, num_experts=4)
186+
187+
with torch.no_grad():
188+
torch.nn.init.zeros_(router.gate.weight)
189+
self.assertTrue(torch.all(router.gate.weight == 0))
190+
router.init_weights(init_std=0.02)
191+
self.assertFalse(torch.all(router.gate.weight == 0))
192+
193+
def test_shared_config_builds_independent_instances(self):
194+
"""A single Config can build multiple independent routers."""
195+
config = TokenChoiceTopKRouter.Config(top_k=2)
196+
r1 = config.build(dim=32, num_experts=8)
197+
r2 = config.build(dim=64, num_experts=16)
198+
self.assertIsNot(r1, r2)
199+
self.assertEqual(r1.gate.weight.shape, torch.Size([8, 32]))
200+
self.assertEqual(r2.gate.weight.shape, torch.Size([16, 64]))
201+
202+
203+
class TestMoEConfig(unittest.TestCase):
204+
"""Tests for MoE with nested GroupedExperts and TokenChoiceTopKRouter configs."""
205+
206+
def test_config_build_defaults(self):
207+
"""MoE.Config.build() with defaults creates a working MoE."""
208+
config = MoE.Config(hidden_dim=64)
209+
moe = config.build(dim=32)
210+
self.assertIsInstance(moe, MoE)
211+
self.assertIsInstance(moe, Module)
212+
self.assertIsInstance(moe.experts, GroupedExperts)
213+
self.assertIsInstance(moe.router, TokenChoiceTopKRouter)
214+
215+
def test_config_with_nested_router(self):
216+
"""MoE.Config with custom router config is respected."""
217+
config = MoE.Config(
218+
hidden_dim=64,
219+
num_experts=16,
220+
router=TokenChoiceTopKRouter.Config(
221+
top_k=4,
222+
score_func="softmax",
223+
route_norm=True,
224+
),
225+
)
226+
moe = config.build(dim=32)
227+
self.assertEqual(moe.router.top_k, 4)
228+
self.assertEqual(moe.router.score_func, "softmax")
229+
self.assertTrue(moe.router.route_norm)
230+
self.assertEqual(moe.experts.num_experts, 16)
231+
232+
def test_config_with_nested_experts(self):
233+
"""MoE.Config with custom experts config is respected."""
234+
config = MoE.Config(
235+
hidden_dim=64,
236+
experts=GroupedExperts.Config(use_grouped_mm=False),
237+
)
238+
moe = config.build(dim=32)
239+
self.assertFalse(moe.experts.use_grouped_mm)
240+
241+
def test_config_with_gate_bias(self):
242+
"""MoE.Config with gate bias via nested router config."""
243+
config = MoE.Config(
244+
hidden_dim=64,
245+
router=TokenChoiceTopKRouter.Config(
246+
gate=Linear.Config(bias=True),
247+
),
248+
)
249+
moe = config.build(dim=32)
250+
self.assertIsNotNone(moe.router.gate.bias)
251+
252+
def test_num_experts_propagated(self):
253+
"""num_experts from MoE.Config propagates to experts and router."""
254+
config = MoE.Config(hidden_dim=64, num_experts=16)
255+
moe = config.build(dim=32)
256+
self.assertEqual(moe.experts.num_experts, 16)
257+
self.assertEqual(moe.router.num_experts, 16)
258+
self.assertEqual(moe.experts.w1.shape[0], 16)
259+
260+
def test_hidden_dim_propagated(self):
261+
"""hidden_dim from MoE.Config propagates to experts."""
262+
config = MoE.Config(hidden_dim=128, num_experts=4)
263+
moe = config.build(dim=32)
264+
self.assertEqual(moe.experts.w1.shape, torch.Size([4, 128, 32]))
265+
self.assertEqual(moe.experts.w2.shape, torch.Size([4, 32, 128]))
266+
267+
def test_init_weights(self):
268+
"""MoE.init_weights initializes experts and router weights."""
269+
config = MoE.Config(hidden_dim=64, num_experts=4)
270+
moe = config.build(dim=32)
271+
272+
with torch.no_grad():
273+
torch.nn.init.zeros_(moe.experts.w1)
274+
torch.nn.init.zeros_(moe.router.gate.weight)
275+
self.assertTrue(torch.all(moe.experts.w1 == 0))
276+
self.assertTrue(torch.all(moe.router.gate.weight == 0))
277+
278+
moe.init_weights(
279+
init_std=0.02, buffer_device=torch.device("cpu")
280+
)
281+
self.assertFalse(torch.all(moe.experts.w1 == 0))
282+
self.assertFalse(torch.all(moe.router.gate.weight == 0))
283+
284+
@unittest.skipUnless(torch.cuda.is_available(), "requires CUDA")
285+
def test_forward(self):
286+
"""MoE forward pass produces correct output shape."""
287+
config = MoE.Config(
288+
hidden_dim=64,
289+
num_experts=4,
290+
num_shared_experts=0,
291+
experts=GroupedExperts.Config(use_grouped_mm=False),
292+
)
293+
moe = config.build(dim=32).cuda()
294+
moe.init_weights(
295+
init_std=0.02, buffer_device=torch.device("cuda")
296+
)
297+
298+
x = torch.randn(2, 8, 32, device="cuda")
299+
out = moe(x)
300+
self.assertEqual(out.shape, torch.Size([2, 8, 32]))
301+
302+
def test_shared_experts(self):
303+
"""MoE with shared experts creates FeedForward."""
304+
config = MoE.Config(hidden_dim=64, num_shared_experts=2)
305+
moe = config.build(dim=32)
306+
self.assertIsNotNone(moe.shared_experts)
307+
308+
def test_no_shared_experts(self):
309+
"""MoE with num_shared_experts=0 has no shared experts."""
310+
config = MoE.Config(hidden_dim=64, num_shared_experts=0)
311+
moe = config.build(dim=32)
312+
self.assertIsNone(moe.shared_experts)
313+
314+
315+
if __name__ == "__main__":
316+
unittest.main()

0 commit comments

Comments
 (0)