Skip to content

Commit 5872e02

Browse files
fix: apply param freeze the right place for moe lora (#1252)
* fix: freeze weights for moe lora finetuning Signed-off-by: Zhiyu Li <zhiyul@NVIDIA.com> comments Signed-off-by: Zhiyu Li <zhiyul@NVIDIA.com> comments Signed-off-by: Zhiyu Li <zhiyul@NVIDIA.com> * better Signed-off-by: Zhiyu Li <zhiyul@NVIDIA.com> * lint Signed-off-by: Zhiyu Li <zhiyul@NVIDIA.com> * fix test Signed-off-by: Zhiyu Li <zhiyul@NVIDIA.com> --------- Signed-off-by: Zhiyu Li <zhiyul@NVIDIA.com>
1 parent 7419674 commit 5872e02

3 files changed

Lines changed: 23 additions & 4 deletions

File tree

nemo_automodel/_transformers/infrastructure.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,8 @@ def _apply_peft_and_lower_precision(
8080
logger.info("Enabling PEFT with Pipeline Parallelism")
8181
logger.info("Disabling Triton with Pipeline Parallelism Enabled.")
8282
peft_config.use_triton = False
83-
apply_lora_to_linear_modules(model, peft_config, quantization_config=quantization_config)
83+
# Skip freeze here - will do global freeze after checkpoint loading
84+
apply_lora_to_linear_modules(model, peft_config, quantization_config=quantization_config, skip_freeze=True)
8485

8586
# FP8
8687
if fp8_config is not None:
@@ -446,6 +447,15 @@ def apply_model_infrastructure(
446447
load_base_model=load_base_model,
447448
)
448449

450+
# Freeze parameters after checkpoint loading and parallelization
451+
# This catches params created during parallelization (e.g., GroupedExpertsTE in init_token_dispatcher)
452+
if peft_config is not None:
453+
models_to_freeze = model.parts if hasattr(model, "parts") else [model]
454+
for mp in models_to_freeze:
455+
for name, param in mp.named_parameters():
456+
if "lora_" not in name and param.requires_grad:
457+
param.requires_grad_(False)
458+
449459
if autopipeline is None:
450460
print_trainable_parameters(model) # Once model's been sharded
451461
# Ensure model is on the correct device; AutoPipeline takes care of it internally

nemo_automodel/components/_peft/lora.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -464,6 +464,7 @@ def apply_lora_to_linear_modules(
464464
model: nn.Module,
465465
peft_config: PeftConfig,
466466
quantization_config=None,
467+
skip_freeze: bool = False,
467468
) -> int:
468469
"""
469470
Replace selected nn.Linear layers with LinearLoRA layers (in-place).
@@ -472,6 +473,7 @@ def apply_lora_to_linear_modules(
472473
model: The model to apply LoRA to.
473474
peft_config: PEFT configuration for LoRA parameters.
474475
quantization_config: Optional separate QLoRA quantization configuration.
476+
skip_freeze: If True, skip the global parameter freeze (caller will handle it later).
475477
476478
Returns:
477479
Number of modules that were modified with LoRA.
@@ -480,8 +482,9 @@ def apply_lora_to_linear_modules(
480482
target_modules accepts wildcard fragments, e.g. ["q_proj", "k_proj", ".*fc.*"].
481483
"""
482484
# Freeze base model parameters
483-
for w in model.parameters():
484-
w.requires_grad_(False)
485+
if not skip_freeze:
486+
for w in model.parameters():
487+
w.requires_grad_(False)
485488

486489
is_causal_lm = False
487490
try:

tests/unit_tests/recipes/test_train_ft.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,13 @@ def test_peft_without_pipeline_parallelism(caplog):
230230
with patch('nemo_automodel._transformers.infrastructure._supports_logits_to_keep', return_value=True):
231231
with patch('nemo_automodel._transformers.auto_model._verify_sdpa_support'):
232232
with patch('nemo_automodel._transformers.infrastructure._shard_ep_fsdp') as mock_shard:
233-
mock_shard.return_value = DummyModel()
233+
# Return a DummyModel with lora_dummy_param so freeze doesn't remove all trainable params
234+
sharded_model = DummyModel()
235+
sharded_model.register_parameter(
236+
"lora_dummy_param",
237+
nn.Parameter(torch.tensor(1.0, device=torch.device("cuda")), requires_grad=True)
238+
)
239+
mock_shard.return_value = sharded_model
234240
with caplog.at_level(logging.INFO):
235241
# This should work fine without PP
236242
model = build_model(

0 commit comments

Comments
 (0)