|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +import unittest |
| 8 | + |
| 9 | +import torch |
| 10 | + |
| 11 | +from torchtitan.models.common.linear import Linear |
| 12 | +from torchtitan.models.common.moe.moe import ( |
| 13 | + GroupedExperts, |
| 14 | + MoE, |
| 15 | + TokenChoiceTopKRouter, |
| 16 | +) |
| 17 | +from torchtitan.protocols.module import Module |
| 18 | + |
| 19 | + |
| 20 | +class TestGroupedExperts(unittest.TestCase): |
| 21 | + """Tests for GroupedExperts Config/build pattern.""" |
| 22 | + |
| 23 | + def test_config_build(self): |
| 24 | + """GroupedExperts.Config.build() creates a working instance.""" |
| 25 | + config = GroupedExperts.Config() |
| 26 | + experts = config.build(dim=32, hidden_dim=64, num_experts=4) |
| 27 | + self.assertIsInstance(experts, GroupedExperts) |
| 28 | + self.assertIsInstance(experts, Module) |
| 29 | + self.assertEqual(experts.w1.shape, torch.Size([4, 64, 32])) |
| 30 | + self.assertEqual(experts.w2.shape, torch.Size([4, 32, 64])) |
| 31 | + self.assertEqual(experts.w3.shape, torch.Size([4, 64, 32])) |
| 32 | + self.assertEqual(experts.num_experts, 4) |
| 33 | + self.assertTrue(experts.use_grouped_mm) |
| 34 | + |
| 35 | + def test_config_build_no_grouped_mm(self): |
| 36 | + """GroupedExperts.Config(use_grouped_mm=False) is respected.""" |
| 37 | + config = GroupedExperts.Config(use_grouped_mm=False) |
| 38 | + experts = config.build(dim=16, hidden_dim=32, num_experts=2) |
| 39 | + self.assertFalse(experts.use_grouped_mm) |
| 40 | + |
| 41 | + def test_config_build_without_fields_raises(self): |
| 42 | + """build() raises when required fields are not provided.""" |
| 43 | + config = GroupedExperts.Config() |
| 44 | + with self.assertRaises(TypeError): |
| 45 | + config.build() |
| 46 | + |
| 47 | + def test_config_build_partial_fields_raises(self): |
| 48 | + """build() raises when only some required fields are provided.""" |
| 49 | + config = GroupedExperts.Config() |
| 50 | + with self.assertRaises(TypeError): |
| 51 | + config.build(dim=32) |
| 52 | + |
| 53 | + def test_init_weights(self): |
| 54 | + """init_weights re-initializes weight tensors.""" |
| 55 | + config = GroupedExperts.Config() |
| 56 | + experts = config.build(dim=16, hidden_dim=32, num_experts=2) |
| 57 | + |
| 58 | + with torch.no_grad(): |
| 59 | + torch.nn.init.zeros_(experts.w1) |
| 60 | + torch.nn.init.zeros_(experts.w2) |
| 61 | + torch.nn.init.zeros_(experts.w3) |
| 62 | + self.assertTrue(torch.all(experts.w1 == 0)) |
| 63 | + experts.init_weights(init_std=0.02) |
| 64 | + self.assertFalse(torch.all(experts.w1 == 0)) |
| 65 | + self.assertFalse(torch.all(experts.w2 == 0)) |
| 66 | + self.assertFalse(torch.all(experts.w3 == 0)) |
| 67 | + |
| 68 | + def test_init_weights_requires_init_std(self): |
| 69 | + """init_weights raises when init_std is not provided.""" |
| 70 | + config = GroupedExperts.Config() |
| 71 | + experts = config.build(dim=16, hidden_dim=32, num_experts=2) |
| 72 | + with self.assertRaises(AssertionError): |
| 73 | + experts.init_weights() |
| 74 | + |
| 75 | + def test_forward_for_loop(self): |
| 76 | + """Forward pass works with for-loop implementation.""" |
| 77 | + config = GroupedExperts.Config(use_grouped_mm=False) |
| 78 | + experts = config.build(dim=16, hidden_dim=32, num_experts=4) |
| 79 | + experts.init_weights(init_std=0.02) |
| 80 | + |
| 81 | + num_tokens_per_expert = torch.tensor([3, 2, 4, 1]) |
| 82 | + total_tokens = num_tokens_per_expert.sum().item() |
| 83 | + x = torch.randn(total_tokens, 16) |
| 84 | + out = experts(x, num_tokens_per_expert) |
| 85 | + self.assertEqual(out.shape, torch.Size([total_tokens, 16])) |
| 86 | + |
| 87 | + def test_shared_config_builds_independent_instances(self): |
| 88 | + """A single Config can build multiple independent instances.""" |
| 89 | + config = GroupedExperts.Config() |
| 90 | + e1 = config.build(dim=16, hidden_dim=32, num_experts=4) |
| 91 | + e2 = config.build(dim=32, hidden_dim=64, num_experts=8) |
| 92 | + self.assertIsNot(e1, e2) |
| 93 | + self.assertEqual(e1.w1.shape, torch.Size([4, 32, 16])) |
| 94 | + self.assertEqual(e2.w1.shape, torch.Size([8, 64, 32])) |
| 95 | + |
| 96 | + def test_default_use_grouped_mm_true(self): |
| 97 | + """GroupedExperts.Config defaults to use_grouped_mm=True.""" |
| 98 | + config = GroupedExperts.Config() |
| 99 | + self.assertTrue(config.use_grouped_mm) |
| 100 | + |
| 101 | + |
| 102 | +class TestTokenChoiceTopKRouter(unittest.TestCase): |
| 103 | + """Tests for TokenChoiceTopKRouter Config/build pattern.""" |
| 104 | + |
| 105 | + def test_config_build(self): |
| 106 | + """TokenChoiceTopKRouter.Config.build() creates a working instance.""" |
| 107 | + config = TokenChoiceTopKRouter.Config() |
| 108 | + router = config.build(dim=32, num_experts=8) |
| 109 | + self.assertIsInstance(router, TokenChoiceTopKRouter) |
| 110 | + self.assertIsInstance(router, Module) |
| 111 | + self.assertEqual(router.num_experts, 8) |
| 112 | + self.assertEqual(router.top_k, 1) |
| 113 | + self.assertEqual(router.score_func, "sigmoid") |
| 114 | + self.assertFalse(router.route_norm) |
| 115 | + self.assertEqual(router.route_scale, 1.0) |
| 116 | + |
| 117 | + def test_config_build_with_custom_params(self): |
| 118 | + """Config respects custom routing parameters.""" |
| 119 | + config = TokenChoiceTopKRouter.Config( |
| 120 | + top_k=4, |
| 121 | + score_func="softmax", |
| 122 | + route_norm=True, |
| 123 | + route_scale=2.5, |
| 124 | + ) |
| 125 | + router = config.build(dim=64, num_experts=16) |
| 126 | + self.assertEqual(router.top_k, 4) |
| 127 | + self.assertEqual(router.score_func, "softmax") |
| 128 | + self.assertTrue(router.route_norm) |
| 129 | + self.assertEqual(router.route_scale, 2.5) |
| 130 | + |
| 131 | + def test_config_build_with_gate_bias(self): |
| 132 | + """gate=Linear.Config(bias=True) creates a gate with bias.""" |
| 133 | + config = TokenChoiceTopKRouter.Config( |
| 134 | + gate=Linear.Config(bias=True), |
| 135 | + ) |
| 136 | + router = config.build(dim=32, num_experts=8) |
| 137 | + self.assertIsNotNone(router.gate.bias) |
| 138 | + self.assertEqual(router.gate.bias.shape, torch.Size([8])) |
| 139 | + |
| 140 | + def test_config_build_default_gate_no_bias(self): |
| 141 | + """Default gate config has no bias.""" |
| 142 | + config = TokenChoiceTopKRouter.Config() |
| 143 | + router = config.build(dim=32, num_experts=8) |
| 144 | + self.assertIsNone(router.gate.bias) |
| 145 | + |
| 146 | + def test_config_build_without_fields_raises(self): |
| 147 | + """build() raises when required fields are not provided.""" |
| 148 | + config = TokenChoiceTopKRouter.Config() |
| 149 | + with self.assertRaises(TypeError): |
| 150 | + config.build() |
| 151 | + |
| 152 | + def test_gate_shape(self): |
| 153 | + """Gate linear layer has correct shape (dim -> num_experts).""" |
| 154 | + config = TokenChoiceTopKRouter.Config() |
| 155 | + router = config.build(dim=64, num_experts=16) |
| 156 | + self.assertIsInstance(router.gate, Linear) |
| 157 | + self.assertEqual(router.gate.weight.shape, torch.Size([16, 64])) |
| 158 | + |
| 159 | + def test_node_limited_routing_config(self): |
| 160 | + """Config correctly passes node-limited routing parameters.""" |
| 161 | + config = TokenChoiceTopKRouter.Config( |
| 162 | + num_expert_groups=4, |
| 163 | + num_limited_groups=2, |
| 164 | + ) |
| 165 | + router = config.build(dim=32, num_experts=16) |
| 166 | + self.assertEqual(router.num_expert_groups, 4) |
| 167 | + self.assertEqual(router.num_limited_groups, 2) |
| 168 | + |
| 169 | + @unittest.skipUnless(torch.cuda.is_available(), "requires CUDA") |
| 170 | + def test_forward(self): |
| 171 | + """Forward pass returns correct shapes.""" |
| 172 | + config = TokenChoiceTopKRouter.Config(top_k=2, score_func="softmax") |
| 173 | + router = config.build(dim=32, num_experts=8).cuda() |
| 174 | + router.init_weights(init_std=0.02) |
| 175 | + |
| 176 | + x = torch.randn(10, 32, device="cuda") |
| 177 | + top_scores, selected_experts, num_tokens_per_expert = router(x) |
| 178 | + self.assertEqual(top_scores.shape, torch.Size([10, 2])) |
| 179 | + self.assertEqual(selected_experts.shape, torch.Size([10, 2])) |
| 180 | + self.assertEqual(num_tokens_per_expert.shape, torch.Size([8])) |
| 181 | + |
| 182 | + def test_init_weights(self): |
| 183 | + """init_weights delegates to gate.init_weights.""" |
| 184 | + config = TokenChoiceTopKRouter.Config() |
| 185 | + router = config.build(dim=16, num_experts=4) |
| 186 | + |
| 187 | + with torch.no_grad(): |
| 188 | + torch.nn.init.zeros_(router.gate.weight) |
| 189 | + self.assertTrue(torch.all(router.gate.weight == 0)) |
| 190 | + router.init_weights(init_std=0.02) |
| 191 | + self.assertFalse(torch.all(router.gate.weight == 0)) |
| 192 | + |
| 193 | + def test_shared_config_builds_independent_instances(self): |
| 194 | + """A single Config can build multiple independent routers.""" |
| 195 | + config = TokenChoiceTopKRouter.Config(top_k=2) |
| 196 | + r1 = config.build(dim=32, num_experts=8) |
| 197 | + r2 = config.build(dim=64, num_experts=16) |
| 198 | + self.assertIsNot(r1, r2) |
| 199 | + self.assertEqual(r1.gate.weight.shape, torch.Size([8, 32])) |
| 200 | + self.assertEqual(r2.gate.weight.shape, torch.Size([16, 64])) |
| 201 | + |
| 202 | + |
| 203 | +class TestMoEConfig(unittest.TestCase): |
| 204 | + """Tests for MoE with nested GroupedExperts and TokenChoiceTopKRouter configs.""" |
| 205 | + |
| 206 | + def test_config_build_defaults(self): |
| 207 | + """MoE.Config.build() with defaults creates a working MoE.""" |
| 208 | + config = MoE.Config(hidden_dim=64) |
| 209 | + moe = config.build(dim=32) |
| 210 | + self.assertIsInstance(moe, MoE) |
| 211 | + self.assertIsInstance(moe, Module) |
| 212 | + self.assertIsInstance(moe.experts, GroupedExperts) |
| 213 | + self.assertIsInstance(moe.router, TokenChoiceTopKRouter) |
| 214 | + |
| 215 | + def test_config_with_nested_router(self): |
| 216 | + """MoE.Config with custom router config is respected.""" |
| 217 | + config = MoE.Config( |
| 218 | + hidden_dim=64, |
| 219 | + num_experts=16, |
| 220 | + router=TokenChoiceTopKRouter.Config( |
| 221 | + top_k=4, |
| 222 | + score_func="softmax", |
| 223 | + route_norm=True, |
| 224 | + ), |
| 225 | + ) |
| 226 | + moe = config.build(dim=32) |
| 227 | + self.assertEqual(moe.router.top_k, 4) |
| 228 | + self.assertEqual(moe.router.score_func, "softmax") |
| 229 | + self.assertTrue(moe.router.route_norm) |
| 230 | + self.assertEqual(moe.experts.num_experts, 16) |
| 231 | + |
| 232 | + def test_config_with_nested_experts(self): |
| 233 | + """MoE.Config with custom experts config is respected.""" |
| 234 | + config = MoE.Config( |
| 235 | + hidden_dim=64, |
| 236 | + experts=GroupedExperts.Config(use_grouped_mm=False), |
| 237 | + ) |
| 238 | + moe = config.build(dim=32) |
| 239 | + self.assertFalse(moe.experts.use_grouped_mm) |
| 240 | + |
| 241 | + def test_config_with_gate_bias(self): |
| 242 | + """MoE.Config with gate bias via nested router config.""" |
| 243 | + config = MoE.Config( |
| 244 | + hidden_dim=64, |
| 245 | + router=TokenChoiceTopKRouter.Config( |
| 246 | + gate=Linear.Config(bias=True), |
| 247 | + ), |
| 248 | + ) |
| 249 | + moe = config.build(dim=32) |
| 250 | + self.assertIsNotNone(moe.router.gate.bias) |
| 251 | + |
| 252 | + def test_num_experts_propagated(self): |
| 253 | + """num_experts from MoE.Config propagates to experts and router.""" |
| 254 | + config = MoE.Config(hidden_dim=64, num_experts=16) |
| 255 | + moe = config.build(dim=32) |
| 256 | + self.assertEqual(moe.experts.num_experts, 16) |
| 257 | + self.assertEqual(moe.router.num_experts, 16) |
| 258 | + self.assertEqual(moe.experts.w1.shape[0], 16) |
| 259 | + |
| 260 | + def test_hidden_dim_propagated(self): |
| 261 | + """hidden_dim from MoE.Config propagates to experts.""" |
| 262 | + config = MoE.Config(hidden_dim=128, num_experts=4) |
| 263 | + moe = config.build(dim=32) |
| 264 | + self.assertEqual(moe.experts.w1.shape, torch.Size([4, 128, 32])) |
| 265 | + self.assertEqual(moe.experts.w2.shape, torch.Size([4, 32, 128])) |
| 266 | + |
| 267 | + def test_init_weights(self): |
| 268 | + """MoE.init_weights initializes experts and router weights.""" |
| 269 | + config = MoE.Config(hidden_dim=64, num_experts=4) |
| 270 | + moe = config.build(dim=32) |
| 271 | + |
| 272 | + with torch.no_grad(): |
| 273 | + torch.nn.init.zeros_(moe.experts.w1) |
| 274 | + torch.nn.init.zeros_(moe.router.gate.weight) |
| 275 | + self.assertTrue(torch.all(moe.experts.w1 == 0)) |
| 276 | + self.assertTrue(torch.all(moe.router.gate.weight == 0)) |
| 277 | + |
| 278 | + moe.init_weights( |
| 279 | + init_std=0.02, buffer_device=torch.device("cpu") |
| 280 | + ) |
| 281 | + self.assertFalse(torch.all(moe.experts.w1 == 0)) |
| 282 | + self.assertFalse(torch.all(moe.router.gate.weight == 0)) |
| 283 | + |
| 284 | + @unittest.skipUnless(torch.cuda.is_available(), "requires CUDA") |
| 285 | + def test_forward(self): |
| 286 | + """MoE forward pass produces correct output shape.""" |
| 287 | + config = MoE.Config( |
| 288 | + hidden_dim=64, |
| 289 | + num_experts=4, |
| 290 | + num_shared_experts=0, |
| 291 | + experts=GroupedExperts.Config(use_grouped_mm=False), |
| 292 | + ) |
| 293 | + moe = config.build(dim=32).cuda() |
| 294 | + moe.init_weights( |
| 295 | + init_std=0.02, buffer_device=torch.device("cuda") |
| 296 | + ) |
| 297 | + |
| 298 | + x = torch.randn(2, 8, 32, device="cuda") |
| 299 | + out = moe(x) |
| 300 | + self.assertEqual(out.shape, torch.Size([2, 8, 32])) |
| 301 | + |
| 302 | + def test_shared_experts(self): |
| 303 | + """MoE with shared experts creates FeedForward.""" |
| 304 | + config = MoE.Config(hidden_dim=64, num_shared_experts=2) |
| 305 | + moe = config.build(dim=32) |
| 306 | + self.assertIsNotNone(moe.shared_experts) |
| 307 | + |
| 308 | + def test_no_shared_experts(self): |
| 309 | + """MoE with num_shared_experts=0 has no shared experts.""" |
| 310 | + config = MoE.Config(hidden_dim=64, num_shared_experts=0) |
| 311 | + moe = config.build(dim=32) |
| 312 | + self.assertIsNone(moe.shared_experts) |
| 313 | + |
| 314 | + |
| 315 | +if __name__ == "__main__": |
| 316 | + unittest.main() |
0 commit comments