Skip to content

Commit b8f3f25

Browse files
committed
adopt *experts.{id}.* naming pattern
Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
1 parent 945e1c6 commit b8f3f25

File tree

1 file changed

+18
-9
lines changed

1 file changed

+18
-9
lines changed

modelopt/torch/export/unified_export_hf.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -988,22 +988,31 @@ def _revert_weight_conversion_noop(model: Any, state_dict: dict) -> dict:
988988
return state_dict
989989

990990

991-
def _patch_revert_weight_conversion() -> list[tuple[Any, Any]]:
992-
"""Patch revert_weight_conversion in transformers to avoid IndexError on scalar tensors."""
991+
def _try_patch_module(mod_path: str) -> tuple[Any, Any] | None:
992+
"""Try to patch revert_weight_conversion in a single module."""
993993
import importlib
994994

995+
try:
996+
mod = importlib.import_module(mod_path)
997+
if hasattr(mod, "revert_weight_conversion"):
998+
original = getattr(mod, "revert_weight_conversion")
999+
setattr(mod, "revert_weight_conversion", _revert_weight_conversion_noop)
1000+
return (mod, original)
1001+
except (ImportError, AttributeError):
1002+
pass
1003+
return None
1004+
1005+
1006+
def _patch_revert_weight_conversion() -> list[tuple[Any, Any]]:
1007+
"""Patch revert_weight_conversion in transformers to avoid IndexError on scalar tensors."""
9951008
patches: list[tuple[Any, Any]] = []
9961009
for mod_path in [
9971010
"transformers.core_model_loading",
9981011
"transformers.modeling_utils",
9991012
]:
1000-
try:
1001-
mod = importlib.import_module(mod_path)
1002-
if hasattr(mod, "revert_weight_conversion"):
1003-
patches.append((mod, getattr(mod, "revert_weight_conversion")))
1004-
setattr(mod, "revert_weight_conversion", _revert_weight_conversion_noop)
1005-
except (ImportError, AttributeError):
1006-
pass
1013+
result = _try_patch_module(mod_path)
1014+
if result is not None:
1015+
patches.append(result)
10071016
return patches
10081017

10091018

0 commit comments

Comments
 (0)