Skip to content

Commit 61a5116

Browse files
authored
feat(recipes): add VLM knowledge distillation recipe with chunked KD loss (#2205)
* feat(loss): add chunked KD loss for memory-efficient distillation Add a ``chunk_size`` knob to ``KDLoss`` that processes valid tokens in chunks when computing forward KL. Only one ``[chunk_size, vocab_size]`` fp32 probability/log-prob tensor is materialized at a time, which keeps peak memory bounded for large-vocab VLMs while remaining numerically identical to the unchunked path (verified by new unit tests). The TP path is unaffected; chunking is opt-in via ``chunk_size > 0``. Signed-off-by: khazic <khazzz1c@gmail.com> * feat(recipes): add VLM knowledge distillation recipe Add ``KnowledgeDistillationRecipeForVLM`` under ``recipes/vlm/kd.py``. The recipe extends ``FinetuneRecipeForVLM`` with a frozen teacher ``NeMoAutoModelForImageTextToText``, a KD loss term, and a ``kd_ratio`` linear mix between CE and KD losses. The training loop forwards multimodal inputs (pixel_values, image_grid_thw, etc.) to both teacher and student, frees intermediate activations eagerly to keep peak memory low, and reports CE/KD sub-losses alongside the combined loss in validation metrics. Pipeline parallelism is not supported. Signed-off-by: khazic <khazzz1c@gmail.com> * docs(examples): add Qwen3.5 VLM KD example config Add an example YAML that distills Qwen3.5-9B (teacher) into Qwen3.5-4B (student) on the public ``mmoukouba/MedPix-VQA`` medical-image VQA dataset. The config exercises the chunked KD loss (``kd_loss_fn.chunk_size: 512``), freezes the student's vision and audio towers, and uses FSDP2. Signed-off-by: khazic <khazzz1c@gmail.com> --------- Signed-off-by: khazic <khazzz1c@gmail.com>
1 parent adc20e2 commit 61a5116

4 files changed

Lines changed: 752 additions & 4 deletions

File tree

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# VLM Knowledge Distillation: Qwen3.5-9B (teacher) → Qwen3.5-4B (student)
16+
# Dataset: MedPix-VQA (medical image VQA with real images)
17+
#
18+
# To run:
19+
# automodel examples/vlm_kd/qwen3_5/qwen3_5_vl_4b_kd.yaml --nproc-per-node 8
20+
21+
recipe: KnowledgeDistillationRecipeForVLM
22+
23+
step_scheduler:
24+
global_batch_size: 16
25+
local_batch_size: 1
26+
ckpt_every_steps: 200
27+
val_every_steps: 50
28+
num_epochs: 2
29+
max_steps: 300
30+
31+
dist_env:
32+
backend: nccl
33+
timeout_minutes: 10
34+
35+
rng:
36+
_target_: nemo_automodel.components.training.rng.StatefulRNG
37+
seed: 42
38+
ranked: true
39+
40+
# Student
41+
model:
42+
_target_: nemo_automodel.NeMoAutoModelForImageTextToText.from_pretrained
43+
pretrained_model_name_or_path: Qwen/Qwen3.5-4B
44+
attn_implementation: sdpa
45+
46+
# Teacher
47+
teacher_model:
48+
_target_: nemo_automodel.NeMoAutoModelForImageTextToText.from_pretrained
49+
pretrained_model_name_or_path: Qwen/Qwen3.5-9B
50+
attn_implementation: sdpa
51+
52+
processor:
53+
_target_: transformers.AutoProcessor.from_pretrained
54+
pretrained_model_name_or_path: Qwen/Qwen3.5-4B
55+
56+
checkpoint:
57+
enabled: true
58+
checkpoint_dir: checkpoints/qwen3_5_vl_4b_kd/
59+
model_save_format: safetensors
60+
save_consolidated: true
61+
62+
distributed:
63+
_target_: nemo_automodel.components.distributed.fsdp2.FSDP2Manager
64+
tp_size: 1
65+
cp_size: 1
66+
pp_size: 1
67+
dp_replicate_size: 1
68+
ep_size: 1
69+
sequence_parallel: false
70+
71+
clip_grad_norm:
72+
max_norm: 1.0
73+
74+
loss_fn:
75+
_target_: nemo_automodel.components.loss.masked_ce.MaskedCrossEntropy
76+
77+
# KD hyper-params
78+
kd_ratio: 0.5
79+
kd_loss_fn:
80+
_target_: nemo_automodel.components.loss.kd_loss.KDLoss
81+
ignore_index: -100
82+
temperature: 1.0
83+
fp32_upcast: true
84+
chunk_size: 512
85+
86+
optimizer:
87+
_target_: torch.optim.AdamW
88+
lr: 5e-5
89+
weight_decay: 0.01
90+
betas: [0.9, 0.95]
91+
eps: 1e-8
92+
93+
lr_scheduler:
94+
lr_warmup_steps: 30
95+
lr_decay_style: cosine
96+
97+
# MedPix-VQA dataset (medical images + VQA)
98+
dataset:
99+
_target_: nemo_automodel.components.datasets.vlm.datasets.make_medpix_dataset
100+
path_or_dataset: mmoukouba/MedPix-VQA
101+
split: train
102+
103+
dataloader:
104+
_target_: torchdata.stateful_dataloader.StatefulDataLoader
105+
num_workers: 0
106+
collate_fn:
107+
_target_: nemo_automodel.components.datasets.vlm.collate_fns.default_collate_fn
108+
109+
validation_dataset:
110+
_target_: nemo_automodel.components.datasets.vlm.datasets.make_medpix_dataset
111+
path_or_dataset: mmoukouba/MedPix-VQA
112+
split: validation
113+
114+
validation_dataloader:
115+
_target_: torchdata.stateful_dataloader.StatefulDataLoader
116+
num_workers: 0
117+
collate_fn:
118+
_target_: nemo_automodel.components.datasets.vlm.collate_fns.default_collate_fn
119+
120+
# Student: freeze vision+audio towers, train language model only
121+
freeze_config:
122+
freeze_vision_tower: true
123+
freeze_audio_tower: true
124+
freeze_language_model: false

nemo_automodel/components/loss/kd_loss.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,37 @@ def _kl_forward_tp(
8989
return ce_local # shape: [valid_tokens]
9090

9191

92+
def _kl_forward_chunked(
93+
t_logits: torch.Tensor,
94+
s_logits: torch.Tensor,
95+
chunk_size: int,
96+
) -> torch.Tensor:
97+
"""Compute per-token sum(P * log Q) in chunks to reduce peak memory.
98+
99+
Processes ``chunk_size`` tokens at a time so that only one chunk's worth of the
100+
``[chunk_size, vocab_size]`` fp32 probability matrix is live at any moment.
101+
102+
Args:
103+
t_logits: Teacher logits, shape ``[num_valid_tokens, vocab_size]``.
104+
s_logits: Student logits, shape ``[num_valid_tokens, vocab_size]``.
105+
chunk_size: Number of tokens per chunk.
106+
107+
Returns:
108+
Per-token sum(P * log Q), shape ``[num_valid_tokens]``.
109+
"""
110+
num_tokens = t_logits.shape[0]
111+
kl_parts: list[torch.Tensor] = []
112+
for start in range(0, num_tokens, chunk_size):
113+
end = min(start + chunk_size, num_tokens)
114+
t_chunk = t_logits[start:end]
115+
s_chunk = s_logits[start:end]
116+
teacher_prob = F.softmax(t_chunk, dim=-1, dtype=torch.float32)
117+
student_logprob = F.log_softmax(s_chunk, dim=-1, dtype=torch.float32)
118+
inf_mask = torch.isinf(s_chunk)
119+
kl_parts.append(torch.masked_fill(teacher_prob * student_logprob, inf_mask, 0).sum(-1))
120+
return torch.cat(kl_parts, dim=0)
121+
122+
92123
class KDLoss(nn.Module):
93124
"""Forward KL divergence loss for knowledge distillation.
94125
@@ -108,6 +139,10 @@ class KDLoss(nn.Module):
108139
tp_group: Explicit TP process group. When ``None`` (default) the group is inferred from
109140
the DTensor placement of ``student_logits``, or the non-TP path is used for plain
110141
tensors.
142+
chunk_size: When positive, valid tokens are processed in chunks of this size to avoid
143+
materializing the full ``[num_valid_tokens, vocab_size]`` probability matrix in fp32.
144+
Reduces peak memory at the cost of slightly more kernel launches. ``0`` (default)
145+
disables chunking. Ignored when using the TP path.
111146
"""
112147

113148
def __init__(
@@ -116,12 +151,14 @@ def __init__(
116151
temperature: float = 1.0,
117152
fp32_upcast: bool = True,
118153
tp_group: Optional[torch.distributed.ProcessGroup] = None,
154+
chunk_size: int = 0,
119155
):
120156
super().__init__()
121157
self.ignore_index = ignore_index
122158
self.temperature = temperature
123159
self.fp32_upcast = fp32_upcast
124160
self.tp_group = tp_group
161+
self.chunk_size = chunk_size
125162

126163
def forward(
127164
self,
@@ -191,12 +228,11 @@ def forward(
191228
# Compute per-token negative cross-entropy: sum(P * log Q).
192229
if tp_group is not None:
193230
kl_per_token = _kl_forward_tp(t_logits, s_logits, tp_group)
231+
elif self.chunk_size > 0:
232+
kl_per_token = _kl_forward_chunked(t_logits, s_logits, self.chunk_size)
194233
else:
195234
teacher_prob = F.softmax(t_logits, dim=-1, dtype=torch.float32)
196235
student_logprob = F.log_softmax(s_logits, dim=-1, dtype=torch.float32)
197-
# mask out infinities originating *only* from student logits
198-
# (teacher logits infs are extremely rare and do not
199-
# affect gradients w.r.t. student parameters).
200236
inf_mask = torch.isinf(s_logits)
201237
kl_per_token = torch.masked_fill(teacher_prob * student_logprob, inf_mask, 0).sum(-1).view(-1)
202238

0 commit comments

Comments
 (0)