Skip to content
19 changes: 16 additions & 3 deletions nemo_automodel/_transformers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from typing import Any, Optional

from transformers import AutoConfig

logger = logging.getLogger(__name__)


def _should_load_before_shard(
*,
Expand All @@ -28,16 +31,26 @@ def _should_load_before_shard(
) -> bool:
"""Decide whether to load the checkpoint before FSDP/TP/EP sharding.

Load-before-shard is only safe when running single-GPU (no PP, TP, or EP),
a checkpoint actually needs loading, and no PEFT adapter is involved.
Load-before-shard is only safe when running single-GPU (no PP, TP, or EP)
and a checkpoint actually needs loading.
With any model parallelism the post-shard load path must be used to avoid
NCCL collective mismatches or key/device inconsistencies.

PEFT models skip this path and use the post-shard load so that base and
adapter weights load in the same way as multi-GPU.
"""
no_pp = autopipeline is None
no_tp = tp_size <= 1
no_ep = ep_size <= 1
no_peft = peft_config is None
need_checkpoint_load = bool(pretrained_model_name_or_path and load_base_model)
return no_pp and no_tp and no_ep and need_checkpoint_load and (peft_config is None)
result = no_pp and no_tp and no_ep and no_peft and need_checkpoint_load
logger.debug(
"[_should_load_before_shard] no_pp={} no_tp={} no_ep={} need_load={} peft={} -> {}".format(
no_pp, no_tp, no_ep, need_checkpoint_load, peft_config is not None, result
)
)
return result


def sliding_window_overwrite(model_name: str) -> dict[str, Any]:
Expand Down
42 changes: 27 additions & 15 deletions nemo_automodel/components/checkpoint/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ def load_model(
# For models that need tensor merging and don't have an adapter, try using transformers' conversion
if is_init_step and model_type and requires_tensor_merging(model_type) and not has_state_dict_adapter:
converted_state_dict = _convert_checkpoint_with_transformers(model_state.model[0], model_path, key_mapping)
if converted_state_dict is not None:
if converted_state_dict:
# Load using full_state_dict=True to properly convert tensors to DTensors for FSDP
_load_full_state_dict_into_model(model_state.model, converted_state_dict)
return
Expand Down Expand Up @@ -1258,41 +1258,53 @@ def _convert_checkpoint_with_transformers(
# Sort by key for consistent ordering
sorted_items = sorted(checkpoint_state_dict.items(), key=lambda kv: dot_natural_key(kv[0]))

n_converter_keys = 0
n_rename_keys = 0
for original_key, tensor in sorted_items:
# Rename the key
renamed_key, source_pattern = rename_source_key(original_key, renamings, converters)

# Check if this needs conversion
if source_pattern is not None:
n_converter_keys += 1
# This key is part of a WeightConverter operation
new_converter = deepcopy(pattern_to_converter[source_pattern])
mapping = param_name_to_mapping.setdefault(renamed_key, new_converter)
mapping.add_tensor(renamed_key, original_key, source_pattern, tensor)
else:
n_rename_keys += 1
# Simple rename or pass-through
mapping = param_name_to_mapping.setdefault(renamed_key, WeightRenaming(original_key, renamed_key))
mapping.add_tensor(renamed_key, original_key, original_key, tensor)

logging.debug(
"[convert_ckpt] {} keys matched converters, {} keys simple rename, {} total mappings".format(
n_converter_keys, n_rename_keys, len(param_name_to_mapping)
)
)

# Now apply all the conversions
for first_param_name, mapping in param_name_to_mapping.items():
try:
realized_value = mapping.convert(first_param_name, model=model, config=model.config)
for target_name, param in realized_value.items():
param = param[0] if isinstance(param, list) else param
converted_state_dict[target_name] = param
# convert() returns dict or (dict, errors) depending on transformers version
result = mapping.convert(first_param_name, model=model, config=model.config)
if isinstance(result, tuple):
realized_value = result[0]
elif isinstance(result, dict):
realized_value = result
else:
raise TypeError(
"Expected convert() to return dict or (dict, errors) tuple, got {}".format(type(result))
)
for target_name, param in realized_value.items():
param = param[0] if isinstance(param, list) else param
converted_state_dict[target_name] = param
if callable(getattr(mapping, "reset", None)):
mapping.reset()
except Exception as e:
logging.warning(f"Conversion failed for {first_param_name}: {e}")
continue

logging.info(f"Converted {len(converted_state_dict)} keys using transformers conversion mapping")
logging.debug("Converted {} keys using transformers conversion mapping".format(len(converted_state_dict)))
return converted_state_dict

except Exception as e:
logging.warning(f"Failed to convert checkpoint with transformers: {e}")
import traceback

traceback.print_exc()
logging.warning("Failed to convert checkpoint with transformers: {}".format(e))
return None


Expand Down
11 changes: 9 additions & 2 deletions nemo_automodel/components/checkpoint/stateful_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,12 @@ def state_dict(self) -> dict[str, Any]:
func = partial(get_model_state_dict, options=options)
model_state_dict = {k: v for sd in map(func, self.model) for k, v in sd.items()}

# @akoumpa: the second is_peft statement above keeps buffers in the state dict
# this filtering removes them.
# TODO: this is a hack and we should find a better way to do this.
if self.is_peft:
model_state_dict = {k: v for k, v in model_state_dict.items() if "lora_" in k}

if self.is_tied_lm_head:
# PP models don't have tied embeddings. Safe to pass in model[0] here.
model_state_dict.pop(self.lm_head_param_name, None)
Expand Down Expand Up @@ -300,6 +306,7 @@ def load_state_dict(self, state_dict: dict[str, Any], strict: bool = True) -> No
_drop_outer_prefix(state_dict, "base_model.model.")
# DoRA: reverse the HF PEFT key rename so DCP can match model params
_rename_dora_keys_from_hf(state_dict)
# @akoumpa: I'm not sure about this code.
# For EP models, DCP's set_model_state_dict silently skips EP-sharded
# LoRA params (strict=False hides the FQN mismatch caused by custom
# expert state_dict() keys like gate_up_linear.weight0). Bypass DCP.
Expand All @@ -320,8 +327,8 @@ def load_state_dict(self, state_dict: dict[str, Any], strict: bool = True) -> No
# weight tying guarantees this is identical to the embedding weight
state_dict[lm_head_param_name] = lm_head_weight.detach()

func = partial(set_model_state_dict, model_state_dict=state_dict, options=options)
list(map(func, self.model))
for model_part in self.model:
set_model_state_dict(model_part, state_dict, options=options)

def _get_base_model_state_dict(self) -> dict[str, Any]:
model_state_dict = {k: v for sd in map(get_model_state_dict, self.model) for k, v in sd.items()}
Expand Down
25 changes: 22 additions & 3 deletions tests/functional_tests/checkpoint/test_peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,14 @@ def test_hf_peft_checkpoint(force_hf, use_triton):
cfg.model.force_hf = force_hf

try:
# Clean up any leftover checkpoints from previous runs to avoid
# auto-detection loading stale state into the fresh model.
ckpt_dir = Path(cfg.get("checkpoint.checkpoint_dir", "checkpoints"))
if ckpt_dir.exists() and (not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0):
shutil.rmtree(ckpt_dir)
if torch.distributed.is_initialized():
torch.distributed.barrier()

# set use_triton value based on parsed input
expected_automodel_peft_config["use_triton"] = cfg.peft.use_triton

Expand All @@ -299,6 +307,7 @@ def test_hf_peft_checkpoint(force_hf, use_triton):
trainer.model_parts,
trainer.optimizer,
trainer.lr_scheduler,
is_peft=trainer.checkpointer.config.is_peft,
).state_dict()["optim"]
)

Expand Down Expand Up @@ -384,13 +393,23 @@ def test_hf_peft_checkpoint(force_hf, use_triton):
restored_model = restored_model.model_parts[0]
source_model_loss = get_validation_loss(trainer.model_parts[0], val_batch, trainer.loss_fn, trainer.dist_env.device)
restored_model_loss = get_validation_loss(restored_model, val_batch, trainer.loss_fn, trainer.dist_env.device)
errors = []
for (source_name, source_p), (restore_name, restore_p) in zip(trainer.model_parts[0].named_parameters(), restored_model.named_parameters()):
assert source_name == restore_name, "Parameter name mismatch"
if isinstance(source_p, torch.distributed.tensor.DTensor):
source_p = source_p.to_local()
source_p = source_p.full_tensor()
if isinstance(restore_p, torch.distributed.tensor.DTensor):
restore_p = restore_p.to_local()
assert torch.allclose(source_p, restore_p), "Parameter value mismatch for " + source_name
restore_p = restore_p.full_tensor()
if not torch.allclose(source_p, restore_p):
errors.append(("Parameter value mismatch for " + source_name, source_p, restore_p))
if errors:
print("Parameter value mismatches:")
for error in errors:
print(error[0])
print(error[1])
print(error[2])
print("-"*80)
raise Exception("Parameter value mismatches")
assert torch.allclose(source_model_loss, restored_model_loss), "Model loss mismatch"

# compare the recipe configs
Expand Down
Loading