feat(recipes): add VLM knowledge distillation recipe with chunked KD loss#2205
Conversation
Add a ``chunk_size`` knob to ``KDLoss`` that processes valid tokens in chunks when computing forward KL. Only one ``[chunk_size, vocab_size]`` fp32 probability/log-prob tensor is materialized at a time, which keeps peak memory bounded for large-vocab VLMs while remaining numerically identical to the unchunked path (verified by new unit tests). The TP path is unaffected; chunking is opt-in via ``chunk_size > 0``. Signed-off-by: khazic <khazzz1c@gmail.com>
Add ``KnowledgeDistillationRecipeForVLM`` under ``recipes/vlm/kd.py``. The recipe extends ``FinetuneRecipeForVLM`` with a frozen teacher ``NeMoAutoModelForImageTextToText``, a KD loss term, and a ``kd_ratio`` linear mix between CE and KD losses. The training loop forwards multimodal inputs (pixel_values, image_grid_thw, etc.) to both teacher and student, frees intermediate activations eagerly to keep peak memory low, and reports CE/KD sub-losses alongside the combined loss in validation metrics. Pipeline parallelism is not supported. Signed-off-by: khazic <khazzz1c@gmail.com>
Add an example YAML that distills Qwen3.5-9B (teacher) into Qwen3.5-4B (student) on the public ``mmoukouba/MedPix-VQA`` medical-image VQA dataset. The config exercises the chunked KD loss (``kd_loss_fn.chunk_size: 512``), freezes the student's vision and audio towers, and uses FSDP2. Signed-off-by: khazic <khazzz1c@gmail.com>
|
/ok to test ce2e186 |
|
Full 300-step sanity run on the example config — Qwen3.5-9B (teacher) → Qwen3.5-4B (student) on Per-step totals match
CE loss drops ~24% (1.69 → 1.29) end-to-end; KD loss drops ~5% (2.03 → 1.92). The asymmetry is expected — forward KL = Raw log excerpt (early / mid / final windows) |
|
Clarification on scope: this PR intentionally implements the initial/base VLM KD path only, targeting dense models with the current example configuration ( Support and validation for tensor parallelism, expert parallelism, pipeline parallelism, and MoE-specific behavior will be handled in a follow-up PR. This PR is meant to establish the basic dense VLM KD recipe and chunked KD loss path first. |
|
/claude review |
| raise ValueError("Both student and teacher must specify pretrained_model_name_or_path") | ||
| student_tokenizer = NeMoAutoTokenizer.from_pretrained(student_name, trust_remote_code=trust_remote_code) | ||
| teacher_tokenizer = NeMoAutoTokenizer.from_pretrained(teacher_name, trust_remote_code=trust_remote_code) | ||
| if student_tokenizer.vocab_size != teacher_tokenizer.vocab_size: |
There was a problem hiding this comment.
Nit: The LLM KD recipe uses .detach().clone() here (recipes/llm/kd.py:399), which ensures the teacher output's underlying storage can be fully freed. Without .clone(), teacher_logits shares storage with teacher_out.logits, so deleting teacher_out may not release all associated memory. In practice this is likely fine (logits is usually a separate allocation), but for consistency with the LLM KD recipe and the PR's stated goal of "eagerly freeing intermediate activations":
| if student_tokenizer.vocab_size != teacher_tokenizer.vocab_size: | |
| teacher_logits = getattr(teacher_out, "logits", teacher_out).detach().clone() |
HuiyingLi
left a comment
There was a problem hiding this comment.
Thank you! Looking forward to followup prs.
What does this PR do ?
Adds a knowledge distillation recipe for vision-language models, plus a memory-efficient chunked path in the existing KD loss to keep peak fp32 memory bounded for large-vocab VLMs.
Closes #2195.
Changelog
feat(loss): addchunk_sizeparameter toKDLossthat materializes the fp32[chunk_size, vocab_size]probability matrix one chunk at a time. Numerically identical to the unchunked path; TP path unchanged. New unit tests cover chunked vs. unchunked equivalence, temperature scaling, andnum_batch_labels.feat(recipes): addKnowledgeDistillationRecipeForVLMundernemo_automodel/recipes/vlm/kd.py. Mirrors the LLM KD recipe structure: frozen teacherNeMoAutoModelForImageTextToText, KD loss term,kd_ratiolinear mix between CE and KD. Forwards multimodal inputs (pixel_values,image_grid_thw, etc.) to both teacher and student; eagerly frees intermediate activations; reports CE/KD sub-losses in validation metrics. Respectsfreeze_vision_toweron the student. Pipeline parallelism is not supported in this initial version.docs(examples): addexamples/vlm_kd/qwen3_5/qwen3_5_vl_4b_kd.yamldistilling Qwen3.5-9B → Qwen3.5-4B on the publicmmoukouba/MedPix-VQAdataset. Exercises chunked KD (chunk_size: 512), freezes vision and audio towers on the student, FSDP2.Before your PR is "Ready for review"
Pre checks:
Additional Information
pytest tests/unit_tests/loss/test_kd_loss.py— 28 passed locally.