Skip to content

Commit d3dd2a1

Browse files
authored
fix: vlm refac fixes (#1268)
* multiple vlm fixes Signed-off-by: HuiyingLi <willwin.lee@gmail.com> * revert model init Signed-off-by: HuiyingLi <willwin.lee@gmail.com> * update tests Signed-off-by: HuiyingLi <willwin.lee@gmail.com> --------- Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
1 parent 5872e02 commit d3dd2a1

6 files changed

Lines changed: 162 additions & 14 deletions

File tree

examples/vlm_finetune/qwen3/qwen3_omni_moe_30b_te_deepep.yaml

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,16 +32,17 @@ rng:
3232
model:
3333
_target_: nemo_automodel.NeMoAutoModelForImageTextToText.from_pretrained
3434
pretrained_model_name_or_path: Qwen/Qwen3-Omni-30B-A3B-Instruct
35-
# Customize this backend for fine grained control
36-
# backend:
37-
# _target_: nemo_automodel.components.models.common.BackendConfig
38-
# attn: sdpa
39-
# linear: te
40-
# rms_norm: te
41-
# experts: te
42-
# dispatcher: deepep
43-
# fake_balanced_gate: false
44-
# enable_hf_state_dict_adapter: true
35+
#Customize this backend for fine grained control
36+
backend:
37+
_target_: nemo_automodel.components.models.common.BackendConfig
38+
attn: sdpa
39+
linear: te
40+
rms_norm: te
41+
rope_fusion: false
42+
experts: te
43+
enable_deepep: true
44+
fake_balanced_gate: false
45+
enable_hf_state_dict_adapter: true
4546

4647

4748
checkpoint:

examples/vlm_finetune/qwen3/qwen3_vl_moe_30b_te_deepep.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,9 @@ model:
4040
attn: sdpa
4141
linear: te
4242
rms_norm: te
43+
rope_fusion: false
4344
experts: te
44-
dispatcher: deepep
45+
enable_deepep: true
4546
fake_balanced_gate: false
4647
enable_hf_state_dict_adapter: true
4748

nemo_automodel/components/datasets/vlm/collate_fns.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,7 @@ def kimi_vl_collate_fn(
308308
"return_tensors": "pt",
309309
"padding": True,
310310
"truncation": True,
311+
"add_special_tokens": False,
311312
}
312313
if max_length is not None:
313314
processor_kwargs["max_length"] = max_length

nemo_automodel/recipes/vlm/finetune.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,13 +88,14 @@ def build_model(
8888
cfg_freeze,
8989
cfg_peft,
9090
seed,
91-
freeze_embeddings=True,
9291
cfg_fp8=None,
9392
cfg_compile=None,
9493
device_mesh=None,
9594
moe_mesh=None,
9695
distributed_config=None,
9796
pipeline_config=None,
97+
cfg_moe=None,
98+
activation_checkpointing=False,
9899
) -> tuple[nn.Module | AutoPipeline, list["Optimizer"]]: # noqa: F821
99100
"""Build and initialize a model for VLM.
100101
@@ -111,6 +112,20 @@ def build_model(
111112
"pipeline_config": pipeline_config,
112113
"freeze_config": cfg_freeze.to_dict() if cfg_freeze is not None else None,
113114
}
115+
116+
if cfg_moe is not None:
117+
from nemo_automodel.components.moe.config import MoEParallelizerConfig
118+
119+
if isinstance(cfg_moe, MoEParallelizerConfig):
120+
kwargs["moe_config"] = cfg_moe
121+
else:
122+
moe_dict = cfg_moe.to_dict() if hasattr(cfg_moe, "to_dict") else dict(cfg_moe)
123+
# activation_checkpointing is handled separately; strip config keys
124+
moe_dict.pop("activation_checkpointing", None)
125+
moe_dict.pop("_target_", None)
126+
kwargs["moe_config"] = MoEParallelizerConfig(**moe_dict)
127+
kwargs["activation_checkpointing"] = activation_checkpointing
128+
114129
if cfg_fp8 is not None:
115130
fp8_config = build_fp8_config(cfg_fp8)
116131
kwargs["fp8_config"] = fp8_config
@@ -556,6 +571,8 @@ def setup(self):
556571
moe_mesh=self.moe_mesh,
557572
distributed_config=self.distributed_config,
558573
pipeline_config=self.pipeline_config,
574+
cfg_moe=self.dist_setup.moe_config,
575+
activation_checkpointing=self.dist_setup.activation_checkpointing,
559576
)
560577
self.optimizer = build_optimizer(model, self.cfg.optimizer, self.distributed_config, self.device_mesh)
561578

tests/unit_tests/datasets/vlm/test_collate_fns.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -523,6 +523,24 @@ def test_kimi_vl_collate_fn_extracts_images(collate_mod, monkeypatch):
523523
assert forward_call["images"] == ["test_image.jpg"]
524524

525525

526+
def test_kimi_vl_collate_fn_passes_add_special_tokens_false(collate_mod, monkeypatch):
527+
"""Test that kimi_vl_collate_fn passes add_special_tokens=False to processor."""
528+
processor = DummyKimiVLProcessor()
529+
530+
labels_stub = torch.tensor([[10, 11, 12, 13, 14]], dtype=torch.long)
531+
monkeypatch.setattr(
532+
collate_mod, "build_labels", lambda *args, **kwargs: labels_stub, raising=True
533+
)
534+
535+
examples = [{"conversation": CONVERSATION}]
536+
collate_mod.kimi_vl_collate_fn(examples, processor)
537+
538+
assert len(processor.forward_calls) == 1
539+
forward_call = processor.forward_calls[0]
540+
assert "add_special_tokens" in forward_call
541+
assert forward_call["add_special_tokens"] is False
542+
543+
526544
def test_kimi_vl_collate_fn_multiple_examples(collate_mod, monkeypatch):
527545
"""Test kimi_vl_collate_fn handles multiple examples."""
528546
processor = DummyKimiVLProcessor()

tests/unit_tests/recipes/test_finetune_vlm_helpers.py

Lines changed: 112 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def get(self, key, default=None):
154154

155155
class FreezeConfig:
156156
def to_dict(self):
157-
return {"freeze_embeddings": True, "freeze_language_model": False}
157+
return {"freeze_language_model": False, "freeze_vision_tower": True}
158158

159159
with patch('nemo_automodel.recipes.vlm.finetune._supports_logits_to_keep', return_value=True):
160160
model = build_model(
@@ -166,7 +166,117 @@ def to_dict(self):
166166

167167
# Verify freeze_config was passed to model instantiation
168168
assert "freeze_config" in captured_kwargs
169-
assert captured_kwargs["freeze_config"] == {"freeze_embeddings": True, "freeze_language_model": False}
169+
assert captured_kwargs["freeze_config"] == {"freeze_language_model": False, "freeze_vision_tower": True}
170+
171+
172+
def test_build_model_passes_moe_config_from_parallelizer_config():
173+
"""Test that cfg_moe as MoEParallelizerConfig is forwarded directly."""
174+
from nemo_automodel._transformers import NeMoAutoModelForImageTextToText
175+
from nemo_automodel.components.moe.config import MoEParallelizerConfig
176+
177+
captured_kwargs = {}
178+
179+
class CapturingModelConfig:
180+
def __init__(self):
181+
self._target_ = NeMoAutoModelForImageTextToText.from_pretrained
182+
183+
def instantiate(self, **kwargs):
184+
captured_kwargs.update(kwargs)
185+
return DummyModel()
186+
187+
def get(self, key, default=None):
188+
return getattr(self, key, default)
189+
190+
cfg_model = CapturingModelConfig()
191+
moe_cfg = MoEParallelizerConfig()
192+
193+
with patch('nemo_automodel.recipes.vlm.finetune._supports_logits_to_keep', return_value=True):
194+
build_model(
195+
cfg_model=cfg_model,
196+
cfg_freeze=None,
197+
cfg_peft=None,
198+
seed=123,
199+
cfg_moe=moe_cfg,
200+
activation_checkpointing=True,
201+
)
202+
203+
assert "moe_config" in captured_kwargs
204+
assert captured_kwargs["moe_config"] is moe_cfg
205+
assert captured_kwargs["activation_checkpointing"] is True
206+
207+
208+
def test_build_model_passes_moe_config_from_dict_like():
209+
"""Test that cfg_moe with to_dict() is converted to MoEParallelizerConfig."""
210+
from nemo_automodel._transformers import NeMoAutoModelForImageTextToText
211+
from nemo_automodel.components.moe.config import MoEParallelizerConfig
212+
213+
captured_kwargs = {}
214+
215+
class CapturingModelConfig:
216+
def __init__(self):
217+
self._target_ = NeMoAutoModelForImageTextToText.from_pretrained
218+
219+
def instantiate(self, **kwargs):
220+
captured_kwargs.update(kwargs)
221+
return DummyModel()
222+
223+
def get(self, key, default=None):
224+
return getattr(self, key, default)
225+
226+
class DictLikeMoeConfig:
227+
def to_dict(self):
228+
return {
229+
"activation_checkpointing": True, # should be stripped
230+
"_target_": "some.target", # should be stripped
231+
}
232+
233+
cfg_model = CapturingModelConfig()
234+
235+
with patch('nemo_automodel.recipes.vlm.finetune._supports_logits_to_keep', return_value=True):
236+
build_model(
237+
cfg_model=cfg_model,
238+
cfg_freeze=None,
239+
cfg_peft=None,
240+
seed=123,
241+
cfg_moe=DictLikeMoeConfig(),
242+
activation_checkpointing=False,
243+
)
244+
245+
assert "moe_config" in captured_kwargs
246+
assert isinstance(captured_kwargs["moe_config"], MoEParallelizerConfig)
247+
assert captured_kwargs["activation_checkpointing"] is False
248+
249+
250+
def test_build_model_no_moe_config_when_cfg_moe_is_none():
251+
"""Test that moe_config and activation_checkpointing are not in kwargs when cfg_moe is None."""
252+
from nemo_automodel._transformers import NeMoAutoModelForImageTextToText
253+
254+
captured_kwargs = {}
255+
256+
class CapturingModelConfig:
257+
def __init__(self):
258+
self._target_ = NeMoAutoModelForImageTextToText.from_pretrained
259+
260+
def instantiate(self, **kwargs):
261+
captured_kwargs.update(kwargs)
262+
return DummyModel()
263+
264+
def get(self, key, default=None):
265+
return getattr(self, key, default)
266+
267+
cfg_model = CapturingModelConfig()
268+
269+
with patch('nemo_automodel.recipes.vlm.finetune._supports_logits_to_keep', return_value=True):
270+
build_model(
271+
cfg_model=cfg_model,
272+
cfg_freeze=None,
273+
cfg_peft=None,
274+
seed=123,
275+
cfg_moe=None,
276+
)
277+
278+
assert "moe_config" not in captured_kwargs
279+
assert "activation_checkpointing" not in captured_kwargs
170280

171281

172282
# -----------------------------------------------------------------------------

0 commit comments

Comments
 (0)