Skip to content

Commit 0663edf

Browse files
committed
fix(infra): keep model.to(device) on unsharded post-shard load
Persistent buffers initialized via torch.tensor()/torch.ones() inside init_empty_weights() (e.g. Gemma4's Gemma4ClippableLinear input_min/max, Gemma4TextDecoderLayer layer_scalar) stay on CPU because the context only patches register_parameter, not register_buffer. The post-shard load path then unconditionally skipped model.to(device), leaving these buffers stranded and tripping torch.clamp on cuda:0 vs cpu. The skip exists for FSDP's reset_sharded_param issue with tied params under TP>1 (pytorch/pytorch#151085). Narrow it to its actual precondition: any DTensor in the model, so single-GPU, DDP, and other unsharded configs still run model.to(device). Add unit coverage for both the unsharded and DTensor-sharded checkpoint load paths. Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
1 parent 8222a4f commit 0663edf

2 files changed

Lines changed: 52 additions & 9 deletions

File tree

nemo_automodel/_transformers/infrastructure.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -565,11 +565,20 @@ def apply_model_infrastructure(
565565
if autopipeline is None:
566566
print_trainable_parameters(model) # Once model's been sharded
567567
# Ensure model is on the correct device.
568-
# Skip when checkpoint was loaded post-shard (params are already on the
569-
# target device) to avoid triggering FSDP's reset_sharded_param which
570-
# fails on tied parameters (e.g. lm_head/embed_tokens with TP>1).
568+
# Skip only when params are actually sharded (any DTensor in the model)
569+
# AND the checkpoint was loaded post-shard. Calling model.to(device) on
570+
# sharded params triggers FSDP's reset_sharded_param, which fails on
571+
# tied parameters (e.g. lm_head/embed_tokens with TP>1).
571572
# See: https://github.com/pytorch/pytorch/issues/151085
572-
if not should_load_checkpoint:
573+
# In unsharded cases (single-GPU, DDP, or any combination of TP/DP/CP/EP
574+
# that left params as plain tensors), model.to(device) must still run so
575+
# that persistent buffers not present in the checkpoint (e.g. Gemma4's
576+
# Gemma4ClippableLinear input_min/max, Gemma4TextDecoderLayer
577+
# layer_scalar) reach the GPU.
578+
from torch.distributed.tensor import DTensor
579+
580+
has_sharded_params = any(isinstance(p, DTensor) for p in model.parameters())
581+
if not (should_load_checkpoint and has_sharded_params):
573582
try:
574583
model.to(device, non_blocking=True)
575584
except NotImplementedError as e:

tests/unit_tests/_transformers/test_infrastructure.py

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -190,8 +190,8 @@ def test_megatron_fsdp_skips_post_shard_init(self):
190190

191191
mock_ckpt.initialize_model_weights.assert_not_called()
192192

193-
def test_skips_model_to_device_when_checkpoint_loaded(self):
194-
"""model.to(device) should be skipped when should_load_checkpoint is True (tied params + FSDP fix)."""
193+
def test_calls_model_to_device_when_checkpoint_loaded_without_dtensor(self):
194+
"""Unsharded post-shard checkpoint loads should still move buffers with model.to(device)."""
195195
from nemo_automodel._transformers.infrastructure import apply_model_infrastructure
196196

197197
model = _DummyModel()
@@ -217,9 +217,43 @@ def test_skips_model_to_device_when_checkpoint_loaded(self):
217217
pretrained_model_name_or_path="test/model",
218218
)
219219

220-
# model.to(device) should NOT have been called — checkpoint loading
221-
# already placed params on device, and calling to() would trigger
222-
# FSDP's reset_sharded_param failure on tied parameters.
220+
mock_to.assert_called_once_with(torch.device("cpu"), non_blocking=True)
221+
222+
def test_skips_model_to_device_when_checkpoint_loaded_with_dtensor(self, monkeypatch):
223+
"""DTensor-sharded post-shard checkpoint loads should skip model.to(device)."""
224+
from nemo_automodel._transformers.infrastructure import apply_model_infrastructure
225+
import torch.distributed.tensor as dist_tensor
226+
227+
class FakeDTensor:
228+
pass
229+
230+
class ModelWithShardedParameter(_DummyModel):
231+
def parameters(self, recurse=True):
232+
return iter([FakeDTensor()])
233+
234+
monkeypatch.setattr(dist_tensor, "DTensor", FakeDTensor)
235+
model = ModelWithShardedParameter()
236+
237+
with (
238+
patch(f"{_INFRA_MODULE}.get_world_size_safe", return_value=1),
239+
patch(f"{_INFRA_MODULE}._supports_logits_to_keep", return_value=True),
240+
patch(f"{_INFRA_MODULE}.print_trainable_parameters"),
241+
patch(f"{_INFRA_MODULE}._should_load_before_shard", return_value=False),
242+
patch(f"{_INFRA_MODULE}.Checkpointer") as MockCheckpointer,
243+
patch.object(model, "to", wraps=model.to) as mock_to,
244+
):
245+
mock_ckpt = MockCheckpointer.return_value
246+
mock_ckpt.config = MagicMock()
247+
mock_ckpt.config.dequantize_base_checkpoint = False
248+
249+
apply_model_infrastructure(
250+
model=model,
251+
is_meta_device=True,
252+
device=torch.device("cpu"),
253+
load_base_model=True,
254+
pretrained_model_name_or_path="test/model",
255+
)
256+
223257
mock_to.assert_not_called()
224258

225259
def test_calls_model_to_device_when_from_config_meta(self):

0 commit comments

Comments
 (0)