Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

Documenting **breaking** configuration changes — renamed, removed, or moved fields that require users to update existing configs.

- **`trainer.loss` (default loss)**: Replaced DPPO-Binary TV + Kimi-K2.5 KL with **IcePop** (INTELLECT-3, [arxiv](https://arxiv.org/abs/2512.16144)) as the default loss. Removed `dppo_mask_low`, `dppo_mask_high`, and `kl_tau`. Added `icepop_ratio_low` (default: `0.2`, α) and `icepop_ratio_high` (default: `5.0`, β) for double-sided importance-ratio masking — tokens outside `[α, β]` are dropped, not clipped. The KL penalty term is gone; the double-sided ratio mask is what keeps updates inside the trust region. `adv_tau` and `teacher_tau` are unchanged. (2026-05-02)
- **`orchestrator.advantage.length_shaping` → `orchestrator.advantage.length_penalty`**: The boolean `length_shaping` flag has been replaced by `length_penalty: Literal["tokens", "turns"] | None` (default: `None`). `length_shaping = true` becomes `length_penalty = "tokens"`; `length_shaping = false` becomes `length_penalty = None`. The new `"turns"` option applies the same correctness-gated efficiency shaping using trajectory turn count instead of completion-token count. (2026-05-01)
- **`AdvantageInputs` API**: Replaced the `rewards`/`completion_lengths`/`num_turns` tensor fields with a single `rollouts: list[list[vf.RolloutOutput]]` (grouped by problem). Custom advantage functions can now access any rollout metadata. Existing custom advantages must update their signatures and extract per-rollout fields directly (e.g. `torch.tensor([[r["reward"] for r in g] for g in inputs.rollouts])`). (2026-05-01)
- **`orchestrator.teacher_rollout_model` now requires `orchestrator.use_sft_loss = true`**: External teacher rollout configs no longer rely on `trainer.loss.type = "sft"` to select SFT loss. Existing hard-distill configs must set `orchestrator.use_sft_loss = true` alongside `orchestrator.teacher_rollout_model`; the orchestrator validates the pair and stamps training samples with the per-batch SFT loss bool. (2026-04-26)
Expand Down
11 changes: 7 additions & 4 deletions src/prime_rl/configs/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,15 +649,18 @@ class CheckpointConfig(BaseConfig):


class DefaultLossConfig(BaseModel):
"""Config for the default loss."""
"""Config for the default loss (IcePop, INTELLECT-3 / arXiv:2512.16144)."""

type: Literal["default"] = "default"

dppo_mask_low: Annotated[float, Field(ge=0, description="The low threshold for masking tokens.")] = 0.2
dppo_mask_high: Annotated[float, Field(ge=0, description="The high threshold for masking tokens.")] = 0.2
icepop_ratio_low: Annotated[
float, Field(ge=0, description="Lower bound α on the importance ratio. Tokens below are dropped.")
] = 0.2
icepop_ratio_high: Annotated[
float, Field(ge=0, description="Upper bound β on the importance ratio. Tokens above are dropped.")
] = 5.0
adv_tau: Annotated[float, Field(ge=0, description="The tau for advantages.")] = 1.0
teacher_tau: Annotated[float, Field(ge=0, description="The tau for teacher logprobs.")] = 0.0
kl_tau: Annotated[float, Field(ge=0, description="The tau for KL divergence.")] = 1e-3


class SFTLossConfig(BaseModel):
Expand Down
35 changes: 12 additions & 23 deletions src/prime_rl/trainer/rl/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,38 +106,28 @@ def _safe_mean(values: Tensor, mask: Tensor) -> Tensor:

def default_loss_fn(inputs: LossInputs, loss_config: DefaultLossConfig) -> LossOutputs:
"""
DPPO+KL loss, combining:
- DPPO-Binary TV Loss (https://arxiv.org/pdf/2602.04879)
- Kimi-K2.5 KL Loss (https://arxiv.org/pdf/2602.02276)

The mask is conditioned on the advantage sign: for positive advantages,
we mask tokens whose probability increased too much (trust region violation
in the upweight direction); for negative advantages, we mask tokens whose
probability decreased too much (trust region violation in the downweight
direction).
IcePop loss (INTELLECT-3, https://arxiv.org/abs/2512.16144).

Token-level masked importance sampling: tokens whose ratio
π_train / π_infer falls outside [α, β] are dropped (gradient set
to 0), not clipped. No KL penalty — the double-sided mask is what
keeps the update inside the trust region.
"""
trainer_logprobs = inputs.trainer_logprobs
inference_logprobs = inputs.inference_logprobs
teacher_logprobs = inputs.teacher_logprobs
advantages = inputs.advantages
loss_mask = inputs.loss_mask

trainer_probs = torch.exp(trainer_logprobs)
inference_probs = torch.exp(inference_logprobs)
probs_diff = trainer_probs - inference_probs
dppo_invalid_mask_high = probs_diff > loss_config.dppo_mask_high
dppo_invalid_mask_low = probs_diff < -loss_config.dppo_mask_low
dppo_invalid_mask = torch.where(advantages > 0, dppo_invalid_mask_high, dppo_invalid_mask_low)

is_masked = dppo_invalid_mask
is_masked_high = (advantages > 0) & dppo_invalid_mask_high
is_masked_low = (advantages < 0) & dppo_invalid_mask_low
keep_mask = loss_mask & ~is_masked

log_importance_ratio = trainer_logprobs - inference_logprobs
importance_ratio = torch.exp(log_importance_ratio)
mismatch_kl = importance_ratio - log_importance_ratio - 1

is_masked_low = importance_ratio.detach() < loss_config.icepop_ratio_low
is_masked_high = importance_ratio.detach() > loss_config.icepop_ratio_high
is_masked = is_masked_low | is_masked_high
keep_mask = loss_mask & ~is_masked

advantages = loss_config.adv_tau * advantages
if teacher_logprobs is not None:
teacher_kl = teacher_logprobs - trainer_logprobs
Expand All @@ -146,8 +136,7 @@ def default_loss_fn(inputs: LossInputs, loss_config: DefaultLossConfig) -> LossO
teacher_kl = None

pg_loss = keep_mask * advantages * importance_ratio
kl_loss = loss_mask * log_importance_ratio**2
loss = (-pg_loss + loss_config.kl_tau * kl_loss).sum()
loss = (-pg_loss).sum()

metrics = {
"mismatch_kl": _safe_mean(mismatch_kl, loss_mask), # all trainable tokens
Expand Down
Loading