Skip to content

Commit ccfa065

Browse files
committed
fix
Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
1 parent d7c544b commit ccfa065

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

nemo_automodel/components/checkpoint/checkpointing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1174,7 +1174,7 @@ def _load_hf_checkpoint_preserving_dtype(model_path: str) -> Optional[dict[str,
11741174
from safetensors import safe_open
11751175
except ImportError:
11761176
return None
1177-
if not os.path.isdir(model_path) and not (os.path.isfile(model_path) and model_path.endswith(".safetensors")):
1177+
if not _is_safetensors_checkpoint(model_path):
11781178
return None
11791179
out: dict[str, torch.Tensor] = {}
11801180
if os.path.isfile(model_path):

0 commit comments

Comments
 (0)