Skip to content

feat(recipes): add VLM knowledge distillation recipe with chunked KD loss#2205

Merged
HuiyingLi merged 3 commits into
NVIDIA-NeMo:mainfrom
khazic:khazic/feat/vlm_kd
May 12, 2026
Merged

feat(recipes): add VLM knowledge distillation recipe with chunked KD loss#2205
HuiyingLi merged 3 commits into
NVIDIA-NeMo:mainfrom
khazic:khazic/feat/vlm_kd

Conversation

@khazic
Copy link
Copy Markdown
Contributor

@khazic khazic commented May 11, 2026

What does this PR do ?

Adds a knowledge distillation recipe for vision-language models, plus a memory-efficient chunked path in the existing KD loss to keep peak fp32 memory bounded for large-vocab VLMs.

Closes #2195.

Changelog

  • feat(loss): add chunk_size parameter to KDLoss that materializes the fp32 [chunk_size, vocab_size] probability matrix one chunk at a time. Numerically identical to the unchunked path; TP path unchanged. New unit tests cover chunked vs. unchunked equivalence, temperature scaling, and num_batch_labels.
  • feat(recipes): add KnowledgeDistillationRecipeForVLM under nemo_automodel/recipes/vlm/kd.py. Mirrors the LLM KD recipe structure: frozen teacher NeMoAutoModelForImageTextToText, KD loss term, kd_ratio linear mix between CE and KD. Forwards multimodal inputs (pixel_values, image_grid_thw, etc.) to both teacher and student; eagerly frees intermediate activations; reports CE/KD sub-losses in validation metrics. Respects freeze_vision_tower on the student. Pipeline parallelism is not supported in this initial version.
  • docs(examples): add examples/vlm_kd/qwen3_5/qwen3_5_vl_4b_kd.yaml distilling Qwen3.5-9B → Qwen3.5-4B on the public mmoukouba/MedPix-VQA dataset. Exercises chunked KD (chunk_size: 512), freezes vision and audio towers on the student, FSDP2.

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?

Additional Information

khazic added 3 commits May 11, 2026 16:39
Add a ``chunk_size`` knob to ``KDLoss`` that processes valid tokens in
chunks when computing forward KL. Only one ``[chunk_size, vocab_size]``
fp32 probability/log-prob tensor is materialized at a time, which keeps
peak memory bounded for large-vocab VLMs while remaining numerically
identical to the unchunked path (verified by new unit tests).

The TP path is unaffected; chunking is opt-in via ``chunk_size > 0``.

Signed-off-by: khazic <khazzz1c@gmail.com>
Add ``KnowledgeDistillationRecipeForVLM`` under ``recipes/vlm/kd.py``.
The recipe extends ``FinetuneRecipeForVLM`` with a frozen teacher
``NeMoAutoModelForImageTextToText``, a KD loss term, and a ``kd_ratio``
linear mix between CE and KD losses.

The training loop forwards multimodal inputs (pixel_values,
image_grid_thw, etc.) to both teacher and student, frees intermediate
activations eagerly to keep peak memory low, and reports CE/KD
sub-losses alongside the combined loss in validation metrics. Pipeline
parallelism is not supported.

Signed-off-by: khazic <khazzz1c@gmail.com>
Add an example YAML that distills Qwen3.5-9B (teacher) into Qwen3.5-4B
(student) on the public ``mmoukouba/MedPix-VQA`` medical-image VQA
dataset. The config exercises the chunked KD loss
(``kd_loss_fn.chunk_size: 512``), freezes the student's vision and
audio towers, and uses FSDP2.

Signed-off-by: khazic <khazzz1c@gmail.com>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 11, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@khazic khazic changed the title [recipes] feat: add VLM knowledge distillation recipe with chunked KD loss feat(recipes): add VLM knowledge distillation recipe with chunked KD loss May 11, 2026
@HuiyingLi
Copy link
Copy Markdown
Contributor

/ok to test ce2e186

@khazic
Copy link
Copy Markdown
Contributor Author

khazic commented May 11, 2026

Full 300-step sanity run on the example config — Qwen3.5-9B (teacher) → Qwen3.5-4B (student) on mmoukouba/MedPix-VQA, FSDP2 on 8×H100, chunk_size=512, kd_ratio=0.5, temperature=1.0. Total wall time ≈ 71 min.

Per-step totals match (1-kd_ratio)*ce + kd_ratio*kd exactly (e.g. step 75: 0.5*1.7002 + 0.5*2.1826 = 1.9414 ✓). No NaN, no divergence; eager-free of intermediate activations keeps peak memory bounded across image-size variance.

window steps ce mean kd mean total mean peak mem
early 75–99 1.69 2.03 1.86 33.5 GiB
mid 126–149 1.62 1.99 1.80 45.6 GiB
final 276–299 1.29 1.92 1.60 36.0 GiB

CE loss drops ~24% (1.69 → 1.29) end-to-end; KD loss drops ~5% (2.03 → 1.92). The asymmetry is expected — forward KL = H(P_teacher) + KL(P_teacher‖P_student), and the teacher-entropy term is a constant offset that does not decay with training, so the reducible signal is concentrated in the CE component.

Raw log excerpt (early / mid / final windows)
step  75 | loss 1.9414 | ce 1.7002 | kd 2.1826 | lr 4.65e-05 | mem 27.6 GiB | tps 1295
step  80 | loss 1.6199 | ce 1.5234 | kd 1.7163 | lr 4.58e-05 | mem 33.5 GiB | tps 1176
step  90 | loss 1.7475 | ce 1.4456 | kd 2.0494 | lr 4.40e-05 | mem 28.5 GiB | tps  956
step  99 | loss 1.9515 | ce 1.7086 | kd 2.1944 | lr 4.22e-05 | mem 27.7 GiB | tps 1078
step 126 | loss 1.7391 | ce 1.4892 | kd 1.9889 | lr 3.58e-05 | mem 27.4 GiB | tps  962
step 131 | loss 1.9853 | ce 1.6907 | kd 2.2799 | lr 3.45e-05 | mem 45.6 GiB | tps 1182
step 140 | loss 1.4328 | ce 1.1651 | kd 1.7005 | lr 3.21e-05 | mem 36.9 GiB | tps  953
step 149 | loss 1.7113 | ce 1.5619 | kd 1.8607 | lr 2.95e-05 | mem 32.3 GiB | tps 1040
step 280 | loss 1.4842 | ce 1.0968 | kd 1.8717 | lr 1.10e-06 | mem 26.5 GiB | tps 1182
step 285 | loss 1.4599 | ce 1.0999 | kd 1.8199 | lr 8.28e-07 | mem 27.4 GiB | tps  955
step 294 | loss 1.4891 | ce 1.1008 | kd 1.8775 | lr 5.42e-07 | mem 30.9 GiB | tps  919
step 298 | loss 1.3651 | ce 1.0605 | kd 1.6697 | lr 5.02e-07 | mem 30.7 GiB | tps 1120
step 299 | loss 1.4761 | ce 1.3185 | kd 1.6337 | lr 5.00e-07 | mem 33.6 GiB | tps  996

@khazic
Copy link
Copy Markdown
Contributor Author

khazic commented May 11, 2026

Clarification on scope: this PR intentionally implements the initial/base VLM KD path only, targeting dense models with the current example configuration (tp_size=1, ep_size=1, pp_size=1).

Support and validation for tensor parallelism, expert parallelism, pipeline parallelism, and MoE-specific behavior will be handled in a follow-up PR. This PR is meant to establish the basic dense VLM KD recipe and chunked KD loss path first.

@HuiyingLi
Copy link
Copy Markdown
Contributor

/claude review

raise ValueError("Both student and teacher must specify pretrained_model_name_or_path")
student_tokenizer = NeMoAutoTokenizer.from_pretrained(student_name, trust_remote_code=trust_remote_code)
teacher_tokenizer = NeMoAutoTokenizer.from_pretrained(teacher_name, trust_remote_code=trust_remote_code)
if student_tokenizer.vocab_size != teacher_tokenizer.vocab_size:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: The LLM KD recipe uses .detach().clone() here (recipes/llm/kd.py:399), which ensures the teacher output's underlying storage can be fully freed. Without .clone(), teacher_logits shares storage with teacher_out.logits, so deleting teacher_out may not release all associated memory. In practice this is likely fine (logits is usually a separate allocation), but for consistency with the LLM KD recipe and the PR's stated goal of "eagerly freeing intermediate activations":

Suggested change
if student_tokenizer.vocab_size != teacher_tokenizer.vocab_size:
teacher_logits = getattr(teacher_out, "logits", teacher_out).detach().clone()

Copy link
Copy Markdown
Contributor

@HuiyingLi HuiyingLi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! Looking forward to followup prs.

@HuiyingLi HuiyingLi merged commit 61a5116 into NVIDIA-NeMo:main May 12, 2026
64 of 65 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Knowledge distillation recipe for VLMs

3 participants