Skip to content

Commit 3b73476

Browse files
HuiyingLiclaude
andcommitted
test: add unit tests for _patched_get_init_context and meta device helpers
Tests cover: - Extra args/kwargs forwarding (transformers v5.3.0 allow_all_kernels) - Meta device filtering with no_hf_meta_device context manager - Patch installation on PreTrainedModel - Nested context manager behavior Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
1 parent 4a9d6ad commit 3b73476

1 file changed

Lines changed: 115 additions & 0 deletions

File tree

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import torch
16+
from transformers import PreTrainedModel
17+
from unittest.mock import patch
18+
19+
from nemo_automodel._transformers.model_init import (
20+
_filter_meta_device_from_init_context,
21+
_patched_get_init_context,
22+
no_hf_meta_device,
23+
)
24+
25+
26+
class TestFilterMetaDeviceFromInitContext:
27+
def test_removes_meta_device(self):
28+
contexts = [torch.device("meta"), torch.float32]
29+
result = _filter_meta_device_from_init_context(contexts)
30+
assert torch.device("meta") not in result
31+
assert torch.float32 in result
32+
33+
def test_keeps_non_meta_devices(self):
34+
contexts = [torch.device("cpu"), torch.device("cuda")]
35+
result = _filter_meta_device_from_init_context(contexts)
36+
assert len(result) == 2
37+
38+
def test_empty_list(self):
39+
assert _filter_meta_device_from_init_context([]) == []
40+
41+
42+
class TestPatchedGetInitContext:
43+
def test_forwards_extra_args(self):
44+
"""Verify _patched_get_init_context forwards *args/**kwargs (transformers v5.3.0 compat)."""
45+
received_args = {}
46+
47+
def mock_original(cls, dtype, is_quantized, _is_ds_init_called, *args, **kwargs):
48+
received_args["args"] = args
49+
received_args["kwargs"] = kwargs
50+
return []
51+
52+
with patch.object(_patched_get_init_context, "__wrapped__", mock_original):
53+
_patched_get_init_context(None, torch.float32, False, False, True, extra_kwarg="test")
54+
55+
assert received_args["args"] == (True,)
56+
assert received_args["kwargs"] == {"extra_kwarg": "test"}
57+
58+
def test_forwards_allow_all_kernels(self):
59+
"""Simulate the exact transformers v5.3.0 call with allow_all_kernels param."""
60+
received_args = {}
61+
62+
def mock_original(cls, dtype, is_quantized, _is_ds_init_called, allow_all_kernels):
63+
received_args["allow_all_kernels"] = allow_all_kernels
64+
return []
65+
66+
with patch.object(_patched_get_init_context, "__wrapped__", mock_original):
67+
_patched_get_init_context(None, torch.float32, False, False, None)
68+
69+
assert received_args["allow_all_kernels"] is None
70+
71+
def test_strips_meta_device_when_disabled(self):
72+
"""When no_hf_meta_device context is active, meta devices are filtered out."""
73+
74+
def mock_original(cls, dtype, is_quantized, _is_ds_init_called, *args, **kwargs):
75+
return [torch.device("meta"), torch.float32]
76+
77+
with patch.object(_patched_get_init_context, "__wrapped__", mock_original):
78+
with no_hf_meta_device():
79+
result = _patched_get_init_context(None, torch.float32, False, False)
80+
assert torch.device("meta") not in result
81+
assert torch.float32 in result
82+
83+
def test_keeps_meta_device_by_default(self):
84+
"""Without no_hf_meta_device, meta devices are preserved."""
85+
86+
def mock_original(cls, dtype, is_quantized, _is_ds_init_called, *args, **kwargs):
87+
return [torch.device("meta"), torch.float32]
88+
89+
with patch.object(_patched_get_init_context, "__wrapped__", mock_original):
90+
result = _patched_get_init_context(None, torch.float32, False, False)
91+
assert torch.device("meta") in result
92+
93+
def test_patch_installed_on_pretrained_model(self):
94+
"""Verify the patch is actually installed on PreTrainedModel."""
95+
assert PreTrainedModel.get_init_context.__func__ is _patched_get_init_context
96+
97+
98+
class TestNoHfMetaDevice:
99+
def test_context_manager_sets_and_restores(self):
100+
from nemo_automodel._transformers.model_init import _get_hf_meta_device_disabled
101+
102+
assert not _get_hf_meta_device_disabled()
103+
with no_hf_meta_device():
104+
assert _get_hf_meta_device_disabled()
105+
assert not _get_hf_meta_device_disabled()
106+
107+
def test_nested_context_managers(self):
108+
from nemo_automodel._transformers.model_init import _get_hf_meta_device_disabled
109+
110+
with no_hf_meta_device():
111+
assert _get_hf_meta_device_disabled()
112+
with no_hf_meta_device():
113+
assert _get_hf_meta_device_disabled()
114+
assert _get_hf_meta_device_disabled()
115+
assert not _get_hf_meta_device_disabled()

0 commit comments

Comments
 (0)