Skip to content

Commit 5be52e7

Browse files
Merge branch 'dkorzekwa/any_model' into dkorzekwa/any_model_mbridge_distillation
2 parents 6b2afe5 + f8f36a5 commit 5be52e7

File tree

8 files changed

+1724
-2
lines changed

8 files changed

+1724
-2
lines changed
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
# Knowledge Distillation with NeMo AutoModel
2+
3+
This guide shows how to run knowledge distillation on Puzzletron-compressed AnyModel (heterogeneous) checkpoints using **NeMo AutoModel**. AutoModel enables efficient training of any HuggingFace model with a unified API; here we extend it to load heterogeneous checkpoints and use TP-friendly KD loss.
4+
5+
## Overview
6+
7+
1. **AutoModel + AnyModel**: We monkey-patch NeMo AutoModel so `from_pretrained(..., anymodel_descriptor=..., block_configs_path=...)` can load heterogeneous checkpoints. The patch uses ModelOpt’s `ModelDescriptorFactory` and `deci_x_patcher` to apply per-layer configs during model init.
8+
2. **Custom KD recipe**: For distillation we use a custom recipe (`recipe.py`) that adds pipeline-parallel (PP) support, better logging, and TP-friendly KD loss. Pretraining is unchanged and uses AutoModel’s built-in recipe. Once the AutoModel repo gains these features, the custom recipe can be dropped and the upstream KD recipe used instead.
9+
3. **KD loss** (`loss.py`): We provide a TP-aware KD on precomputed logits only; CE is computed separately and mixed with `kd_ratio`.
10+
11+
**Supported parallelisms**
12+
FSDP is fully supported. Pipeline parallelism (PP) is supported for most models; exceptions are those whose layer naming does not follow the usual HuggingFace convention. Tensor parallelism (TP) and sequence parallelism (SP) are mostly supported—a known exception is GPT-OSS due to sink tokens (AutoModel has the same limitation; it is not specific to AnyModel). Context parallelism (CP) is supported for all models tested. Expert parallelism (EP) is not supported: AutoModel relies on custom (non–HuggingFace) model implementations for EP, which conflicts with the goal of supporting any HF model.
13+
14+
## Setup
15+
16+
**Requirements**
17+
18+
- NeMo AutoModel (install from source or use a container that provides it).
19+
- ModelOpt installed (`pip install nvidia-modelopt` or install from the Model-Optimizer repo).
20+
- For KD: this example’s `recipe.py`, `loss.py`, and `patch_automodel.py` (the run entrypoint always applies the patch before loading models).
21+
22+
**Environment**
23+
24+
Set `PYTHONPATH` so that the Model-Optimizer root is on the path (for ModelOpt and, if you run this example as a module, for `automodel_distillation`):
25+
26+
```bash
27+
export PYTHONPATH="/path/to/Model-Optimizer:${PYTHONPATH}"
28+
```
29+
30+
If you use a NeMo AutoModel container, ensure the AutoModel package is installed (e.g. clone AutoModel and `pip install -e .`). Upgrade HuggingFace Transformers if needed (e.g. for compatibility):
31+
32+
```bash
33+
python -m pip install -e /path/to/AutoModel
34+
python -m pip install -U omegaconf fire transformers
35+
```
36+
37+
## Configuration
38+
39+
- **pretrain.yaml** – Pretrain/finetune on an AnyModel checkpoint. Set `model.pretrained_model_name_or_path` and `model.anymodel_descriptor` (e.g. `gpt_oss_20b`, `llama`, `qwen2`, `qwen3`). Optional: `model.block_configs_path`; if omitted, block configs are auto-detected from `<checkpoint_dir>/block_configs.json`.
40+
- **kd.yaml** – Knowledge distillation. Set `model.pretrained_model_name_or_path` and `model.anymodel_descriptor` for the student, and `teacher_model.pretrained_model_name_or_path` and `teacher_model.anymodel_descriptor` for the teacher.
41+
42+
Paths and descriptors can be overridden from the command line (see below).
43+
44+
## Run
45+
46+
**Apply the patch and run KD**
47+
48+
Before loading models, the run entrypoint calls `apply_patch()` so that `from_pretrained` accepts `anymodel_descriptor` and `block_configs_path`. Then it loads the config and runs the chosen recipe.
49+
50+
Run from the **automodel_distillation** directory so that `run.py` can import `patch_automodel` and `recipe`:
51+
52+
```bash
53+
cd /path/to/Model-Optimizer/examples/puzzletron/automodel_distillation
54+
torchrun --nproc_per_node=2 \
55+
-m run \
56+
--mode kd \
57+
-c kd.yaml
58+
```
59+
60+
Override config (e.g. paths and descriptor) on the command line:
61+
62+
```bash
63+
torchrun --nproc_per_node=2 \
64+
-m run \
65+
--mode kd \
66+
-c kd.yaml \
67+
model.pretrained_model_name_or_path=/path/to/student \
68+
model.anymodel_descriptor=gpt_oss_20b \
69+
teacher_model.pretrained_model_name_or_path=/path/to/teacher \
70+
teacher_model.anymodel_descriptor=gpt_oss_20b
71+
```
72+
73+
**Pretrain (uses AutoModel’s built-in recipe)**
74+
75+
```bash
76+
torchrun --nproc_per_node=2 \
77+
-m run \
78+
--mode pretrain \
79+
-c pretrain.yaml \
80+
model.pretrained_model_name_or_path=/path/to/checkpoint \
81+
model.anymodel_descriptor=gpt_oss_20b
82+
```
83+
84+
**Note:** If you run from a different layout (e.g. from the Model-Optimizer repo root or under another package name), set `PYTHONPATH` to include this directory so `run` can import `patch_automodel` and `recipe`, and ensure the config `kd_loss_fn._target_` (e.g. `loss.KDLoss`) resolves to the correct module.
85+
86+
## Example: Running on a cluster
87+
88+
Below is an example job setup: NeMo AutoModel container, clone AutoModel main, install it and upgrade Transformers, then run KD from a directory that contains your config and run script (e.g. a copy of this example or the RealAnyModel layout).
89+
90+
```bash
91+
# Submit interactive job (example with your cluster’s submit_job)
92+
submit_job --partition interactive --time 2.0 \
93+
--image nvcr.io/nvidia/nemo-automodel:25.11.00 \
94+
--mounts "/path/to/AutoModel/:/opt/Automodel/,/lustre:/lustre" \
95+
--interactive --gpu 2 --skip_image_check --email_mode=never \
96+
--command='bash'
97+
98+
# Inside the container
99+
source /opt/venv/bin/activate
100+
cd /opt/Automodel/
101+
python -m pip install -e .
102+
python -m pip install -U omegaconf fire transformers
103+
python -m pip uninstall nvidia-modelopt
104+
cd /path/to/Model-Optimizer
105+
python -m pip install -e .
106+
107+
# Run KD (from your project dir that has run.py, kd.yaml, patch_automodel, loss, recipe)
108+
cd ./examples/puzzletron/automodel_distillation/
109+
torchrun --nproc_per_node 2 -m run --mode kd -c kd.yaml 2>&1 | tee logs
110+
```
111+
112+
Use your own paths for mounts, checkpoint dirs, and config overrides as needed.
113+
114+
## Files in this example
115+
116+
| File | Purpose |
117+
|------|--------|
118+
| `patch_automodel.py` | Monkey-patch so `from_pretrained` accepts `anymodel_descriptor` and `block_configs_path`; uses ModelOpt’s `deci_x_patcher`. |
119+
| `loss.py` | KDLoss: TP-aware KD on precomputed logits (CE is mixed via `kd_ratio` in the recipe). |
120+
| `recipe.py` | Custom KD recipe (PP support, logging, TP-friendly KD). |
121+
| `run.py` | Entrypoint: applies patch, then runs pretrain or KD using the config. |
122+
| `pretrain.yaml` | Pretrain config (no hardcoded paths; override on CLI). |
123+
| `kd.yaml` | KD config (no hardcoded paths; override on CLI). |
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Knowledge distillation: student and teacher are AnyModel checkpoints.
16+
# Requires apply_patch() from patch_automodel. Set model and teacher_model paths and descriptors.
17+
# anymodel_descriptor must match a ModelOpt ModelDescriptorFactory name (e.g. gpt_oss_20b, llama, qwen2, qwen3).
18+
#
19+
# KD loss (kd_loss_fn._target_): use loss.KDLoss for TP-aware KD on precomputed logits.
20+
# CE is computed by loss_fn and mixed with KD via kd_ratio in the recipe.
21+
# If running under a different package name, use that module path (e.g. automodel_distillation.loss.KDLoss).
22+
#
23+
# To run:
24+
# torchrun --nproc_per_node <N> -m automodel_distillation.run --mode kd -c kd.yaml
25+
# Override: model.pretrained_model_name_or_path=/path/to/student model.anymodel_descriptor=llama ...
26+
27+
step_scheduler:
28+
global_batch_size: 128
29+
local_batch_size: 4
30+
ckpt_every_steps: 200
31+
val_every_steps: 100
32+
num_epochs: 2
33+
34+
dist_env:
35+
backend: nccl
36+
timeout_minutes: 5
37+
38+
rng:
39+
_target_: nemo_automodel.components.training.rng.StatefulRNG
40+
seed: 1111
41+
ranked: true
42+
43+
model:
44+
_target_: nemo_automodel.NeMoAutoModelForCausalLM.from_pretrained
45+
pretrained_model_name_or_path: ./heterogeneous_ckpts/meta-llama-Llama-3.1-8B-Instruct/ # student checkpoint dir
46+
anymodel_descriptor: llama # e.g. gpt_oss_20b, llama, qwen2, qwen3
47+
force_hf: true
48+
torch_dtype: bf16
49+
trust_remote_code: true
50+
51+
teacher_model:
52+
_target_: nemo_automodel.NeMoAutoModelForCausalLM.from_pretrained
53+
pretrained_model_name_or_path: ./heterogeneous_ckpts/meta-llama-Llama-3.1-8B-Instruct-teacher/ # teacher checkpoint dir
54+
anymodel_descriptor: llama # same format as model.anymodel_descriptor
55+
force_hf: true
56+
torch_dtype: bf16
57+
trust_remote_code: true
58+
59+
checkpoint:
60+
enabled: true
61+
checkpoint_dir: checkpoints/
62+
model_save_format: safetensors
63+
save_consolidated: false
64+
65+
distributed:
66+
dp_size: none
67+
tp_size: 2
68+
cp_size: 1
69+
ep_size: 1
70+
sequence_parallel: false
71+
pp_size: 1
72+
pipeline:
73+
pp_schedule: interleaved1f1b
74+
pp_microbatch_size: 1
75+
scale_grads_in_schedule: false
76+
round_virtual_stages_to_pp_multiple: up
77+
dtype: bf16
78+
79+
distributed_config:
80+
_target_: nemo_automodel.components.distributed.config.FSDP2Config
81+
activation_checkpointing: false
82+
83+
compile_config:
84+
enabled: true
85+
86+
packed_sequence:
87+
packed_sequence_size: 1024
88+
split_across_pack: false
89+
90+
loss_fn:
91+
_target_: nemo_automodel.components.loss.te_parallel_ce.TEParallelCrossEntropy
92+
93+
# 0 = pure CE (better to run pretrain instead of loading a teacher and not using it)
94+
# 1 = pure KD (common practice for puzzletron distillation)
95+
kd_ratio: 1.0
96+
97+
kd_loss_fn:
98+
_target_: loss.KDLoss
99+
ignore_index: -100
100+
temperature: 1.0
101+
fp32_upcast: true
102+
103+
optimizer:
104+
_target_: torch.optim.Adam
105+
betas: [0.9, 0.999]
106+
eps: 1.0e-8
107+
lr: 1.0e-5
108+
weight_decay: 0
109+
110+
dataset:
111+
_target_: nemo_automodel.components.datasets.llm.squad.make_squad_dataset
112+
dataset_name: rajpurkar/squad
113+
split: train
114+
115+
dataloader:
116+
_target_: torchdata.stateful_dataloader.StatefulDataLoader
117+
collate_fn: nemo_automodel.components.datasets.utils.default_collater
118+
shuffle: false
119+
120+
validation_dataset:
121+
_target_: nemo_automodel.components.datasets.llm.squad.make_squad_dataset
122+
dataset_name: rajpurkar/squad
123+
split: validation
124+
125+
validation_dataloader:
126+
_target_: torchdata.stateful_dataloader.StatefulDataLoader
127+
collate_fn: nemo_automodel.components.datasets.utils.default_collater
128+
129+
# wandb:
130+
# project: <your_project>
131+
# entity: <your_entity>
132+
# name: <your_run_name>
133+
# save_dir: <your_save_dir>
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import torch
16+
import torch.distributed as dist
17+
import torch.nn as nn
18+
import torch.nn.functional as F
19+
from torch.distributed.tensor import DTensor, Shard
20+
21+
22+
def _infer_tp_group_from_dtensor(tensor: "torch.Tensor"):
23+
"""Return device_mesh process group if tensor is a DTensor sharded on vocab (logits last dim, lm_head dim 0)."""
24+
if not isinstance(tensor, DTensor):
25+
return None
26+
# Vocab sharding: Shard on last dim (logits) or Shard(0) (weight matrix)
27+
has_shard = any(isinstance(p, Shard) for p in tensor.placements)
28+
if not has_shard:
29+
return None
30+
return tensor.device_mesh.get_group()
31+
32+
33+
def _kl_forward_tp(
34+
t_logits: torch.Tensor,
35+
s_logits: torch.Tensor,
36+
tp_group,
37+
) -> torch.Tensor:
38+
"""
39+
Compute KL (negative cross entropy sum(P*log Q)) with tensor parallelism.
40+
Returns per-token negative cross entropy (sum over vocab).
41+
"""
42+
teacher_max = t_logits.max(dim=-1, keepdim=True).values
43+
dist.all_reduce(teacher_max, op=dist.ReduceOp.MAX, group=tp_group)
44+
output_teacher = t_logits - teacher_max
45+
46+
denom_teacher = torch.exp(output_teacher).sum(dim=-1, keepdim=True)
47+
dist.all_reduce(denom_teacher, op=dist.ReduceOp.SUM, group=tp_group)
48+
teacher_prob = torch.exp(output_teacher - torch.log(denom_teacher.clamp(min=1e-12)))
49+
50+
student_max = s_logits.max(dim=-1, keepdim=True).values
51+
dist.all_reduce(student_max, op=dist.ReduceOp.MAX, group=tp_group)
52+
output_student = s_logits - student_max.detach()
53+
54+
denom_student = torch.exp(output_student).sum(dim=-1, keepdim=True)
55+
dist.all_reduce(denom_student, op=dist.ReduceOp.SUM, group=tp_group)
56+
student_log_prob = output_student - torch.log(denom_student.clamp(min=1e-12))
57+
58+
term = teacher_prob * student_log_prob
59+
inf_mask = torch.isinf(s_logits)
60+
term = torch.masked_fill(term, inf_mask, 0.0)
61+
ce_local = term.sum(dim=-1)
62+
dist.all_reduce(ce_local, op=dist.ReduceOp.SUM, group=tp_group)
63+
return ce_local.view(-1)
64+
65+
66+
class KDLoss(nn.Module):
67+
"""TP-aware KD on precomputed logits."""
68+
69+
def __init__(
70+
self,
71+
ignore_index: int = -100,
72+
temperature: float = 1.0,
73+
fp32_upcast: bool = True,
74+
tp_group=None,
75+
**kwargs,
76+
):
77+
super().__init__()
78+
self.ignore_index = ignore_index
79+
self.temperature = temperature
80+
self.fp32_upcast = fp32_upcast
81+
self.tp_group = tp_group
82+
83+
def forward(
84+
self,
85+
student_logits: torch.Tensor,
86+
teacher_logits: torch.Tensor,
87+
labels: torch.Tensor,
88+
num_batch_labels: int | None = None,
89+
) -> torch.Tensor:
90+
valid_mask = (labels != self.ignore_index).view(-1)
91+
if valid_mask.sum() == 0:
92+
return student_logits.new_tensor(0.0)
93+
94+
if student_logits.ndim > 2:
95+
student_logits = student_logits.view(-1, student_logits.shape[-1])
96+
if teacher_logits.ndim > 2:
97+
teacher_logits = teacher_logits.view(-1, teacher_logits.shape[-1])
98+
if labels.ndim > 1:
99+
labels = labels.view(-1)
100+
101+
tp_group = self.tp_group
102+
if isinstance(student_logits, DTensor) and tp_group is None:
103+
tp_group = _infer_tp_group_from_dtensor(student_logits)
104+
105+
if tp_group is not None:
106+
if isinstance(student_logits, DTensor):
107+
student_logits = student_logits.to_local()
108+
if isinstance(teacher_logits, DTensor):
109+
teacher_logits = teacher_logits.to_local()
110+
else:
111+
if isinstance(student_logits, DTensor):
112+
student_logits = student_logits.full_tensor()
113+
if isinstance(teacher_logits, DTensor):
114+
teacher_logits = teacher_logits.full_tensor()
115+
116+
t_logits = teacher_logits[valid_mask]
117+
s_logits = student_logits[valid_mask]
118+
119+
if self.fp32_upcast:
120+
t_logits = t_logits.float()
121+
s_logits = s_logits.float()
122+
if self.temperature != 1.0:
123+
t_logits = t_logits.mul(1.0 / self.temperature)
124+
s_logits = s_logits.mul(1.0 / self.temperature)
125+
126+
if tp_group is not None:
127+
kl_per_token = _kl_forward_tp(t_logits, s_logits, tp_group)
128+
else:
129+
teacher_prob = F.softmax(t_logits, dim=-1, dtype=torch.float32)
130+
student_logprob = F.log_softmax(s_logits, dim=-1, dtype=torch.float32)
131+
inf_mask = torch.isinf(s_logits)
132+
kl_per_token = (
133+
torch.masked_fill(teacher_prob * student_logprob, inf_mask, 0.0).sum(-1).view(-1)
134+
)
135+
136+
if self.temperature != 1.0:
137+
kl_per_token = kl_per_token * (self.temperature**2)
138+
139+
if num_batch_labels is not None:
140+
return -torch.sum(kl_per_token) / num_batch_labels
141+
return -torch.mean(kl_per_token)

0 commit comments

Comments
 (0)