Skip to content

Commit 909af74

Browse files
committed
update tests
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
1 parent 35f81fc commit 909af74

3 files changed

Lines changed: 130 additions & 237 deletions

File tree

tests/unit_tests/_transformers/test_auto_model.py

Lines changed: 0 additions & 235 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
_consume_config_overrides,
2929
_filter_kwargs_for_init,
3030
)
31-
from nemo_automodel._transformers.model_init import _ensure_pad_token_id
3231
from nemo_automodel.components.models.common.hf_checkpointing_mixin import HFCheckpointingMixin
3332

3433

@@ -700,237 +699,3 @@ def __init__(self):
700699

701700
assert is_custom is False
702701
mock_wrap.assert_called_once_with(FakeModel)
703-
704-
705-
# =============================================================================
706-
# Tests for pad_token_id fix and config forwarding in _init_model
707-
# =============================================================================
708-
709-
710-
class TestEnsurePadTokenId:
711-
"""Test _ensure_pad_token_id for transformers v5 compatibility."""
712-
713-
def test_pad_token_id_set_when_missing(self):
714-
"""Config without pad_token_id gets it set to None."""
715-
716-
class BareConfig:
717-
pass
718-
719-
config = BareConfig()
720-
assert not hasattr(config, "pad_token_id")
721-
722-
_ensure_pad_token_id(config)
723-
724-
assert hasattr(config, "pad_token_id")
725-
assert config.pad_token_id is None
726-
727-
def test_pad_token_id_preserved_when_present(self):
728-
"""Config that already has pad_token_id keeps its value unchanged."""
729-
730-
class ConfigWithPad:
731-
pad_token_id = 42
732-
733-
config = ConfigWithPad()
734-
_ensure_pad_token_id(config)
735-
assert config.pad_token_id == 42
736-
737-
def test_sub_config_pad_token_id_patched(self):
738-
"""Nested sub-configs (e.g. text_config) also get pad_token_id set."""
739-
740-
class SubConfig:
741-
"""Mimics a HF sub-config (has to_dict)."""
742-
def to_dict(self):
743-
return {}
744-
745-
class TopConfig:
746-
pad_token_id = 0 # top-level already has it
747-
748-
def __init__(self):
749-
self.text_config = SubConfig()
750-
self.vision_config = SubConfig()
751-
752-
config = TopConfig()
753-
assert not hasattr(config.text_config, "pad_token_id")
754-
assert not hasattr(config.vision_config, "pad_token_id")
755-
756-
_ensure_pad_token_id(config)
757-
758-
# Top-level preserved
759-
assert config.pad_token_id == 0
760-
# Sub-configs patched
761-
assert config.text_config.pad_token_id is None
762-
assert config.vision_config.pad_token_id is None
763-
764-
def test_sub_config_pad_token_id_preserved_when_present(self):
765-
"""Sub-config that already has pad_token_id keeps its value."""
766-
767-
class SubConfig:
768-
pad_token_id = 99
769-
def to_dict(self):
770-
return {}
771-
772-
class TopConfig:
773-
def __init__(self):
774-
self.text_config = SubConfig()
775-
776-
config = TopConfig()
777-
_ensure_pad_token_id(config)
778-
779-
assert config.pad_token_id is None # top had none, got patched
780-
assert config.text_config.pad_token_id == 99 # preserved
781-
782-
def test_non_config_attributes_ignored(self):
783-
"""Attributes without to_dict (plain ints, strings, etc.) are skipped."""
784-
785-
class TopConfig:
786-
def __init__(self):
787-
self.hidden_size = 768
788-
self.model_type = "test"
789-
self.some_list = [1, 2, 3]
790-
791-
config = TopConfig()
792-
_ensure_pad_token_id(config)
793-
794-
assert config.pad_token_id is None
795-
# Other attributes untouched
796-
assert config.hidden_size == 768
797-
assert config.model_type == "test"
798-
799-
def test_integration_with_init_model(self):
800-
"""_init_model applies _ensure_pad_token_id to the config."""
801-
802-
class BareConfig:
803-
name_or_path = "test-model"
804-
805-
config = BareConfig()
806-
assert not hasattr(config, "pad_token_id")
807-
808-
cls = MagicMock()
809-
cls._model_mapping = {}
810-
cls._from_config_parent_class = MagicMock(return_value=MagicMock())
811-
812-
with (
813-
patch("nemo_automodel._transformers.model_init.get_architectures", return_value=[]),
814-
patch("nemo_automodel._transformers.model_init._get_mixin_wrapped_class", side_effect=lambda c: c),
815-
):
816-
_init_model(
817-
cls,
818-
config,
819-
attn_implementation="eager",
820-
torch_dtype="auto",
821-
quantization_config=None,
822-
force_hf=False,
823-
)
824-
825-
assert config.pad_token_id is None
826-
827-
828-
class TestConfigForwardingInPretrainedPaths:
829-
"""Test that the patched config is forwarded to HF from_pretrained calls
830-
so they don't reload a fresh copy missing the pad_token_id fix."""
831-
832-
def _make_cls(self, model_mapping_dict=None):
833-
cls = MagicMock()
834-
cls._model_mapping = model_mapping_dict or {}
835-
return cls
836-
837-
def _make_fake_config(self, *, has_pad_token_id=False):
838-
config = MagicMock()
839-
config.architectures = []
840-
config.to_dict.return_value = {}
841-
if not has_pad_token_id:
842-
del config.pad_token_id # ensure hasattr returns False
843-
return config
844-
845-
def test_force_hf_pretrained_forwards_config_in_kwargs(self):
846-
"""force_hf + from_pretrained path passes config in kwargs."""
847-
fake_config = self._make_fake_config()
848-
fake_model = MagicMock()
849-
850-
cls = self._make_cls({type(fake_config): type(fake_model)})
851-
cls._from_pretrained_parent_class = MagicMock(return_value=fake_model)
852-
853-
with (
854-
patch("nemo_automodel._transformers.model_init.get_hf_config", return_value=fake_config),
855-
patch("nemo_automodel._transformers.model_init._get_mixin_wrapped_class", side_effect=lambda c: c),
856-
):
857-
_init_model(
858-
cls,
859-
"some-model-name", # string triggers is_pretrained_init=True
860-
attn_implementation="eager",
861-
torch_dtype="auto",
862-
quantization_config=None,
863-
force_hf=True,
864-
)
865-
866-
# Verify _from_pretrained_parent_class was called with config in kwargs
867-
call_kwargs = cls._from_pretrained_parent_class.call_args[1]
868-
assert "config" in call_kwargs
869-
assert call_kwargs["config"] is fake_config
870-
871-
def test_fallback_hf_pretrained_forwards_config_in_kwargs(self):
872-
"""Fallback HF + from_pretrained path passes config in kwargs."""
873-
fake_config = self._make_fake_config()
874-
fake_model = MagicMock()
875-
876-
cls = self._make_cls({type(fake_config): type(fake_model)})
877-
cls._from_pretrained_parent_class = MagicMock(return_value=fake_model)
878-
879-
with (
880-
patch("nemo_automodel._transformers.model_init.get_hf_config", return_value=fake_config),
881-
patch("nemo_automodel._transformers.model_init.get_architectures", return_value=[]),
882-
patch("nemo_automodel._transformers.model_init._get_mixin_wrapped_class", side_effect=lambda c: c),
883-
):
884-
_init_model(
885-
cls,
886-
"some-model-name", # string triggers is_pretrained_init=True
887-
attn_implementation="eager",
888-
torch_dtype="auto",
889-
quantization_config=None,
890-
force_hf=False,
891-
)
892-
893-
call_kwargs = cls._from_pretrained_parent_class.call_args[1]
894-
assert "config" in call_kwargs
895-
assert call_kwargs["config"] is fake_config
896-
897-
def test_custom_model_pretrained_does_not_receive_config_in_kwargs(self):
898-
"""Custom model path must NOT get config in kwargs (it's passed positionally)."""
899-
fake_config = MagicMock()
900-
fake_config.architectures = ["FakeArch"]
901-
fake_config.to_dict.return_value = {}
902-
fake_config.torch_dtype = "bfloat16"
903-
# Ensure pad_token_id is missing so the fix sets it
904-
del fake_config.pad_token_id
905-
906-
class FakeCustomModel(torch.nn.Module):
907-
def __init__(self, config, **kwargs):
908-
super().__init__()
909-
self.config = config
910-
# Store kwargs so we can inspect them in the test
911-
self._init_kwargs = kwargs
912-
913-
cls = MagicMock()
914-
registry_mock = {"FakeArch": FakeCustomModel}
915-
916-
with (
917-
patch("nemo_automodel._transformers.model_init.get_hf_config", return_value=fake_config),
918-
patch("nemo_automodel._transformers.model_init.get_architectures", return_value=["FakeArch"]),
919-
patch("nemo_automodel._transformers.model_init.ModelRegistry") as mock_registry,
920-
patch("nemo_automodel._transformers.model_init._download_model_weights"),
921-
):
922-
mock_registry.model_arch_name_to_cls = registry_mock
923-
is_custom, model = _init_model(
924-
cls,
925-
"some-model-name", # string triggers is_pretrained_init=True
926-
attn_implementation="eager",
927-
torch_dtype="bfloat16",
928-
quantization_config=None,
929-
force_hf=False,
930-
)
931-
932-
assert is_custom is True
933-
# config was passed positionally
934-
assert model.config is fake_config
935-
# config must NOT be in kwargs (would cause TypeError: got multiple values)
936-
assert "config" not in model._init_kwargs

tests/unit_tests/datasets/vlm/test_collate_fns.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -523,6 +523,24 @@ def test_kimi_vl_collate_fn_extracts_images(collate_mod, monkeypatch):
523523
assert forward_call["images"] == ["test_image.jpg"]
524524

525525

526+
def test_kimi_vl_collate_fn_passes_add_special_tokens_false(collate_mod, monkeypatch):
527+
"""Test that kimi_vl_collate_fn passes add_special_tokens=False to processor."""
528+
processor = DummyKimiVLProcessor()
529+
530+
labels_stub = torch.tensor([[10, 11, 12, 13, 14]], dtype=torch.long)
531+
monkeypatch.setattr(
532+
collate_mod, "build_labels", lambda *args, **kwargs: labels_stub, raising=True
533+
)
534+
535+
examples = [{"conversation": CONVERSATION}]
536+
collate_mod.kimi_vl_collate_fn(examples, processor)
537+
538+
assert len(processor.forward_calls) == 1
539+
forward_call = processor.forward_calls[0]
540+
assert "add_special_tokens" in forward_call
541+
assert forward_call["add_special_tokens"] is False
542+
543+
526544
def test_kimi_vl_collate_fn_multiple_examples(collate_mod, monkeypatch):
527545
"""Test kimi_vl_collate_fn handles multiple examples."""
528546
processor = DummyKimiVLProcessor()

tests/unit_tests/recipes/test_finetune_vlm_helpers.py

Lines changed: 112 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def get(self, key, default=None):
154154

155155
class FreezeConfig:
156156
def to_dict(self):
157-
return {"freeze_embeddings": True, "freeze_language_model": False}
157+
return {"freeze_language_model": False, "freeze_vision_tower": True}
158158

159159
with patch('nemo_automodel.recipes.vlm.finetune._supports_logits_to_keep', return_value=True):
160160
model = build_model(
@@ -166,7 +166,117 @@ def to_dict(self):
166166

167167
# Verify freeze_config was passed to model instantiation
168168
assert "freeze_config" in captured_kwargs
169-
assert captured_kwargs["freeze_config"] == {"freeze_embeddings": True, "freeze_language_model": False}
169+
assert captured_kwargs["freeze_config"] == {"freeze_language_model": False, "freeze_vision_tower": True}
170+
171+
172+
def test_build_model_passes_moe_config_from_parallelizer_config():
173+
"""Test that cfg_moe as MoEParallelizerConfig is forwarded directly."""
174+
from nemo_automodel._transformers import NeMoAutoModelForImageTextToText
175+
from nemo_automodel.components.moe.config import MoEParallelizerConfig
176+
177+
captured_kwargs = {}
178+
179+
class CapturingModelConfig:
180+
def __init__(self):
181+
self._target_ = NeMoAutoModelForImageTextToText.from_pretrained
182+
183+
def instantiate(self, **kwargs):
184+
captured_kwargs.update(kwargs)
185+
return DummyModel()
186+
187+
def get(self, key, default=None):
188+
return getattr(self, key, default)
189+
190+
cfg_model = CapturingModelConfig()
191+
moe_cfg = MoEParallelizerConfig()
192+
193+
with patch('nemo_automodel.recipes.vlm.finetune._supports_logits_to_keep', return_value=True):
194+
build_model(
195+
cfg_model=cfg_model,
196+
cfg_freeze=None,
197+
cfg_peft=None,
198+
seed=123,
199+
cfg_moe=moe_cfg,
200+
activation_checkpointing=True,
201+
)
202+
203+
assert "moe_config" in captured_kwargs
204+
assert captured_kwargs["moe_config"] is moe_cfg
205+
assert captured_kwargs["activation_checkpointing"] is True
206+
207+
208+
def test_build_model_passes_moe_config_from_dict_like():
209+
"""Test that cfg_moe with to_dict() is converted to MoEParallelizerConfig."""
210+
from nemo_automodel._transformers import NeMoAutoModelForImageTextToText
211+
from nemo_automodel.components.moe.config import MoEParallelizerConfig
212+
213+
captured_kwargs = {}
214+
215+
class CapturingModelConfig:
216+
def __init__(self):
217+
self._target_ = NeMoAutoModelForImageTextToText.from_pretrained
218+
219+
def instantiate(self, **kwargs):
220+
captured_kwargs.update(kwargs)
221+
return DummyModel()
222+
223+
def get(self, key, default=None):
224+
return getattr(self, key, default)
225+
226+
class DictLikeMoeConfig:
227+
def to_dict(self):
228+
return {
229+
"activation_checkpointing": True, # should be stripped
230+
"_target_": "some.target", # should be stripped
231+
}
232+
233+
cfg_model = CapturingModelConfig()
234+
235+
with patch('nemo_automodel.recipes.vlm.finetune._supports_logits_to_keep', return_value=True):
236+
build_model(
237+
cfg_model=cfg_model,
238+
cfg_freeze=None,
239+
cfg_peft=None,
240+
seed=123,
241+
cfg_moe=DictLikeMoeConfig(),
242+
activation_checkpointing=False,
243+
)
244+
245+
assert "moe_config" in captured_kwargs
246+
assert isinstance(captured_kwargs["moe_config"], MoEParallelizerConfig)
247+
assert captured_kwargs["activation_checkpointing"] is False
248+
249+
250+
def test_build_model_no_moe_config_when_cfg_moe_is_none():
251+
"""Test that moe_config and activation_checkpointing are not in kwargs when cfg_moe is None."""
252+
from nemo_automodel._transformers import NeMoAutoModelForImageTextToText
253+
254+
captured_kwargs = {}
255+
256+
class CapturingModelConfig:
257+
def __init__(self):
258+
self._target_ = NeMoAutoModelForImageTextToText.from_pretrained
259+
260+
def instantiate(self, **kwargs):
261+
captured_kwargs.update(kwargs)
262+
return DummyModel()
263+
264+
def get(self, key, default=None):
265+
return getattr(self, key, default)
266+
267+
cfg_model = CapturingModelConfig()
268+
269+
with patch('nemo_automodel.recipes.vlm.finetune._supports_logits_to_keep', return_value=True):
270+
build_model(
271+
cfg_model=cfg_model,
272+
cfg_freeze=None,
273+
cfg_peft=None,
274+
seed=123,
275+
cfg_moe=None,
276+
)
277+
278+
assert "moe_config" not in captured_kwargs
279+
assert "activation_checkpointing" not in captured_kwargs
170280

171281

172282
# -----------------------------------------------------------------------------

0 commit comments

Comments
 (0)