Skip to content

Commit e2ac5d2

Browse files
check asr models (#14989) (#15002)
* check asr models * Apply isort and black reformatting * update return --------- Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com> Signed-off-by: nithinraok <nithinraok@users.noreply.github.com> Signed-off-by: NeMo Bot <nemo-bot@nvidia.com> Co-authored-by: Nithin Rao <nithinrao.koluguri@gmail.com> Co-authored-by: nithinraok <nithinraok@users.noreply.github.com>
1 parent f082c92 commit e2ac5d2

File tree

1 file changed

+17
-2
lines changed

1 file changed

+17
-2
lines changed

nemo/core/connectors/save_restore_connector.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -756,8 +756,23 @@ def _save_state_dict_to_disk(state_dict, filepath):
756756
torch.save(state_dict, filepath)
757757

758758
@staticmethod
759-
def _load_state_dict_from_disk(model_weights, map_location=None):
760-
return torch.load(model_weights, map_location='cpu', weights_only=False)
759+
def _load_state_dict_from_disk(model_weights, map_location='cpu'):
760+
"""
761+
Load model state dict from disk.
762+
763+
Args:
764+
model_weights: Path to the checkpoint file
765+
map_location: Device to map tensors to
766+
767+
Returns:
768+
State dict loaded from checkpoint
769+
770+
"""
771+
try:
772+
return torch.load(model_weights, map_location=map_location, weights_only=True)
773+
except Exception as e:
774+
logging.error(f"Failed to load checkpoint with weights_only=True: {e}")
775+
raise e
761776

762777
@property
763778
def model_config_yaml(self) -> str:

0 commit comments

Comments
 (0)