Skip to content

feat: add pipeline parallelism support for knowledge distillation#1500

Merged
akoumpa merged 10 commits intomainfrom
ssameni/feat_kd_pp
Mar 20, 2026
Merged

feat: add pipeline parallelism support for knowledge distillation#1500
akoumpa merged 10 commits intomainfrom
ssameni/feat_kd_pp

Conversation

@Separius
Copy link
Copy Markdown
Contributor

@Separius Separius commented Mar 9, 2026

Add PP support to KnowledgeDistillationRecipeForNextTokenPrediction:

  • Add _build_teacher_model_with_pp: builds the teacher as an AutoPipeline mirroring the student's PP config, with a capture closure that stores last-stage logits in _teacher_logits_capture after each eval pass.
  • Add _make_pp_kd_loss_wrapper: injects into the student schedule's _loss_fn; reads _current_teacher_logits set by the teacher eval pass and returns (1-ratio)ce + ratiokd.
  • Add _forward_backward_step_pp: runs teacher eval first to capture logits, then runs the student step/eval.
  • Add _run_train_optim_step_pp: full PP training step with grad accumulation, norm clipping, and cross-rank loss aggregation for logging.
  • Add run_train_validation_loop PP-aware override (validation skipped).
  • setup() wires up all of the above; removes the old ValueError.

Bug fixes applied from reference implementation:

  • Fix corrupted has_packed_sequence kwarg in teacher PP builder.
  • Guard _forward_backward_step against being called when PP is enabled.
  • Skip CE computation when kd_ratio >= 1.0 (avoid wasted forward pass).
  • Add missing metric_logger_train.log(log_data) in log_train_metrics.
  • Conditionalize ce_loss display in log strings on kd_ratio < 1.0.

Known limitation: when pp_microbatch_size < pp_batch_size only the last microbatch's teacher logits are retained; set pp_microbatch_size == pp_batch_size when using PP with KD.

Add 7 unit tests covering PP-specific logic (capture closure, wrapper combination math, kd_ratio edge cases, buffer accumulation).

Add PP support to KnowledgeDistillationRecipeForNextTokenPrediction:

- Add `_build_teacher_model_with_pp`: builds the teacher as an AutoPipeline
  mirroring the student's PP config, with a capture closure that stores
  last-stage logits in `_teacher_logits_capture` after each eval pass.
- Add `_make_pp_kd_loss_wrapper`: injects into the student schedule's
  `_loss_fn`; reads `_current_teacher_logits` set by the teacher eval pass
  and returns (1-ratio)*ce + ratio*kd.
- Add `_forward_backward_step_pp`: runs teacher eval first to capture
  logits, then runs the student step/eval.
- Add `_run_train_optim_step_pp`: full PP training step with grad
  accumulation, norm clipping, and cross-rank loss aggregation for logging.
- Add `run_train_validation_loop` PP-aware override (validation skipped).
- `setup()` wires up all of the above; removes the old ValueError.

Bug fixes applied from reference implementation:
- Fix corrupted `has_packed_sequence` kwarg in teacher PP builder.
- Guard `_forward_backward_step` against being called when PP is enabled.
- Skip CE computation when `kd_ratio >= 1.0` (avoid wasted forward pass).
- Add missing `metric_logger_train.log(log_data)` in `log_train_metrics`.
- Conditionalize ce_loss display in log strings on `kd_ratio < 1.0`.

Known limitation: when pp_microbatch_size < pp_batch_size only the last
microbatch's teacher logits are retained; set pp_microbatch_size ==
pp_batch_size when using PP with KD.

Add 7 unit tests covering PP-specific logic (capture closure, wrapper
combination math, kd_ratio edge cases, buffer accumulation).
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot bot commented Mar 9, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@Separius
Copy link
Copy Markdown
Contributor Author

Separius commented Mar 9, 2026

@akoumpa for visibility

@Separius Separius mentioned this pull request Mar 9, 2026
3 tasks
@akoumpa
Copy link
Copy Markdown
Contributor

akoumpa commented Mar 9, 2026

/ok to test 134a577

@akoumpa
Copy link
Copy Markdown
Contributor

akoumpa commented Mar 11, 2026

/ok to test 2d773fc

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
@akoumpa
Copy link
Copy Markdown
Contributor

akoumpa commented Mar 17, 2026

/ok to test e6da11a

@akoumpa
Copy link
Copy Markdown
Contributor

akoumpa commented Mar 18, 2026

/claude review

Comment thread nemo_automodel/recipes/llm/kd.py Outdated
Comment thread tests/unit_tests/loss/test_kd_loss.py
Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
@akoumpa
Copy link
Copy Markdown
Contributor

akoumpa commented Mar 18, 2026

/claude review

Comment thread nemo_automodel/recipes/llm/kd.py
Comment on lines +696 to +717
ce_tensor = (
torch.stack(self._ce_loss_buffer).sum()
if self._ce_loss_buffer
else torch.tensor(0.0, device=self.dist_env.device)
)
kd_tensor = (
torch.stack(self._kd_loss_buffer).sum()
if self._kd_loss_buffer
else torch.tensor(0.0, device=self.dist_env.device)
)
ce_tensor = self._dp_allreduce(ce_tensor, include_cp=True)
kd_tensor = self._dp_allreduce(kd_tensor, include_cp=True)
ce_tensor = ce_tensor.to(self.dist_env.device)
kd_tensor = kd_tensor.to(self.dist_env.device)
if self.dist_env.rank == src_rank and not self.dist_env.is_main:
torch.distributed.send(ce_tensor, dst=0)
torch.distributed.send(kd_tensor, dst=0)
elif self.dist_env.is_main and self.dist_env.rank != src_rank:
torch.distributed.recv(ce_tensor, src=src_rank)
torch.distributed.recv(kd_tensor, src=src_rank)
ce_loss = ce_tensor.cpu().item()
kd_loss = kd_tensor.cpu().item()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same normalization mismatch applies to the logged KD/CE metrics in the PP path. Since ce_loss is a sum and kd_loss is a mean (before the fix above), the values in _ce_loss_buffer and _kd_loss_buffer are on different scales. Additionally, unlike reporting_loss which is divided by num_label_tokens on line 683, ce_loss and kd_loss here are not normalized at all — they're raw allreduced values.

In the non-PP path (line 586-587), both buffers contain per-token-normalized values (since num_label_tokens is passed to the loss functions). So the PP and non-PP paths log ce_loss/kd_loss on different scales, making them incomparable across runs with different parallelism configs.

After fixing num_batch_labels=1 above, both buffers will contain sums, and you'd want to divide them by num_label_tokens here (similar to line 683 for reporting_loss).

Comment thread nemo_automodel/recipes/llm/kd.py Outdated
akoumpa and others added 3 commits March 17, 2026 21:11
Co-authored-by: claude[bot] <209825114+claude[bot]@users.noreply.github.com>
Co-authored-by: claude[bot] <209825114+claude[bot]@users.noreply.github.com>
Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
@akoumpa
Copy link
Copy Markdown
Contributor

akoumpa commented Mar 18, 2026

/ok to test f94012a

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants