|
28 | 28 | _consume_config_overrides, |
29 | 29 | _filter_kwargs_for_init, |
30 | 30 | ) |
31 | | -from nemo_automodel._transformers.model_init import _ensure_pad_token_id |
32 | 31 | from nemo_automodel.components.models.common.hf_checkpointing_mixin import HFCheckpointingMixin |
33 | 32 |
|
34 | 33 |
|
@@ -700,237 +699,3 @@ def __init__(self): |
700 | 699 |
|
701 | 700 | assert is_custom is False |
702 | 701 | mock_wrap.assert_called_once_with(FakeModel) |
703 | | - |
704 | | - |
705 | | -# ============================================================================= |
706 | | -# Tests for pad_token_id fix and config forwarding in _init_model |
707 | | -# ============================================================================= |
708 | | - |
709 | | - |
710 | | -class TestEnsurePadTokenId: |
711 | | - """Test _ensure_pad_token_id for transformers v5 compatibility.""" |
712 | | - |
713 | | - def test_pad_token_id_set_when_missing(self): |
714 | | - """Config without pad_token_id gets it set to None.""" |
715 | | - |
716 | | - class BareConfig: |
717 | | - pass |
718 | | - |
719 | | - config = BareConfig() |
720 | | - assert not hasattr(config, "pad_token_id") |
721 | | - |
722 | | - _ensure_pad_token_id(config) |
723 | | - |
724 | | - assert hasattr(config, "pad_token_id") |
725 | | - assert config.pad_token_id is None |
726 | | - |
727 | | - def test_pad_token_id_preserved_when_present(self): |
728 | | - """Config that already has pad_token_id keeps its value unchanged.""" |
729 | | - |
730 | | - class ConfigWithPad: |
731 | | - pad_token_id = 42 |
732 | | - |
733 | | - config = ConfigWithPad() |
734 | | - _ensure_pad_token_id(config) |
735 | | - assert config.pad_token_id == 42 |
736 | | - |
737 | | - def test_sub_config_pad_token_id_patched(self): |
738 | | - """Nested sub-configs (e.g. text_config) also get pad_token_id set.""" |
739 | | - |
740 | | - class SubConfig: |
741 | | - """Mimics a HF sub-config (has to_dict).""" |
742 | | - def to_dict(self): |
743 | | - return {} |
744 | | - |
745 | | - class TopConfig: |
746 | | - pad_token_id = 0 # top-level already has it |
747 | | - |
748 | | - def __init__(self): |
749 | | - self.text_config = SubConfig() |
750 | | - self.vision_config = SubConfig() |
751 | | - |
752 | | - config = TopConfig() |
753 | | - assert not hasattr(config.text_config, "pad_token_id") |
754 | | - assert not hasattr(config.vision_config, "pad_token_id") |
755 | | - |
756 | | - _ensure_pad_token_id(config) |
757 | | - |
758 | | - # Top-level preserved |
759 | | - assert config.pad_token_id == 0 |
760 | | - # Sub-configs patched |
761 | | - assert config.text_config.pad_token_id is None |
762 | | - assert config.vision_config.pad_token_id is None |
763 | | - |
764 | | - def test_sub_config_pad_token_id_preserved_when_present(self): |
765 | | - """Sub-config that already has pad_token_id keeps its value.""" |
766 | | - |
767 | | - class SubConfig: |
768 | | - pad_token_id = 99 |
769 | | - def to_dict(self): |
770 | | - return {} |
771 | | - |
772 | | - class TopConfig: |
773 | | - def __init__(self): |
774 | | - self.text_config = SubConfig() |
775 | | - |
776 | | - config = TopConfig() |
777 | | - _ensure_pad_token_id(config) |
778 | | - |
779 | | - assert config.pad_token_id is None # top had none, got patched |
780 | | - assert config.text_config.pad_token_id == 99 # preserved |
781 | | - |
782 | | - def test_non_config_attributes_ignored(self): |
783 | | - """Attributes without to_dict (plain ints, strings, etc.) are skipped.""" |
784 | | - |
785 | | - class TopConfig: |
786 | | - def __init__(self): |
787 | | - self.hidden_size = 768 |
788 | | - self.model_type = "test" |
789 | | - self.some_list = [1, 2, 3] |
790 | | - |
791 | | - config = TopConfig() |
792 | | - _ensure_pad_token_id(config) |
793 | | - |
794 | | - assert config.pad_token_id is None |
795 | | - # Other attributes untouched |
796 | | - assert config.hidden_size == 768 |
797 | | - assert config.model_type == "test" |
798 | | - |
799 | | - def test_integration_with_init_model(self): |
800 | | - """_init_model applies _ensure_pad_token_id to the config.""" |
801 | | - |
802 | | - class BareConfig: |
803 | | - name_or_path = "test-model" |
804 | | - |
805 | | - config = BareConfig() |
806 | | - assert not hasattr(config, "pad_token_id") |
807 | | - |
808 | | - cls = MagicMock() |
809 | | - cls._model_mapping = {} |
810 | | - cls._from_config_parent_class = MagicMock(return_value=MagicMock()) |
811 | | - |
812 | | - with ( |
813 | | - patch("nemo_automodel._transformers.model_init.get_architectures", return_value=[]), |
814 | | - patch("nemo_automodel._transformers.model_init._get_mixin_wrapped_class", side_effect=lambda c: c), |
815 | | - ): |
816 | | - _init_model( |
817 | | - cls, |
818 | | - config, |
819 | | - attn_implementation="eager", |
820 | | - torch_dtype="auto", |
821 | | - quantization_config=None, |
822 | | - force_hf=False, |
823 | | - ) |
824 | | - |
825 | | - assert config.pad_token_id is None |
826 | | - |
827 | | - |
828 | | -class TestConfigForwardingInPretrainedPaths: |
829 | | - """Test that the patched config is forwarded to HF from_pretrained calls |
830 | | - so they don't reload a fresh copy missing the pad_token_id fix.""" |
831 | | - |
832 | | - def _make_cls(self, model_mapping_dict=None): |
833 | | - cls = MagicMock() |
834 | | - cls._model_mapping = model_mapping_dict or {} |
835 | | - return cls |
836 | | - |
837 | | - def _make_fake_config(self, *, has_pad_token_id=False): |
838 | | - config = MagicMock() |
839 | | - config.architectures = [] |
840 | | - config.to_dict.return_value = {} |
841 | | - if not has_pad_token_id: |
842 | | - del config.pad_token_id # ensure hasattr returns False |
843 | | - return config |
844 | | - |
845 | | - def test_force_hf_pretrained_forwards_config_in_kwargs(self): |
846 | | - """force_hf + from_pretrained path passes config in kwargs.""" |
847 | | - fake_config = self._make_fake_config() |
848 | | - fake_model = MagicMock() |
849 | | - |
850 | | - cls = self._make_cls({type(fake_config): type(fake_model)}) |
851 | | - cls._from_pretrained_parent_class = MagicMock(return_value=fake_model) |
852 | | - |
853 | | - with ( |
854 | | - patch("nemo_automodel._transformers.model_init.get_hf_config", return_value=fake_config), |
855 | | - patch("nemo_automodel._transformers.model_init._get_mixin_wrapped_class", side_effect=lambda c: c), |
856 | | - ): |
857 | | - _init_model( |
858 | | - cls, |
859 | | - "some-model-name", # string triggers is_pretrained_init=True |
860 | | - attn_implementation="eager", |
861 | | - torch_dtype="auto", |
862 | | - quantization_config=None, |
863 | | - force_hf=True, |
864 | | - ) |
865 | | - |
866 | | - # Verify _from_pretrained_parent_class was called with config in kwargs |
867 | | - call_kwargs = cls._from_pretrained_parent_class.call_args[1] |
868 | | - assert "config" in call_kwargs |
869 | | - assert call_kwargs["config"] is fake_config |
870 | | - |
871 | | - def test_fallback_hf_pretrained_forwards_config_in_kwargs(self): |
872 | | - """Fallback HF + from_pretrained path passes config in kwargs.""" |
873 | | - fake_config = self._make_fake_config() |
874 | | - fake_model = MagicMock() |
875 | | - |
876 | | - cls = self._make_cls({type(fake_config): type(fake_model)}) |
877 | | - cls._from_pretrained_parent_class = MagicMock(return_value=fake_model) |
878 | | - |
879 | | - with ( |
880 | | - patch("nemo_automodel._transformers.model_init.get_hf_config", return_value=fake_config), |
881 | | - patch("nemo_automodel._transformers.model_init.get_architectures", return_value=[]), |
882 | | - patch("nemo_automodel._transformers.model_init._get_mixin_wrapped_class", side_effect=lambda c: c), |
883 | | - ): |
884 | | - _init_model( |
885 | | - cls, |
886 | | - "some-model-name", # string triggers is_pretrained_init=True |
887 | | - attn_implementation="eager", |
888 | | - torch_dtype="auto", |
889 | | - quantization_config=None, |
890 | | - force_hf=False, |
891 | | - ) |
892 | | - |
893 | | - call_kwargs = cls._from_pretrained_parent_class.call_args[1] |
894 | | - assert "config" in call_kwargs |
895 | | - assert call_kwargs["config"] is fake_config |
896 | | - |
897 | | - def test_custom_model_pretrained_does_not_receive_config_in_kwargs(self): |
898 | | - """Custom model path must NOT get config in kwargs (it's passed positionally).""" |
899 | | - fake_config = MagicMock() |
900 | | - fake_config.architectures = ["FakeArch"] |
901 | | - fake_config.to_dict.return_value = {} |
902 | | - fake_config.torch_dtype = "bfloat16" |
903 | | - # Ensure pad_token_id is missing so the fix sets it |
904 | | - del fake_config.pad_token_id |
905 | | - |
906 | | - class FakeCustomModel(torch.nn.Module): |
907 | | - def __init__(self, config, **kwargs): |
908 | | - super().__init__() |
909 | | - self.config = config |
910 | | - # Store kwargs so we can inspect them in the test |
911 | | - self._init_kwargs = kwargs |
912 | | - |
913 | | - cls = MagicMock() |
914 | | - registry_mock = {"FakeArch": FakeCustomModel} |
915 | | - |
916 | | - with ( |
917 | | - patch("nemo_automodel._transformers.model_init.get_hf_config", return_value=fake_config), |
918 | | - patch("nemo_automodel._transformers.model_init.get_architectures", return_value=["FakeArch"]), |
919 | | - patch("nemo_automodel._transformers.model_init.ModelRegistry") as mock_registry, |
920 | | - patch("nemo_automodel._transformers.model_init._download_model_weights"), |
921 | | - ): |
922 | | - mock_registry.model_arch_name_to_cls = registry_mock |
923 | | - is_custom, model = _init_model( |
924 | | - cls, |
925 | | - "some-model-name", # string triggers is_pretrained_init=True |
926 | | - attn_implementation="eager", |
927 | | - torch_dtype="bfloat16", |
928 | | - quantization_config=None, |
929 | | - force_hf=False, |
930 | | - ) |
931 | | - |
932 | | - assert is_custom is True |
933 | | - # config was passed positionally |
934 | | - assert model.config is fake_config |
935 | | - # config must NOT be in kwargs (would cause TypeError: got multiple values) |
936 | | - assert "config" not in model._init_kwargs |
0 commit comments