Skip to content

Commit 3b86b96

Browse files
committed
fix(deepseek-v4): keep MoE routing scores and attention softmax in fp32
Two precision issues that compound across 61 layers and degrade backbone parity vs reference (observed during MTP parity testing in #2191): 1. sqrtsoftplus Gate cast routing scores back to bf16 immediately after computing sqrt(softplus(x.float())), losing precision for expert selection. The HashGate counterpart stays in fp32. Remove the .to(scores.dtype) cast so non-hash layers match. 2. eager_attention_with_sink ran softmax in the input dtype (bf16 under autocast). Force fp32 softmax for numerical stability, matching standard practice. Also fix a stale docstring claiming compress-ratio attention was not yet implemented — it has been wired in. Signed-off-by: khazic <khazzz1c@gmail.com>
1 parent 41786e2 commit 3b86b96

2 files changed

Lines changed: 5 additions & 4 deletions

File tree

nemo_automodel/components/models/deepseek_v4/layers.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,9 @@
4646
See ``_hc_split_sinkhorn`` for the pure-torch port of the reference mixer
4747
(ported from miles PR 1045's ``kernel/sinkhorn.py``).
4848
49-
Sliding-window / compress-ratio attention is NOT yet implemented.
50-
All layers use full causal attention regardless of compress_ratios.
49+
Compress-ratio attention (Compressor + Indexer) is wired into
50+
DeepseekV4Attention.forward for layers with compress_ratio > 0.
51+
All layers share the same sliding-window causal mask on the local KV path.
5152
"""
5253

5354
from __future__ import annotations
@@ -473,7 +474,7 @@ def eager_attention_with_sink(
473474
sinks = module.sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1)
474475
combined = torch.cat([attn_weights, sinks.to(attn_weights.dtype)], dim=-1)
475476
combined = combined - combined.max(dim=-1, keepdim=True).values
476-
probs = F.softmax(combined, dim=-1, dtype=combined.dtype)[..., :-1]
477+
probs = F.softmax(combined, dim=-1, dtype=torch.float32)[..., :-1]
477478
probs = F.dropout(probs, p=dropout, training=module.training).to(value_states.dtype)
478479
return torch.matmul(probs, value_states).transpose(1, 2).contiguous(), probs
479480

nemo_automodel/components/moe/layers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,7 @@ def forward(
355355
weights = original_scores.gather(1, indices)
356356
elif self.score_func == "sqrtsoftplus":
357357
# sqrt(softplus(x)) = sqrt(log(1 + exp(x))), used in DeepSeek V4.
358-
scores = torch.sqrt(F.softplus(scores.float())).to(scores.dtype)
358+
scores = torch.sqrt(F.softplus(scores.float()))
359359
original_scores = scores
360360

361361
if self.e_score_correction_bias is not None:

0 commit comments

Comments
 (0)