Skip to content

Commit d620b1c

Browse files
committed
fix(recipes): correct validation loss averaging in LLM KD recipe
_forward_backward_step returns per-token-averaged losses, but the validation loop accumulated them without un-averaging first. This caused val_loss to be divided twice (yielding an artificially small value) and ce_loss/kd_loss to be reported as raw sums instead of per-token means. Multiply each per-batch loss by its num_label_tokens before accumulating, then divide by total_num_label_tokens at the end for a proper weighted average — matching the pattern used in the parent FinetuneRecipe. Signed-off-by: khazic <khazzz1c@gmail.com>
1 parent 174ba8d commit d620b1c

1 file changed

Lines changed: 22 additions & 11 deletions

File tree

  • nemo_automodel/recipes/llm

nemo_automodel/recipes/llm/kd.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -797,9 +797,9 @@ def _run_validation_epoch(self, val_dataloader):
797797
for mp in self.model_parts:
798798
mp.eval()
799799

800-
total_loss = torch.tensor(0.0, dtype=torch.float32, device=self.dist_env.device)
801-
ce_loss = torch.tensor(0.0, dtype=torch.float32, device=self.dist_env.device)
802-
kd_loss = torch.tensor(0.0, dtype=torch.float32, device=self.dist_env.device)
800+
total_loss = 0.0
801+
total_ce_loss = 0.0
802+
total_kd_loss = 0.0
803803
total_num_label_tokens = 0
804804

805805
for batch in val_dataloader:
@@ -811,24 +811,35 @@ def _run_validation_epoch(self, val_dataloader):
811811
num_batches=1,
812812
is_train=False,
813813
)
814+
# _forward_backward_step returns per-token-averaged losses.
815+
# Multiply back by num_label_tokens to get the raw sum for
816+
# correct weighted averaging across batches.
817+
total_loss += local_loss.item() * num_label_tokens
818+
total_ce_loss += _ce_loss.item() * num_label_tokens
819+
total_kd_loss += _kd_loss.item() * num_label_tokens
814820
total_num_label_tokens += num_label_tokens
815-
ce_loss += _ce_loss
816-
kd_loss += _kd_loss
817-
total_loss += local_loss
818821

819-
total_loss = self._dp_allreduce(total_loss, include_cp=True).item()
820-
ce_loss = self._dp_allreduce(ce_loss, include_cp=True).item()
821-
kd_loss = self._dp_allreduce(kd_loss, include_cp=True).item()
822+
total_loss = self._dp_allreduce(
823+
torch.tensor(total_loss, dtype=torch.float32, device=self.dist_env.device), include_cp=True
824+
).item()
825+
total_ce_loss = self._dp_allreduce(
826+
torch.tensor(total_ce_loss, dtype=torch.float32, device=self.dist_env.device), include_cp=True
827+
).item()
828+
total_kd_loss = self._dp_allreduce(
829+
torch.tensor(total_kd_loss, dtype=torch.float32, device=self.dist_env.device), include_cp=True
830+
).item()
822831
total_num_label_tokens = self._dp_allreduce(torch.tensor(total_num_label_tokens, dtype=torch.long)).item()
823832

824833
val_loss = total_loss / max(total_num_label_tokens, 1e-8)
834+
val_ce_loss = total_ce_loss / max(total_num_label_tokens, 1e-8)
835+
val_kd_loss = total_kd_loss / max(total_num_label_tokens, 1e-8)
825836
return MetricsSample(
826837
step=self.step_scheduler.step,
827838
epoch=self.step_scheduler.epoch,
828839
metrics={
829840
"val_loss": val_loss,
830-
"ce_loss": ce_loss,
831-
"kd_loss": kd_loss,
841+
"ce_loss": val_ce_loss,
842+
"kd_loss": val_kd_loss,
832843
"lr": self.optimizer[0].param_groups[0]["lr"],
833844
"num_label_tokens": total_num_label_tokens,
834845
"mem": torch.cuda.max_memory_allocated() / 1024**3,

0 commit comments

Comments
 (0)