Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions examples/puzzletron/mbridge_distillation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,11 @@ Run distillation directly from HuggingFace checkpoints (student and teacher) wit
```bash
torchrun --nproc_per_node=8 examples/puzzletron/mbridge_distillation/distill_hf.py \
--student_hf_path /path/to/student/huggingface/checkpoint \
--student_hf_model meta-llama/Llama-3.1-8B-Instruct \
--teacher_hf_path /path/to/teacher/huggingface/checkpoint \
--data_paths 1.0 /path/to/hf_datasets/wikitext-103-v1/Salesforce--wikitext_wikitext-103-v1_train_text_document \
--output_dir /path/to/distilled/checkpoint \
--hf-export-path /path/to/exported/hf/model \
--hf-model meta-llama/Llama-3.1-8B-Instruct \
--hf_export_path /path/to/exported/hf/model \
--seq_length 4096 \
--tp_size 8 \
--pp_size 1 \
Expand All @@ -90,7 +90,7 @@ torchrun --nproc_per_node=8 examples/puzzletron/mbridge_distillation/distill_hf.

- Add `--trust_remote_code` if student or teacher checkpoints need HuggingFace custom modeling code.
- The distilled Megatron-Bridge checkpoint will be saved to `--output_dir/checkpoints/iter_<train_iters>`.
- Add `--hf-export-path` (or `--hf_export_path`) to automatically export the final checkpoint to HuggingFace format after distillation. When exporting, you must also provide `--hf-model` / `--hf_model` as the HuggingFace model ID for the export template (e.g., `meta-llama/Llama-3.1-8B-Instruct`). It should match the base architecture of the student model. The exported model can be evaluated for accuracy using the evaluation tools described in the main [README.md](../README.md#evaluation).
- Add `--hf_export_path` to automatically export the final checkpoint to HuggingFace format after distillation. When exporting, you must also provide `--student_hf_model` as the HuggingFace model ID for the export template (e.g., `meta-llama/Llama-3.1-8B-Instruct`). It should match the base architecture of the student model. The exported model can be evaluated for accuracy using the evaluation tools described in the main [README.md](../README.md#evaluation).
- For production use, use larger datasets like [Nemotron-Pretraining-SFT-v1](https://huggingface.co/datasets/nvidia/Nemotron-Pretraining-SFT-v1) and train for more iterations. See the [Megatron-Bridge distillation tutorial](https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/megatron_bridge#distillation) for best practices.

## MMLU Evaluation Results
Expand Down
87 changes: 29 additions & 58 deletions examples/puzzletron/mbridge_distillation/distill_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@

import argparse
import os
import traceback
import shutil

import megatron.bridge.models.distillation_provider
import torch
from megatron.bridge import AutoBridge
from megatron.bridge.models.distillation_provider import convert_to_distillation_provider
from megatron.bridge.recipes.utils.optimizer_utils import (
distributed_fused_adam_with_cosine_annealing,
)
Expand All @@ -40,39 +40,16 @@
TokenizerConfig,
TrainingConfig,
)
from megatron.bridge.training.distill import distill
from megatron.bridge.training.post_training.distillation import ModelOptDistillConfig
from megatron.core.datasets.utils import get_blend_from_list
from megatron.core.distributed import DistributedDataParallelConfig

# Import heterogeneous bridges BEFORE AutoBridge.from_hf_pretrained() is called to ensure
# registration takes precedence. The @MegatronModelBridge.register_bridge decorator registers
# bridges when the module is imported. If both LlamaBridge and PuzzletronLlamaAnyModelBridge
# register for the same source (LlamaForCausalLM), the dispatch system uses the last registration.
#
# Note: Currently, bridges are also registered when distillation_provider is imported
# below (via mbridge/__init__.py), but this import will be needed once DistillationProvider
# is upstreamed to Megatron-Bridge and we no longer import from modelopt.torch.puzzletron.
# Import to register heterogeneous bridges (side effect)
import modelopt.torch.puzzletron.export.mbridge # noqa: F401
import modelopt.torch.utils.distributed as dist

# Use local copy of distillation_provider with fix for heterogeneous models
# TODO: Remove this local copy once fix is upstreamed to Megatron-Bridge
from modelopt.torch.puzzletron.export.mbridge.distillation_provider import (
DistillationProvider,
convert_to_distillation_provider,
)
from modelopt.torch.puzzletron.export.mbridge.export_mbridge_to_hf import (
export_to_hf_and_copy_config,
)
from modelopt.torch.utils import print_rank_0

# Patch upstream module BEFORE importing distill() so isinstance checks work with our local DistillationProvider
# This must happen before distill() is imported because distill.py imports DistillationProvider at module load time
megatron.bridge.models.distillation_provider.DistillationProvider = DistillationProvider

# Import distill() AFTER patching so it uses the patched DistillationProvider
from megatron.bridge.training.distill import distill # noqa: E402

SEED = 1234


Expand All @@ -84,13 +61,13 @@ def get_args():
"--student_hf_path",
type=str,
required=True,
help="HuggingFace model name or path for the student (e.g. Qwen/Qwen3-0.6B)",
help="HuggingFace model path for the student in puzzletron any_model format",
)
parser.add_argument(
"--teacher_hf_path",
type=str,
required=True,
help="HuggingFace model name or path for the teacher (e.g. Qwen/Qwen3-8B)",
help="HuggingFace model path for the teacher in puzzletron any_model format",
)
parser.add_argument("--trust_remote_code", action="store_true", help="Trust remote code")
# Parallelism arguments
Expand Down Expand Up @@ -145,28 +122,30 @@ def get_args():
# Export arguments
parser.add_argument(
"--hf_export_path",
"--hf-export-path",
type=str,
default=None,
help=(
"Path where to save the HuggingFace export. "
"If provided, exports checkpoint to HF format after distillation."
"If provided, exports last iteration checkpoint to HF format after distillation."
),
)
parser.add_argument(
"--hf_model",
"--hf-model",
"--student_hf_model",
type=str,
required=True,
help="HuggingFace model ID to use as template for export (e.g., meta-llama/Llama-3.1-8B-Instruct). "
"Should match the base architecture of the student model.",
required=False,
default=None,
help="HuggingFace model ID to use as template for export (e.g., Qwen/Qwen3-0.6B). "
"Should match the base architecture of the student model if --hf_export_path is provided.",
)
args = parser.parse_args()

# Sanity checks
if not args.use_mock_data and not args.data_paths:
raise ValueError("Must provide either --data_paths or set --use_mock_data.")

if args.hf_export_path and not args.student_hf_model:
raise ValueError("Must provide --student_hf_model if --hf_export_path is provided.")

print_rank_0("\n==================== Arguments ====================")
for k, v in args.__dict__.items():
print_rank_0(f"{k:<35} {v}")
Expand Down Expand Up @@ -288,42 +267,34 @@ def _build_model_provider(hf_path):

# Export to HuggingFace format if hf_export_path is provided
if args.hf_export_path:
# Wait for all ranks to finish distillation before export
if torch.distributed.is_initialized():
torch.distributed.barrier()

print_rank_0(f"Exporting final distilled ckpt to HF format to {args.hf_export_path}")
# Save rank before destroying process group (dist.rank() won't work after destruction)
is_rank_0 = dist.rank() == 0

# Destroy process group on all ranks - export_ckpt will create its own temporary one
# This prevents cleanup from hanging (cleanup tries to barrier, but rank 0 would be gone)
if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()
dist.cleanup()

# Only rank 0 exports
if is_rank_0:
try:
export_to_hf_and_copy_config(
student_hf_path=args.student_hf_path,
checkpoint_dir=checkpoint_dir,
train_iters=args.train_iters,
hf_export_path=args.hf_export_path,
hf_model=args.hf_model,
trust_remote_code=args.trust_remote_code,
)
except Exception as e:
print(f"⚠️ Export failed: {e}")
traceback.print_exc()
export_bridge = AutoBridge.from_hf_pretrained(
args.student_hf_model, trust_remote_code=args.trust_remote_code
)
export_bridge.export_ckpt(
megatron_path=f"{checkpoint_dir}/iter_{args.train_iters:07d}",
hf_path=args.hf_export_path,
show_progress=True,
strict=True,
)
Comment thread
kevalmorabia97 marked this conversation as resolved.

# save config from student_model to hf_export_path
shutil.copy(f"{args.student_hf_path}/config.json", f"{args.hf_export_path}/config.json")


if __name__ == "__main__":
dist.setup()
args = get_args()
try:
main(args)
except Exception as e:
print_rank_0(f"✗ MAIN FAILED: {type(e).__name__}: {e}")
print_rank_0(f"Traceback:\n{traceback.format_exc()}")
raise
finally:
dist.cleanup()
190 changes: 0 additions & 190 deletions modelopt/torch/puzzletron/export/mbridge/distillation_provider.py

This file was deleted.

Loading
Loading