Skip to content

Commit 3305309

Browse files
committed
address reviews
Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
1 parent b8f3f25 commit 3305309

File tree

2 files changed

+10
-45
lines changed

2 files changed

+10
-45
lines changed

modelopt/torch/export/layer_utils.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -327,21 +327,12 @@ def is_mlp(module: nn.Module) -> bool:
327327

328328
def is_moe(module: nn.Module) -> bool:
329329
"""Returns whether the module is an MOE layer."""
330-
return any(
331-
key in type(module).__name__.lower()
332-
for key in [
333-
"MixtralSparseMoeBlock".lower(),
334-
"ArcticMoE".lower(),
335-
"DbrxFFN".lower(),
336-
"MoELayer".lower(),
337-
"PhimoeSparseMoeBlock".lower(),
338-
"DeepseekMoE".lower(),
339-
"Qwen2MoeSparseMoeBlock".lower(),
340-
"Qwen3MoeSparseMoeBlock".lower(),
341-
"Qwen3NextSparseMoeBlock".lower(),
342-
"Qwen3_5MoeSparseMoeBlock".lower(),
343-
]
344-
)
330+
name = type(module).__name__.lower()
331+
# Auto-detect common MoE patterns
332+
if name.endswith("sparsemoeblock") or "moelayer" in name:
333+
return True
334+
# Explicit matches for non-standard naming
335+
return any(key in name for key in ["arcticmoe", "deepseekmoe", "dbrxffn"])
345336

346337

347338
def is_quantlinear(module: nn.Module) -> bool:

modelopt/torch/quantization/plugins/huggingface.py

Lines changed: 4 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -983,38 +983,12 @@ def unpack_weight(self):
983983
pass
984984

985985

986-
class _QuantQwen35MoeSparseMoeBlock(_QuantSparseMoe):
987-
"""Qwen3.5 MoE stores top_k/num_experts in the router (self.gate), not as direct attributes.
988-
989-
We override forward instead of just bridging attributes because the router (self.gate)
990-
uses its own top_k internally for routing decisions. We must modify self.gate.top_k
991-
directly so all experts see calibration data.
992-
"""
993-
994-
def _setup(self):
995-
self.num_experts = self.experts.num_experts
996-
997-
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
998-
if any(getattr(m, "_if_calib", False) for m in self.experts.modules()):
999-
# Force all tokens to all experts during calibration
1000-
original_top_k = self.gate.top_k
1001-
self.gate.top_k = self.num_experts
1002-
super(_QuantSparseMoe, self).forward(hidden_states)
1003-
self.gate.top_k = original_top_k
1004-
return super(_QuantSparseMoe, self).forward(hidden_states)
1005-
1006-
1007986
try:
1008-
from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import (
1009-
Qwen3_5MoeExperts,
1010-
Qwen3_5MoeSparseMoeBlock,
1011-
)
1012-
1013-
if Qwen3_5MoeSparseMoeBlock not in QuantModuleRegistry:
1014-
QuantModuleRegistry.register({Qwen3_5MoeSparseMoeBlock: "hf.Qwen3_5MoeSparseMoeBlock"})(
1015-
_QuantQwen35MoeSparseMoeBlock
1016-
)
987+
from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import Qwen3_5MoeExperts
1017988

989+
# Qwen3_5MoeSparseMoeBlock registration is handled by register_sparse_moe_on_the_fly
990+
# (auto-detected via gate.top_k + gate.num_experts + experts pattern).
991+
# Only the fused expert weights need explicit registration.
1018992
if Qwen3_5MoeExperts not in QuantModuleRegistry:
1019993
QuantModuleRegistry.register({Qwen3_5MoeExperts: "hf.Qwen3_5MoeExperts"})(
1020994
_QuantQwen35MoeExperts

0 commit comments

Comments
 (0)