@@ -988,22 +988,31 @@ def _revert_weight_conversion_noop(model: Any, state_dict: dict) -> dict:
988988 return state_dict
989989
990990
991- def _patch_revert_weight_conversion ( ) -> list [ tuple [Any , Any ]] :
992- """Patch revert_weight_conversion in transformers to avoid IndexError on scalar tensors ."""
991+ def _try_patch_module ( mod_path : str ) -> tuple [Any , Any ] | None :
992+ """Try to patch revert_weight_conversion in a single module ."""
993993 import importlib
994994
995+ try :
996+ mod = importlib .import_module (mod_path )
997+ if hasattr (mod , "revert_weight_conversion" ):
998+ original = getattr (mod , "revert_weight_conversion" )
999+ setattr (mod , "revert_weight_conversion" , _revert_weight_conversion_noop )
1000+ return (mod , original )
1001+ except (ImportError , AttributeError ):
1002+ pass
1003+ return None
1004+
1005+
1006+ def _patch_revert_weight_conversion () -> list [tuple [Any , Any ]]:
1007+ """Patch revert_weight_conversion in transformers to avoid IndexError on scalar tensors."""
9951008 patches : list [tuple [Any , Any ]] = []
9961009 for mod_path in [
9971010 "transformers.core_model_loading" ,
9981011 "transformers.modeling_utils" ,
9991012 ]:
1000- try :
1001- mod = importlib .import_module (mod_path )
1002- if hasattr (mod , "revert_weight_conversion" ):
1003- patches .append ((mod , getattr (mod , "revert_weight_conversion" )))
1004- setattr (mod , "revert_weight_conversion" , _revert_weight_conversion_noop )
1005- except (ImportError , AttributeError ):
1006- pass
1013+ result = _try_patch_module (mod_path )
1014+ if result is not None :
1015+ patches .append (result )
10071016 return patches
10081017
10091018
0 commit comments