feat: add pipeline parallelism support for knowledge distillation#1500
feat: add pipeline parallelism support for knowledge distillation#1500
Conversation
Add PP support to KnowledgeDistillationRecipeForNextTokenPrediction: - Add `_build_teacher_model_with_pp`: builds the teacher as an AutoPipeline mirroring the student's PP config, with a capture closure that stores last-stage logits in `_teacher_logits_capture` after each eval pass. - Add `_make_pp_kd_loss_wrapper`: injects into the student schedule's `_loss_fn`; reads `_current_teacher_logits` set by the teacher eval pass and returns (1-ratio)*ce + ratio*kd. - Add `_forward_backward_step_pp`: runs teacher eval first to capture logits, then runs the student step/eval. - Add `_run_train_optim_step_pp`: full PP training step with grad accumulation, norm clipping, and cross-rank loss aggregation for logging. - Add `run_train_validation_loop` PP-aware override (validation skipped). - `setup()` wires up all of the above; removes the old ValueError. Bug fixes applied from reference implementation: - Fix corrupted `has_packed_sequence` kwarg in teacher PP builder. - Guard `_forward_backward_step` against being called when PP is enabled. - Skip CE computation when `kd_ratio >= 1.0` (avoid wasted forward pass). - Add missing `metric_logger_train.log(log_data)` in `log_train_metrics`. - Conditionalize ce_loss display in log strings on `kd_ratio < 1.0`. Known limitation: when pp_microbatch_size < pp_batch_size only the last microbatch's teacher logits are retained; set pp_microbatch_size == pp_batch_size when using PP with KD. Add 7 unit tests covering PP-specific logic (capture closure, wrapper combination math, kd_ratio edge cases, buffer accumulation).
|
@akoumpa for visibility |
|
/ok to test 134a577 |
|
/ok to test 2d773fc |
|
/ok to test e6da11a |
|
/claude review |
Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
|
/claude review |
| ce_tensor = ( | ||
| torch.stack(self._ce_loss_buffer).sum() | ||
| if self._ce_loss_buffer | ||
| else torch.tensor(0.0, device=self.dist_env.device) | ||
| ) | ||
| kd_tensor = ( | ||
| torch.stack(self._kd_loss_buffer).sum() | ||
| if self._kd_loss_buffer | ||
| else torch.tensor(0.0, device=self.dist_env.device) | ||
| ) | ||
| ce_tensor = self._dp_allreduce(ce_tensor, include_cp=True) | ||
| kd_tensor = self._dp_allreduce(kd_tensor, include_cp=True) | ||
| ce_tensor = ce_tensor.to(self.dist_env.device) | ||
| kd_tensor = kd_tensor.to(self.dist_env.device) | ||
| if self.dist_env.rank == src_rank and not self.dist_env.is_main: | ||
| torch.distributed.send(ce_tensor, dst=0) | ||
| torch.distributed.send(kd_tensor, dst=0) | ||
| elif self.dist_env.is_main and self.dist_env.rank != src_rank: | ||
| torch.distributed.recv(ce_tensor, src=src_rank) | ||
| torch.distributed.recv(kd_tensor, src=src_rank) | ||
| ce_loss = ce_tensor.cpu().item() | ||
| kd_loss = kd_tensor.cpu().item() |
There was a problem hiding this comment.
Same normalization mismatch applies to the logged KD/CE metrics in the PP path. Since ce_loss is a sum and kd_loss is a mean (before the fix above), the values in _ce_loss_buffer and _kd_loss_buffer are on different scales. Additionally, unlike reporting_loss which is divided by num_label_tokens on line 683, ce_loss and kd_loss here are not normalized at all — they're raw allreduced values.
In the non-PP path (line 586-587), both buffers contain per-token-normalized values (since num_label_tokens is passed to the loss functions). So the PP and non-PP paths log ce_loss/kd_loss on different scales, making them incomparable across runs with different parallelism configs.
After fixing num_batch_labels=1 above, both buffers will contain sums, and you'd want to divide them by num_label_tokens here (similar to line 683 for reporting_loss).
Co-authored-by: claude[bot] <209825114+claude[bot]@users.noreply.github.com>
Co-authored-by: claude[bot] <209825114+claude[bot]@users.noreply.github.com>
|
/ok to test f94012a |
Add PP support to KnowledgeDistillationRecipeForNextTokenPrediction:
_build_teacher_model_with_pp: builds the teacher as an AutoPipeline mirroring the student's PP config, with a capture closure that stores last-stage logits in_teacher_logits_captureafter each eval pass._make_pp_kd_loss_wrapper: injects into the student schedule's_loss_fn; reads_current_teacher_logitsset by the teacher eval pass and returns (1-ratio)ce + ratiokd._forward_backward_step_pp: runs teacher eval first to capture logits, then runs the student step/eval._run_train_optim_step_pp: full PP training step with grad accumulation, norm clipping, and cross-rank loss aggregation for logging.run_train_validation_loopPP-aware override (validation skipped).setup()wires up all of the above; removes the old ValueError.Bug fixes applied from reference implementation:
has_packed_sequencekwarg in teacher PP builder._forward_backward_stepagainst being called when PP is enabled.kd_ratio >= 1.0(avoid wasted forward pass).metric_logger_train.log(log_data)inlog_train_metrics.kd_ratio < 1.0.Known limitation: when pp_microbatch_size < pp_batch_size only the last microbatch's teacher logits are retained; set pp_microbatch_size == pp_batch_size when using PP with KD.
Add 7 unit tests covering PP-specific logic (capture closure, wrapper combination math, kd_ratio edge cases, buffer accumulation).