Skip to content

feat(deepseek-v4): add Multi-Token Prediction (MTP) training support#2191

Open
khazic wants to merge 12 commits into
NVIDIA-NeMo:mainfrom
khazic:khazic/feat/deepseek-v4-flash-mtp
Open

feat(deepseek-v4): add Multi-Token Prediction (MTP) training support#2191
khazic wants to merge 12 commits into
NVIDIA-NeMo:mainfrom
khazic:khazic/feat/deepseek-v4-flash-mtp

Conversation

@khazic
Copy link
Copy Markdown
Contributor

@khazic khazic commented May 8, 2026

Summary

Adds Multi-Token Prediction (MTP) training support for DeepSeek V4 (Flash). MTP layers run as standard pre-norm attention + MoE blocks (no HC machinery), with rotary embeddings shared from the main backbone. The auxiliary loss is computed via the recipe-side calculate_mtp_loss and added to the main CE loss.

What's in this PR

Model side

  • components/models/common/mtp/: model-agnostic scaffold (MTPConfig, MTPModule, roll_tensor).
  • components/models/deepseek_v4/mtp.py: V4-specific DeepseekV4MTPSublayer and build_deepseek_v4_mtp factory. compress_ratios is forced to None for MTP attention to avoid IndexError past the backbone layer count; rotary refs are stored via object.__setattr__ so they don't pollute state_dict.
  • components/models/deepseek_v4/model.py: DeepseekV4ForCausalLM now constructs self.mtp when num_nextn_predict_layers > 0 and returns a DeepseekV4CausalLMOutput dataclass (logits + optional mtp_per_depth_h).

State-dict adapter

  • from_hf runs MTP layers (layers.{N+k}.*) through the same dequantize / aggregate-experts / rename pipeline as the backbone (renumber to layers.{k}.*, run pipeline, re-prefix to mtp.layers.{k}.*). Previously MTP keys bypassed dequantization and FP8/FP4 buffers were left raw.
  • to_hf rewrites mtp.layers.{k}.* into model.layers.{N+k}.* and runs the unified split / rename / quantize path; an explicit fallback strips the leftover model. prefix for fusion-only modules (eh_proj / enorm / hnorm / final_layernorm) that have no entry in the rename table.

Recipe (recipes/llm/train_ft.py)

  • calculate_mtp_loss: per-depth CE through the configured loss class (FusedLinearCE / MaskedCE), summed with loss_scaling_factor / D weighting.
  • _forward_backward_step (non-PP branch) reads out.mtp_per_depth_h and adds the MTP loss to the main loss.
  • _mtp_is_enabled(cfg, model_parts) + setup-time guard: raises NotImplementedError if pipeline parallelism is enabled together with MTP, since the PP schedule does not currently aggregate the MTP auxiliary loss. PP + MTP is intentionally deferred to a follow-up PR.

Tests

  • test_deepseek_v4_mtp.py: config / construction / forward / backward / state-dict coverage.
  • test_dsv4_state_dict_adapter.py: MTP round-trip for layer rename, FP8 dequantize, expert aggregation, and the fusion-only fallback in both directions.
  • test_dsv4_model_smoke.py: updated to read .logits from the new dataclass output.

Overlap with #2161

PR #2161 (Nemotron V3 MTP) introduces the same calculate_mtp_loss helper and the same non-PP integration in _forward_backward_step. Those two regions are byte-identical between the branches.

This is intentional — both PRs need the same recipe-side scaffolding, and the model-agnostic MTP base (components/models/common/mtp/) is shared. When #2161 lands first, those duplicated lines will be auto-resolved on rebase, and this PR will reduce to the V4-specific changes (model, MTP sublayer, adapter, PP guard, V4 tests).

Test plan

wandb: https://wandb.ai/Nemo-automodel/huiyingl_workspace?nw=nwuserhuiyingl
image

khazic added 2 commits May 8, 2026 17:20
- Add model-agnostic MTP scaffold (MTPConfig, MTPModule, roll_tensor) under
  nemo_automodel/components/models/common/mtp/
- Add DeepseekV4MTPSublayer: pre-norm attention+MoE blocks without HC
  machinery; compress_ratios forced to None to avoid IndexError; rotary
  embeddings stored as non-registered references via object.__setattr__
- Add build_mtp_config_from_hf and build_deepseek_v4_mtp factory functions
- Add DeepseekV4CausalLMOutput dataclass so forward returns logits + optional
  mtp_per_depth_h list for MTP loss computation in train_ft.py
- Update DeepseekV4ForCausalLM.__init__ to construct MTP module when
  num_nextn_predict_layers > 0
- Update state_dict_adapter.py: from_hf splits MTP keys and converts back
- Add calculate_mtp_loss to train_ft.py and wire into _forward_backward_step
- Add 8 unit tests covering config, construction, forward, backward, state dict

Signed-off-by: khazic <khazzz1c@gmail.com>
State-dict adapter:
- from_hf: route MTP layers (layers.{N+k}.*) through dequantize +
  aggregate-experts + rename pipeline by renumbering them as layers.{k}.*
  and re-prefixing the result to mtp.layers.{k}.*. Previously MTP keys
  bypassed dequantization, leaving FP8/FP4 buffers undequantized.
- to_hf: rewrite mtp.layers.{k}.* into model.layers.{N+k}.* and run the
  unified split / rename / quantize path; strip the leftover model.
  prefix for fusion-only modules (eh_proj, enorm, hnorm, final_layernorm)
  that have no entry in the rename table.
- Drop dead _apply_inverse_rename helper.

Recipe (train_ft.py):
- Add _mtp_is_enabled(cfg, model_parts) helper that detects MTP via
  YAML override (model.config.num_nextn_predict_layers) or via an
  enabled mtp_config attribute on any constructed submodule.
- Raise NotImplementedError in setup() when PP and MTP are both
  enabled. The PP schedule does not aggregate the MTP auxiliary loss,
  so the MTP head would silently receive no gradients. PP + MTP
  wiring is intentionally deferred to a follow-up PR.
- Add TODO marker in _forward_backward_step PP branch pointing at the
  same follow-up.

Tests:
- Fix test_forward_shape / test_backward to read .logits from the new
  DeepseekV4CausalLMOutput dataclass returned by forward.
- Add MTP round-trip coverage: layer rename, FP8 dequantize, expert
  aggregation, to_hf rename / split / quantize, and the fusion-only
  fallback for both directions.

Signed-off-by: khazic <khazzz1c@gmail.com>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 8, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@HuiyingLi
Copy link
Copy Markdown
Contributor

/ok to test 3990e0c

DeepSeek-V4 HF safetensors emit MTP layer keys in two forms:

* ``model.layers.{N+k}.*`` for the standard self_attn / mlp / norms
  (carries the canonical ``model.`` prefix like every backbone block).
* ``layers.{N+k}.*`` for V4's MTP-only fusion modules (``eh_proj``,
  ``enorm``, ``hnorm``, ``final_layernorm``) which sit outside the
  HF ``model.`` namespace.

The previous split regex (``r"^layers\.(\d+)\."``) only matched the
unprefixed form, so the prefixed self_attn / mlp / norms keys silently
fell into the backbone bucket. They were then renamed by the standard
backbone pipeline and ended up at ``model.layers.{N+k}.*`` in the
converted state dict — but the model only has ``model.layers.{0..N-1}``,
so DCP load dropped them and ``model.mtp.layers[*].*`` started from
random init. End result: MTP-enabled training silently ran without
loading the MTP head weights from the HF checkpoint.

Repro on a tiny config (num_hidden_layers=2, num_nextn_predict_layers=1):

    Model expects 38 mtp.* state_dict keys
    adapter.from_hf produced  4 mtp.* keys (the 4 unprefixed fusion ones)
    35 mtp.* keys MISSING, 24 keys leaked to model.layers.2.* (dropped)

Make the regex prefix-tolerant (``^(model\.)?layers\.(\d+)\.``) and use
the second capture group as the layer index. After the fix, the same
repro produces 0 missing / 0 extra, and a save→load round-trip via
to_hf -> from_hf reconstructs every mtp.* key the model exposes.

Add a regression test ``test_from_hf_renames_mtp_layer_with_model_prefix``
that exercises the prefixed form so this cannot silently regress again.

Signed-off-by: khazic <khazzz1c@gmail.com>
@HuiyingLi
Copy link
Copy Markdown
Contributor

/ok to test c228ec4

@svcnvidia-nemo-ci svcnvidia-nemo-ci removed the waiting-on-customer Waiting on the original author to respond label May 10, 2026
HuiyingLi pushed a commit that referenced this pull request May 10, 2026
…32 (#2201)

Two precision issues that compound across 61 layers and degrade
backbone parity vs reference (observed during MTP parity testing
in #2191):

1. sqrtsoftplus Gate cast routing scores back to bf16 immediately
   after computing sqrt(softplus(x.float())), losing precision for
   expert selection. The HashGate counterpart stays in fp32. Remove
   the .to(scores.dtype) cast so non-hash layers match.

2. eager_attention_with_sink ran softmax in the input dtype (bf16
   under autocast). Force fp32 softmax for numerical stability,
   matching standard practice.

Also fix a stale docstring claiming compress-ratio attention was
not yet implemented — it has been wired in.

Signed-off-by: khazic <khazzz1c@gmail.com>
@hemildesai
Copy link
Copy Markdown
Contributor

/claude review

Comment thread nemo_automodel/recipes/llm/train_ft.py Outdated
Comment on lines 166 to 197
def _mtp_is_enabled(cfg, model_parts) -> bool:
"""Return True if Multi-Token Prediction is enabled for this run.

Checks both signals because either may be missing depending on how the
model was constructed:

* YAML override / explicit DeepseekV4Config: the
``model.config.num_nextn_predict_layers`` field is the user-facing
knob and is present on the cfg before any model is built.
* Constructed model: V4's ``ForCausalLM.__init__`` materializes
``self.mtp_config``. Walking ``modules()`` catches it on the root
or on any submodule that retained the attribute after wrapping.

The module walk alone isn't sufficient: pipeline-parallel wrapping can
replace the V4 root with a stage container that no longer exposes
``mtp_config``, in which case only the cfg lookup catches MTP.
"""
n = int(cfg.get("model.config.num_nextn_predict_layers", 0) or 0)
if n > 0:
return True
for mp in model_parts:
if mp is None:
continue
for sub in mp.modules():
mc = getattr(sub, "mtp_config", None)
if mc is not None and getattr(mc, "enabled", False):
return True
return False


def build_model(
cfg_model,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_mtp_is_enabled is defined but never called anywhere in the codebase. If this is scaffolding for a follow-up PR, consider deferring it to that PR to avoid dead code on main.

claude[bot]
claude Bot previously approved these changes May 11, 2026
Copy link
Copy Markdown
Contributor

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM — clean refactor from hardcoded DSV4 pipeline checks to a hook-based system, with comprehensive MTP support wired through state dict, forward, loss, and PP paths. Test coverage is thorough.

One note: _mtp_is_enabled (train_ft.py:166-197) appears unused — flagged inline.

Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
@HuiyingLi
Copy link
Copy Markdown
Contributor

/ok to test 95cde23

HuiyingLi added 2 commits May 11, 2026 17:19
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
@HuiyingLi
Copy link
Copy Markdown
Contributor

/claude review

@HuiyingLi
Copy link
Copy Markdown
Contributor

/ok to test 52fe6cd

Comment on lines +204 to +239
future-token embedding (typically the model's input embedding
layer).
position_ids: Position ids matching ``input_ids``. When supplied,
rolled cumulatively per depth in lockstep with ``input_ids``
(so slot ``t`` carries the original position of the rolled
token) and forwarded to each sublayer via ``block_kwargs``.
Required for RoPE-using sublayers; ignored by sublayers that
don't consume it.
**block_kwargs: Forwarded to each sublayer's ``__call__`` (e.g.
``attention_mask``).

Returns:
List of length ``num_depths`` containing the hidden state
produced at each depth.
"""
D = self.num_depths
P = self.pattern_length
per_depth_h: list[torch.Tensor] = []
cur_input_ids = input_ids
cur_position_ids = position_ids
for d in range(D):
cur_input_ids = roll_tensor(cur_input_ids, shifts=-1, dim=-1)
if cur_position_ids is not None:
cur_position_ids = roll_tensor(cur_position_ids, shifts=-1, dim=-1)

decoder_input = embed_fn(cur_input_ids)
for s in range(P):
sublayer = self.layers[d * P + s]
kwargs = dict(block_kwargs)
if cur_position_ids is not None:
kwargs["position_ids"] = cur_position_ids
if s == 0:
kwargs["embed_input"] = decoder_input
hidden_states = sublayer(hidden_states, **kwargs)
per_depth_h.append(hidden_states)
return per_depth_h
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: MTPModule, MTPConfig, roll_tensor, and parse_mtp_layer_pattern are added here but have no direct unit tests in this PR. The DSV4 tests exercise MTPConfig and roll_tensor indirectly (via build_mtp_config_from_hf and DeepseekV4MTPModule), but MTPModule itself and parse_mtp_layer_pattern have zero coverage.

The PR description notes that PR #2161 (Nemotron V3) shares this code and will presumably add tests — just flagging in case that PR is delayed or the merge order changes.

Comment on lines +920 to +930
logits = getattr(output, "logits", output)
hidden_states = get_final_hidden_states(output)
mtp_per_depth_h = getattr(output, "mtp_per_depth_h", None)
scaling_factor = getattr(output, "mtp_loss_scaling_factor", get_mtp_loss_scaling_factor(self.model))

if isinstance(output, tuple):
logits = output[0]
hidden_states = None
mtp_per_depth_h = list(output[1:]) if len(output) > 1 else None
scaling_factor = get_mtp_loss_scaling_factor(self.model)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The getattr extractions on lines 920-923 are computed unconditionally, then fully overwritten when output is a tuple (lines 925-929). This isn't a bug (all the getattr calls are safe on tuples), but the ordering is misleading — a reader might think the pre-tuple values feed into the tuple branch.

Consider checking isinstance(output, tuple) first:

def forward(self, output, labels: torch.Tensor) -> torch.Tensor:
    if isinstance(output, tuple):
        logits = output[0]
        hidden_states = None
        mtp_per_depth_h = list(output[1:]) if len(output) > 1 else None
        scaling_factor = get_mtp_loss_scaling_factor(self.model)
    else:
        logits = getattr(output, "logits", output)
        hidden_states = get_final_hidden_states(output)
        mtp_per_depth_h = getattr(output, "mtp_per_depth_h", None)
        scaling_factor = getattr(output, "mtp_loss_scaling_factor", get_mtp_loss_scaling_factor(self.model))
    ...

Copy link
Copy Markdown
Contributor

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review Summary

Solid PR — well-structured separation between the model-agnostic MTP scaffold (common/mtp/) and the DSV4-specific implementation, with thorough test coverage on the DSV4 side.

What looks good

  • State dict adapter: Handles all three key formats (native mtp.{k}.*, legacy layers.{N+k}.*, prefixed model.layers.{N+k}.*) and correctly drops MTP keys when num_nextn_predict_layers=0. Round-trip tests cover FP8 dequantize, expert aggregation, and the fusion-only fallback.
  • PP refactoring: Replacing hardcoded model_type == "deepseek_v4" checks in functional.py with model-provided hooks (get_pipeline_stage_metas, customize_pipeline_stage_modules) is a clean generalization.
  • Rotary sharing: object.__setattr__ for _rotary_emb / _rotary_emb_compress correctly avoids polluting state_dict — tested explicitly.
  • MoE parallelizer update: _iter_transformer_and_mtp_blocks and _get_moe_module cleanly extend EP/FSDP/checkpointing to MTP blocks without duplicating the traversal logic.

Minor observations (non-blocking)

Left two inline comments:

  1. PipelineCausalLMLoss.forward ordering — the isinstance(output, tuple) check should come first to avoid computing values that are immediately overwritten.
  2. Generic MTPModule test coverageMTPModule, roll_tensor, and parse_mtp_layer_pattern in common/mtp/ have no direct tests. They're exercised indirectly through DSV4 tests and presumably covered by the sibling Nemotron-V3 PR (#2161).

Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
@HuiyingLi
Copy link
Copy Markdown
Contributor

/ok to test a4963b2

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants