diff --git a/nemo_automodel/_transformers/infrastructure.py b/nemo_automodel/_transformers/infrastructure.py index dbb2419a4..da747c973 100644 --- a/nemo_automodel/_transformers/infrastructure.py +++ b/nemo_automodel/_transformers/infrastructure.py @@ -394,14 +394,11 @@ def apply_model_infrastructure( ) # Handle checkpointer config updates if checkpointer is provided - dequantize_base_checkpoint = False if checkpointer is not None: if checkpointer.config.dequantize_base_checkpoint is None: - # try to infer whether the base weights are quantized checkpointer.config.dequantize_base_checkpoint = hasattr( getattr(model, "config", None), "quantization_config" ) - dequantize_base_checkpoint = checkpointer.config.dequantize_base_checkpoint # Apply PEFT and lower precision if configured # When on meta device, wrap in init_empty_weights() so new LoRA modules are also on meta device @@ -444,7 +441,7 @@ def apply_model_infrastructure( # hold a list copy of the model state dict keys before any parallelization. To be used during checkpoint saving in safetensors format. pre_shard_hf_state_dict_keys = list( - _maybe_adapt_state_dict_to_hf(model, model.state_dict(), quantization=dequantize_base_checkpoint).keys() + _maybe_adapt_state_dict_to_hf(model, model.state_dict(), quantization=False).keys() ) # Apply freezing before sharding diff --git a/nemo_automodel/components/checkpoint/addons.py b/nemo_automodel/components/checkpoint/addons.py index ed8b94c38..930f6f3f0 100644 --- a/nemo_automodel/components/checkpoint/addons.py +++ b/nemo_automodel/components/checkpoint/addons.py @@ -68,6 +68,7 @@ def pre_save(self, **kwargs) -> None: _maybe_save_custom_model_code(original_model_path, hf_metadata_dir) # save the config.json file if hasattr(model_part, "config"): + _maybe_strip_quantization_config(model_part) with open(os.path.join(hf_metadata_dir, "config.json"), "w") as f: f.write(model_part.config.to_json_string()) # save the generation_config.json file @@ -329,6 +330,27 @@ def _extract_target_modules(model: nn.Module) -> list[str]: return sorted(list(final_target_modules)) +def _maybe_strip_quantization_config(model_part: nn.Module) -> None: + """Remove ``quantization_config`` from the HF config when no parameters are quantized. + + Models loaded from quantized checkpoints (e.g. mxfp4 GPT-OSS) carry a + ``quantization_config`` on their ``config`` object. After dequantization + all parameters are standard floating-point, but the stale config entry would + still be written to the saved ``config.json``. This strips it so the output + checkpoint is a clean bf16 checkpoint, consistent with e.g. + ``unsloth/gpt-oss-20b-BF16``. + """ + config = getattr(model_part, "config", None) + if config is None or not hasattr(config, "quantization_config"): + return + + _QUANTIZED_DTYPES = frozenset({torch.uint8, torch.int8}) + if any(p.dtype in _QUANTIZED_DTYPES for p in model_part.parameters()): + return + + delattr(config, "quantization_config") + + def _maybe_save_custom_model_code(original_model_path: str | None, hf_metadata_dir: str) -> None: """ Save the custom model code if it exists. This function preserves the original directory structure. diff --git a/tests/functional_tests/checkpoint/create_gptoss_2l_mxfp4.py b/tests/functional_tests/checkpoint/create_gptoss_2l_mxfp4.py new file mode 100644 index 000000000..d2989f011 --- /dev/null +++ b/tests/functional_tests/checkpoint/create_gptoss_2l_mxfp4.py @@ -0,0 +1,233 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Create a minimal 2-layer GPT-OSS checkpoint with mxfp4-quantized expert weights. + +The mxfp4 block/scale geometry uses the same hardcoded (G=90, B=16) that the +GPTOSSStateDictAdapter.convert_single_tensor_to_hf produces, so that the DCP +planner sees matching shapes when loading. + +Dequantization produces: + blocks (E, dim, G, B) → (E, G*B*2, dim) [after reshape + transpose] + +where G*B*2 = 90*16*2 = 2880. + +For both projections the adapter reads dim = tensor.shape[-1]: + gate_and_up_projs (E, hidden, 2*inter) → dim = 2*inter → 2*inter = 2880 → inter = 1440? + down_projs (E, inter, hidden) → dim = hidden → hidden = 2880 + +BUT the *dequanted* tensor must match the internal shape: + gate_up_proj → (E, 2880, 2*inter) needs 2880 = hidden ✓ + down_proj → (E, 2880, hidden) needs 2880 = inter + +So both hidden_size AND intermediate_size must equal 2880. + +Run standalone: + python tests/functional_tests/checkpoint/create_gptoss_2l_mxfp4.py \\ + --output-dir /tmp/gptoss_2l_mxfp4 \\ + --tokenizer-dir $TEST_DATA_DIR/hf_mixtral_2l/ +""" + +import argparse +import json +import os +import shutil + +import torch +from safetensors.torch import save_file + +_G, _B = 90, 16 # hardcoded in GPTOSSStateDictAdapter.convert_single_tensor_to_hf + +_VOCAB_SIZE = 32000 +_HIDDEN = 2880 # forced: dequant produces (E, 2880, dim), must equal hidden for down_proj +_INTER = 2880 # forced: dequant produces (E, 2880, dim), must equal inter for down_proj +_HEADS = 90 # 90 * 32 = 2880 +_KV_HEADS = 2 +_HEAD_DIM = 32 +_EXPERTS = 2 +_LAYERS = 2 + + +def _build_config() -> dict: + return { + "architectures": ["GptOssForCausalLM"], + "model_type": "gpt_oss", + "vocab_size": _VOCAB_SIZE, + "hidden_size": _HIDDEN, + "num_attention_heads": _HEADS, + "num_key_value_heads": _KV_HEADS, + "head_dim": _HEAD_DIM, + "num_hidden_layers": _LAYERS, + "intermediate_size": _INTER, + "max_position_embeddings": 512, + "rms_norm_eps": 1e-6, + "sliding_window": 256, + "layer_types": ["full_attention", "sliding_attention"], + "num_local_experts": _EXPERTS, + "num_experts_per_tok": 2, + "router_aux_loss_coef": 0.01, + "rope_scaling": { + "rope_type": "yarn", + "factor": 32.0, + "beta_fast": 32.0, + "beta_slow": 1.0, + "truncate": False, + "original_max_position_embeddings": 512, + }, + "torch_dtype": "bfloat16", + "quantization_config": { + "quant_method": "mxfp4", + "modules_to_not_convert": [ + "model.layers.*.self_attn", + "model.layers.*.mlp.router", + "model.embed_tokens", + "lm_head", + ], + }, + "tie_word_embeddings": False, + } + + +def _build_tensors() -> dict[str, torch.Tensor]: + """Return a state-dict with mxfp4 expert weights and bf16 dense weights.""" + t: dict[str, torch.Tensor] = {} + bf = torch.bfloat16 + + t["model.embed_tokens.weight"] = torch.randn(_VOCAB_SIZE, _HIDDEN, dtype=bf) + t["lm_head.weight"] = torch.randn(_VOCAB_SIZE, _HIDDEN, dtype=bf) + t["model.norm.weight"] = torch.ones(_HIDDEN, dtype=bf) + + up_proj_dim = 2 * _INTER # gated activation → gate + up concatenated + kv_dim = _KV_HEADS * _HEAD_DIM + + for li in range(_LAYERS): + p = f"model.layers.{li}" + + # ── Attention ── + t[f"{p}.self_attn.q_proj.weight"] = torch.randn(_HEADS * _HEAD_DIM, _HIDDEN, dtype=bf) + t[f"{p}.self_attn.q_proj.bias"] = torch.zeros(_HEADS * _HEAD_DIM, dtype=bf) + t[f"{p}.self_attn.k_proj.weight"] = torch.randn(kv_dim, _HIDDEN, dtype=bf) + t[f"{p}.self_attn.k_proj.bias"] = torch.zeros(kv_dim, dtype=bf) + t[f"{p}.self_attn.v_proj.weight"] = torch.randn(kv_dim, _HIDDEN, dtype=bf) + t[f"{p}.self_attn.v_proj.bias"] = torch.zeros(kv_dim, dtype=bf) + t[f"{p}.self_attn.o_proj.weight"] = torch.randn(_HIDDEN, _HEADS * _HEAD_DIM, dtype=bf) + t[f"{p}.self_attn.o_proj.bias"] = torch.zeros(_HIDDEN, dtype=bf) + t[f"{p}.self_attn.sinks"] = torch.zeros(_HEADS, dtype=torch.float32) + + # ── Router ── + t[f"{p}.mlp.router.weight"] = torch.randn(_EXPERTS, _HIDDEN, dtype=bf) + t[f"{p}.mlp.router.bias"] = torch.zeros(_EXPERTS, dtype=bf) + + # ── Layer norms ── + t[f"{p}.input_layernorm.weight"] = torch.ones(_HIDDEN, dtype=bf) + t[f"{p}.post_attention_layernorm.weight"] = torch.ones(_HIDDEN, dtype=bf) + + # ── Expert biases (bf16, not quantized) ── + t[f"{p}.mlp.experts.gate_up_proj_bias"] = torch.zeros(_EXPERTS, up_proj_dim, dtype=bf) + t[f"{p}.mlp.experts.down_proj_bias"] = torch.zeros(_EXPERTS, _HIDDEN, dtype=bf) + + # ── MXFP4 blocks / scales ── + # gate_and_up_projs internal shape: (E, hidden, 2*inter) + # adapter: dim = tensor.shape[-1] = 2*inter = up_proj_dim + # blocks = (E, up_proj_dim, G, B) + t[f"{p}.mlp.experts.gate_up_proj_blocks"] = torch.randint( + 0, 256, (_EXPERTS, up_proj_dim, _G, _B), dtype=torch.uint8 + ) + t[f"{p}.mlp.experts.gate_up_proj_scales"] = torch.full((_EXPERTS, up_proj_dim, _G), 127, dtype=torch.uint8) + + # down_projs internal shape: (E, inter, hidden) + # adapter: dim = tensor.shape[-1] = hidden + # blocks = (E, hidden, G, B) + t[f"{p}.mlp.experts.down_proj_blocks"] = torch.randint(0, 256, (_EXPERTS, _HIDDEN, _G, _B), dtype=torch.uint8) + t[f"{p}.mlp.experts.down_proj_scales"] = torch.full((_EXPERTS, _HIDDEN, _G), 127, dtype=torch.uint8) + + return t + + +def _build_index(tensors: dict[str, torch.Tensor], filename: str) -> dict: + total_bytes = 0 + weight_map: dict[str, str] = {} + for fqn, tensor in tensors.items(): + total_bytes += tensor.numel() * tensor.element_size() + weight_map[fqn] = filename + return {"metadata": {"total_size": total_bytes}, "weight_map": weight_map} + + +def _verify_mxfp4(output_dir: str, config: dict, expected_mxfp4_keys: list[str]) -> None: + """Re-open the saved checkpoint and verify mxfp4 tensors are present and correct.""" + from safetensors import safe_open + + st_path = os.path.join(output_dir, "model.safetensors") + with safe_open(st_path, framework="pt", device="cpu") as f: + saved_keys = set(f.keys()) + + for k in expected_mxfp4_keys: + assert k in saved_keys, f"mxfp4 key missing from saved checkpoint: {k}" + + with safe_open(st_path, framework="pt", device="cpu") as f: + for k in expected_mxfp4_keys: + t = f.get_tensor(k) + assert t.dtype == torch.uint8, f"{k}: expected uint8 but got {t.dtype}" + if "_blocks" in k: + assert t.shape[-2:] == (_G, _B), f"{k}: expected last dims ({_G}, {_B}) but got {t.shape[-2:]}" + elif "_scales" in k: + assert t.shape[-1] == _G, f"{k}: expected last dim {_G} but got {t.shape[-1]}" + + with open(os.path.join(output_dir, "config.json")) as f: + saved_config = json.load(f) + qcfg = saved_config.get("quantization_config", {}) + assert qcfg.get("quant_method") == "mxfp4", ( + f"config.json quant_method should be 'mxfp4', got {qcfg.get('quant_method')!r}" + ) + + print(f" ✓ verified {len(expected_mxfp4_keys)} mxfp4 keys (uint8, correct shapes, config.json)") + + +def create_checkpoint(output_dir: str, tokenizer_dir: str) -> None: + os.makedirs(output_dir, exist_ok=True) + + config = _build_config() + tensors = _build_tensors() + safetensors_name = "model.safetensors" + + with open(os.path.join(output_dir, "config.json"), "w") as f: + json.dump(config, f, indent=2) + + save_file(tensors, os.path.join(output_dir, safetensors_name)) + + index = _build_index(tensors, safetensors_name) + with open(os.path.join(output_dir, "model.safetensors.index.json"), "w") as f: + json.dump(index, f, indent=2) + + for fname in os.listdir(tokenizer_dir): + src = os.path.join(tokenizer_dir, fname) + if os.path.isfile(src) and ("token" in fname.lower() or fname == "special_tokens_map.json"): + shutil.copy2(src, os.path.join(output_dir, fname)) + + total_mb = sum(t.numel() * t.element_size() for t in tensors.values()) / (1 << 20) + mxfp4_keys = [k for k in tensors if "_blocks" in k or "_scales" in k] + print(f"Created GPT-OSS 2L mxfp4 checkpoint in {output_dir}") + print(f" tensor keys: {len(tensors)} entries") + print(f" mxfp4 keys: {len(mxfp4_keys)}") + print(f" total size: {total_mb:.1f} MB") + + _verify_mxfp4(output_dir, config, mxfp4_keys) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", required=True) + parser.add_argument("--tokenizer-dir", required=True) + args = parser.parse_args() + create_checkpoint(args.output_dir, args.tokenizer_dir) diff --git a/tests/functional_tests/checkpoint/test_hf_consolidated_gptoss_mxfp4.py b/tests/functional_tests/checkpoint/test_hf_consolidated_gptoss_mxfp4.py new file mode 100644 index 000000000..1fb3483e2 --- /dev/null +++ b/tests/functional_tests/checkpoint/test_hf_consolidated_gptoss_mxfp4.py @@ -0,0 +1,91 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Functional test: GPT-OSS mxfp4 base checkpoint → bf16 fine-tune → consolidated safetensors. + +Validates that stale mxfp4 FQNs (_blocks/_scales) in the base checkpoint +index do not produce invalid safetensors during consolidation. +""" + +import shutil +from pathlib import Path + +import torch +import torch.distributed +from safetensors import safe_open + +from nemo_automodel.components.config._arg_parser import parse_args_and_load_config +from nemo_automodel.recipes.llm.train_ft import TrainFinetuneRecipeForNextTokenPrediction + + +def test_consolidated_gptoss_mxfp4_checkpoint(): + """Load mxfp4 GPT-OSS → train → save bf16 consolidated → reload with HF.""" + + script_path = Path(__file__).parent.resolve() + cfg = parse_args_and_load_config(script_path / "llama3_2" / "llama3_2_1b_hellaswag.yaml") + + trainer = TrainFinetuneRecipeForNextTokenPrediction(cfg) + trainer.setup() + trainer.run_train_validation_loop() + + ckpt_root = Path(trainer.checkpointer.config.checkpoint_dir) / "epoch_0_step_9" + consolidated_dir = ckpt_root / "model" / "consolidated" + + # --- 1. Verify checkpoint directory structure --- + expected_paths = [ + "model", + "optim", + "step_scheduler.pt", + "config.yaml", + "losses.json", + "model/consolidated", + "model/consolidated/config.json", + "model/consolidated/model.safetensors.index.json", + ] + for rel in expected_paths: + p = ckpt_root / rel + assert p.exists(), f"Expected {p} to exist" + + # At least one consolidated safetensors file must exist + consolidated_st = list(consolidated_dir.glob("model-*.safetensors")) + assert len(consolidated_st) >= 1, "No consolidated safetensors files found" + + # --- 2. Verify consolidated safetensors is loadable by safe_open --- + loaded_keys: set[str] = set() + for st_file in consolidated_st: + with safe_open(str(st_file), framework="pt", device="cpu") as f: + loaded_keys.update(f.keys()) + + assert len(loaded_keys) > 0, "Consolidated safetensors has no keys" + # No mxfp4 phantom keys should be present + for key in loaded_keys: + assert "_blocks" not in key, f"Phantom mxfp4 key leaked: {key}" + assert "_scales" not in key, f"Phantom mxfp4 key leaked: {key}" + + # --- 3. Verify all consolidated tensors are well-formed --- + for st_file in consolidated_st: + with safe_open(str(st_file), framework="pt", device="cpu") as f: + for key in f.keys(): + tensor = f.get_tensor(key) + assert tensor.shape.numel() > 0, f"Empty tensor: {key}" + assert not torch.isnan(tensor).any(), f"NaN in tensor: {key}" + assert tensor.dtype in (torch.bfloat16, torch.float32), f"Unexpected dtype {tensor.dtype} for {key}" + + # --- cleanup --- + torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + ckpt_base = Path(trainer.checkpointer.config.checkpoint_dir) + if ckpt_base.exists(): + shutil.rmtree(ckpt_base) + torch.distributed.barrier() diff --git a/tests/functional_tests/hf_dcp/L2_HF_Consolidated_GPTOSS_MXFP4_Checkpoint.sh b/tests/functional_tests/hf_dcp/L2_HF_Consolidated_GPTOSS_MXFP4_Checkpoint.sh new file mode 100644 index 000000000..8f78e4baa --- /dev/null +++ b/tests/functional_tests/hf_dcp/L2_HF_Consolidated_GPTOSS_MXFP4_Checkpoint.sh @@ -0,0 +1,57 @@ +# Copyright (c) 2020-2025, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/bin/bash +set -xeuo pipefail + +export PYTHONPATH=${PYTHONPATH:-}:$(pwd) +export CUDA_VISIBLE_DEVICES="0,1" + +GPTOSS_MODEL_DIR=/tmp/gptoss_2l_mxfp4_$$ +trap 'rm -rf "$GPTOSS_MODEL_DIR"' EXIT + +python tests/functional_tests/checkpoint/create_gptoss_2l_mxfp4.py \ + --output-dir "$GPTOSS_MODEL_DIR" \ + --tokenizer-dir "$TEST_DATA_DIR/hf_mixtral_2l/" + +TRANSFORMERS_OFFLINE=1 TORCH_COMPILE_DISABLE=1 \ +python -m torch.distributed.run --nproc_per_node=2 --nnodes=1 -m coverage run --data-file=/workspace/.coverage --source=/workspace/ --parallel-mode \ +-m pytest tests/functional_tests/checkpoint/test_hf_consolidated_gptoss_mxfp4.py \ + --config examples/llm_finetune/llama3_2/llama3_2_1b_squad.yaml \ + --model.pretrained_model_name_or_path "$GPTOSS_MODEL_DIR" \ + --step_scheduler.max_steps 10 \ + --step_scheduler.global_batch_size 16 \ + --step_scheduler.local_batch_size 8 \ + --dataset.tokenizer.pretrained_model_name_or_path "$GPTOSS_MODEL_DIR" \ + --validation_dataset.tokenizer.pretrained_model_name_or_path "$GPTOSS_MODEL_DIR" \ + --dataset.dataset_name $HF_CACHE/squad/ \ + --validation_dataset.dataset_name $HF_CACHE/squad/ \ + --validation_dataset.padding true \ + --dataset.limit_dataset_samples 1000 \ + --dataset.padding true \ + --dataloader.collate_fn.pad_seq_len_divisible 512 \ + --validation_dataloader.collate_fn.pad_seq_len_divisible 512 \ + --dataset.seq_length 512 \ + --validation_dataset.seq_length 512 \ + --step_scheduler.ckpt_every_steps 10 \ + --checkpoint.enabled true \ + --checkpoint.checkpoint_dir checkpoints/ \ + --checkpoint.model_save_format safetensors \ + --checkpoint.save_consolidated true \ + --distributed.dp_size 2 \ + --distributed.ep_size 2 \ + --distributed.tp_size 1 \ + --distributed.cp_size 1 \ + --distributed.pp_size 1 \ + --distributed.sequence_parallel false diff --git a/tests/functional_tests/hf_dcp/test_hf_dcp.py b/tests/functional_tests/hf_dcp/test_hf_dcp.py index 9b78fb2bc..baf832a21 100644 --- a/tests/functional_tests/hf_dcp/test_hf_dcp.py +++ b/tests/functional_tests/hf_dcp/test_hf_dcp.py @@ -12,12 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from tests.utils.test_utils import run_test_script import shutil + import pytest +from tests.utils.test_utils import run_test_script + try: import qwen_vl_utils # noqa: F401 + _has_qwen_vl_utils = True except ImportError: _has_qwen_vl_utils = False @@ -38,6 +41,7 @@ HF_CONSOLIDATED_FSDP2_VLM_FILENAME = "L2_HF_Consolidated_FSDP2_VLM_Checkpoint.sh" HF_CONSOLIDATED_FSDP2_LLM_SCALAR_WEIGHT_FILENAME = "L2_HF_Consolidated_FSDP2_LLM_Checkpoint_Scalar_Param.sh" HF_CONSOLIDATED_PP2_LLM_FILENAME = "L2_HF_Consolidated_PP2_LLM_Checkpoint.sh" +HF_CONSOLIDATED_GPTOSS_MXFP4_FILENAME = "L2_HF_Consolidated_GPTOSS_MXFP4_Checkpoint.sh" FLASHOPTIM_DCP_ROUNDTRIP_FILENAME = "L2_FlashOptim_DCP_Roundtrip.sh" @@ -108,5 +112,14 @@ def test_hf_consolidated_pp2_llm_checkpoint(self): # remove the checkpoint directory shutil.rmtree("checkpoints/", ignore_errors=True) + def test_hf_consolidated_gptoss_mxfp4_checkpoint(self): + try: + run_test_script(TEST_FOLDER, HF_CONSOLIDATED_GPTOSS_MXFP4_FILENAME) + finally: + shutil.rmtree("checkpoints/", ignore_errors=True) + def test_flashoptim_dcp_roundtrip(self): - run_test_script(TEST_FOLDER, FLASHOPTIM_DCP_ROUNDTRIP_FILENAME) + try: + run_test_script(TEST_FOLDER, FLASHOPTIM_DCP_ROUNDTRIP_FILENAME) + finally: + shutil.rmtree("checkpoints/", ignore_errors=True) diff --git a/tests/unit_tests/checkpoint/test_addons.py b/tests/unit_tests/checkpoint/test_addons.py index ccd2e341e..d6c68f233 100644 --- a/tests/unit_tests/checkpoint/test_addons.py +++ b/tests/unit_tests/checkpoint/test_addons.py @@ -20,6 +20,7 @@ from nemo_automodel.components.checkpoint.addons import ( _extract_target_modules, _maybe_save_custom_model_code, + _maybe_strip_quantization_config, ) from nemo_automodel.components.checkpoint.stateful_wrappers import ModelState @@ -232,3 +233,41 @@ def test_biencoder_target_modules_remapped(self): assert all("lm_q" not in m for m in result) +class TestMaybeStripQuantizationConfig: + """Tests for _maybe_strip_quantization_config.""" + + @staticmethod + def _make_config_with_quant(): + cfg = type("Config", (), {})() + cfg.quantization_config = {"quant_method": "mxfp4"} + return cfg + + def test_strips_quantization_config_when_all_params_bf16(self): + """quantization_config is removed when all params are standard floating-point.""" + model = nn.Linear(4, 4, dtype=torch.bfloat16) + model.config = self._make_config_with_quant() + + _maybe_strip_quantization_config(model) + assert not hasattr(model.config, "quantization_config") + + def test_keeps_quantization_config_when_uint8_params_exist(self): + """quantization_config is preserved when quantized (uint8) parameters exist.""" + model = nn.Module() + model.register_parameter("weight", nn.Parameter(torch.ones(4, 4, dtype=torch.uint8), requires_grad=False)) + model.config = self._make_config_with_quant() + + _maybe_strip_quantization_config(model) + assert hasattr(model.config, "quantization_config") + + def test_noop_when_no_quantization_config(self): + """No error when config has no quantization_config attribute.""" + model = nn.Linear(4, 4) + model.config = type("Config", (), {})() + + _maybe_strip_quantization_config(model) + assert not hasattr(model.config, "quantization_config") + + def test_noop_when_no_config(self): + """No error when model has no config attribute.""" + model = nn.Linear(4, 4) + _maybe_strip_quantization_config(model) diff --git a/tests/unit_tests/checkpoint/test_checkpointing.py b/tests/unit_tests/checkpoint/test_checkpointing.py index f5f74cb1b..985ccd8bd 100644 --- a/tests/unit_tests/checkpoint/test_checkpointing.py +++ b/tests/unit_tests/checkpoint/test_checkpointing.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import inspect from types import SimpleNamespace from unittest.mock import MagicMock, patch @@ -267,7 +266,6 @@ def test_hf_linear_is_not_custom(self): def test_module_from_custom_namespace_is_custom(self): """A class whose __module__ starts with nemo_automodel.components.models. is custom.""" - model = torch.nn.Module() # Simulate a custom model by patching __module__ on the class's MRO FakeCustom = type("FakeCustom", (torch.nn.Module,), {}) FakeCustom.__module__ = "nemo_automodel.components.models.deepseek_v3.model" @@ -358,7 +356,7 @@ class TestLoadModelCustomModelGuard: def _make_checkpointer(self): """Create a minimally configured Checkpointer for testing.""" - from nemo_automodel.components.checkpoint.checkpointing import CheckpointingConfig, Checkpointer + from nemo_automodel.components.checkpoint.checkpointing import Checkpointer, CheckpointingConfig config = CheckpointingConfig( enabled=True, @@ -418,9 +416,7 @@ def test_custom_model_skips_fast_path_uses_dcp(self, mock_load_full, mock_load_h with ( patch("os.path.exists", return_value=True), - patch( - "nemo_automodel.components.checkpoint.checkpointing.ModelState" - ) as MockModelState, + patch("nemo_automodel.components.checkpoint.checkpointing.ModelState") as MockModelState, patch( "nemo_automodel.components.checkpoint.checkpointing._maybe_adapt_state_dict_to_hf", side_effect=lambda m, sd, **kw: sd,