Skip to content

Commit 7c5a5b3

Browse files
committed
Add DeepSeek V4 MTP pipeline support
- Keep DSV4 MTP, HC head, and rotary-compress dependencies on the right PP stages - Propagate shifted MTP embeddings through PP stages to the final stage - Allow DSV4 MTP to consume precomputed embeddings - Wire PP loss handling for MTP auxiliary CE with configurable scaling - Enable MTP in the DSV4 flash finetuning recipe with scaling factor 0.1 - Add unit coverage for DSV4 MTP PP behavior and stage metadata Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
1 parent c985123 commit 7c5a5b3

10 files changed

Lines changed: 490 additions & 307 deletions

File tree

examples/llm_finetune/deepseek_v4/deepseek_v4_flash_hellaswag.yaml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,12 @@ model:
6161
_target_: nemo_automodel.components.models.deepseek_v4.config.DeepseekV4Config.from_pretrained
6262
pretrained_model_name_or_path: deepseek-ai/DeepSeek-V4-Flash
6363
name_or_path: deepseek-ai/DeepSeek-V4-Flash
64-
num_nextn_predict_layers: 0
64+
num_nextn_predict_layers: 1
6565
trust_remote_code: false
6666
load_base_model: true
67+
# DeepSeek-V4 uses 0.3 for most pretraining, then 0.1 during LR decay.
68+
# Keep finetuning/RL conservative unless explicitly reproducing pretraining.
69+
mtp_loss_scaling_factor: 0.1
6770
backend:
6871
_target_: nemo_automodel.components.models.common.BackendConfig
6972
attn: sdpa

nemo_automodel/components/models/common/mtp/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,15 @@
2727
from nemo_automodel.components.models.common.mtp.mtp import (
2828
MTPConfig,
2929
MTPModule,
30+
get_mtp_loss_scaling_factor,
3031
parse_mtp_layer_pattern,
3132
roll_tensor,
3233
)
3334

3435
__all__ = [
3536
"MTPConfig",
3637
"MTPModule",
38+
"get_mtp_loss_scaling_factor",
3739
"parse_mtp_layer_pattern",
3840
"roll_tensor",
3941
]

nemo_automodel/components/models/common/mtp/mtp.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,14 @@ def roll_tensor(t: torch.Tensor, shifts: int = -1, dim: int = -1) -> torch.Tenso
8686
return rolled
8787

8888

89+
def get_mtp_loss_scaling_factor(model: nn.Module, default: float = 0.1) -> float:
90+
"""Return the model's configured MTP auxiliary-loss scaling factor."""
91+
mtp_config = getattr(model, "mtp_config", None)
92+
if mtp_config is not None:
93+
return float(getattr(mtp_config, "loss_scaling_factor", default))
94+
return default
95+
96+
8997
@dataclass
9098
class MTPConfig:
9199
"""Runtime configuration for the MTP block.

nemo_automodel/components/models/deepseek_v4/model.py

Lines changed: 119 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -599,15 +599,104 @@ def get_output_embeddings(self):
599599
def set_output_embeddings(self, new_embeddings):
600600
self.lm_head = new_embeddings
601601

602+
def customize_pipeline_stage_modules(
603+
self,
604+
module_names_per_stage: list[list[str]],
605+
*,
606+
layers_prefix: str,
607+
text_model: nn.Module | None = None,
608+
) -> list[list[str]]:
609+
"""Keep DSV4 non-layer PP dependencies with the stages that need them."""
610+
611+
text_model = text_model or self.model
612+
stage_modules = [list(modules) for modules in module_names_per_stage]
613+
614+
def append_once(modules: list[str], fqn: str) -> None:
615+
if fqn not in modules:
616+
modules.append(fqn)
617+
618+
if getattr(text_model, "rotary_emb_compress", None) is not None:
619+
for modules in stage_modules:
620+
append_once(modules, f"{layers_prefix}rotary_emb_compress")
621+
if getattr(text_model, "hc_head", None) is not None:
622+
append_once(stage_modules[-1], f"{layers_prefix}hc_head")
623+
if self.mtp is not None:
624+
append_once(stage_modules[-1], "mtp")
625+
626+
return stage_modules
627+
628+
def get_pipeline_stage_metas(
629+
self,
630+
*,
631+
is_first: bool,
632+
microbatch_size: int,
633+
seq_len: int,
634+
dtype: torch.dtype,
635+
) -> tuple[tuple[torch.Tensor, ...], tuple[torch.Tensor, ...]]:
636+
"""Return PP input/output meta tensors for DSV4's HC and MTP contract."""
637+
638+
hidden_shape = (microbatch_size, seq_len, self.config.hidden_size)
639+
hc_hidden_shape = (microbatch_size, seq_len, self.config.hc_mult, self.config.hidden_size)
640+
mtp_depth = int(getattr(self.mtp_config, "num_layers", 0) or 0)
641+
642+
def meta(shape: tuple[int, ...]) -> torch.Tensor:
643+
return torch.empty(*shape, device="meta", dtype=dtype)
644+
645+
def append_mtp_metas(primary: torch.Tensor) -> tuple[torch.Tensor, ...]:
646+
mtp_metas = (meta(hidden_shape) for _ in range(mtp_depth))
647+
return (primary, *mtp_metas)
648+
649+
if is_first:
650+
inputs_meta = (torch.empty(microbatch_size, seq_len, device="meta", dtype=torch.long),)
651+
else:
652+
inputs_meta = append_mtp_metas(meta(hc_hidden_shape if self.config.hc_mult > 1 else hidden_shape))
653+
654+
if self.lm_head is not None:
655+
output_meta = meta((microbatch_size, seq_len, self.config.vocab_size))
656+
elif getattr(self.model, "norm", None) is not None:
657+
output_meta = meta(hidden_shape)
658+
else:
659+
output_meta = meta(hc_hidden_shape if self.config.hc_mult > 1 else hidden_shape)
660+
661+
return inputs_meta, append_mtp_metas(output_meta)
662+
663+
def _is_pipeline_parallel_stage(self) -> bool:
664+
if self.lm_head is None:
665+
return True
666+
if getattr(self.model, "embed_tokens", None) is None:
667+
return True
668+
try:
669+
return len(self.model.layers) != int(self.config.num_hidden_layers)
670+
except TypeError:
671+
return False
672+
673+
def _build_mtp_embed_inputs_for_pp(self, input_ids: torch.Tensor) -> tuple[torch.Tensor, ...]:
674+
if getattr(self.model, "embed_tokens", None) is None:
675+
raise ValueError("First PP stage must own embed_tokens to build MTP embeddings")
676+
if input_ids.dtype not in (torch.int32, torch.int64, torch.long):
677+
raise ValueError("First PP stage must receive token ids to build MTP embeddings")
678+
679+
from nemo_automodel.components.models.common.mtp import roll_tensor # noqa: PLC0415
680+
681+
cur_input_ids = input_ids
682+
embeds = []
683+
for _ in range(self.mtp_config.num_layers):
684+
cur_input_ids = roll_tensor(cur_input_ids, shifts=-1, dim=-1)
685+
embeds.append(self.model.embed_tokens(cur_input_ids))
686+
return tuple(embeds)
687+
602688
def forward(
603689
self,
604690
input_ids: torch.Tensor,
605-
*,
691+
*mtp_embed_inputs: torch.Tensor,
606692
position_ids: torch.Tensor | None = None,
607693
attention_mask: torch.Tensor | None = None,
608694
padding_mask: torch.Tensor | None = None,
609695
**attn_kwargs: Any,
610-
) -> "DeepseekV4CausalLMOutput":
696+
) -> "DeepseekV4CausalLMOutput" | tuple[torch.Tensor, ...] | torch.Tensor:
697+
is_pp_stage = self._is_pipeline_parallel_stage()
698+
pp_mtp_enabled = is_pp_stage and self.mtp_config.enabled
699+
611700
thd_mode = "qkv_format" in attn_kwargs and attn_kwargs["qkv_format"] == "thd"
612701
if thd_mode:
613702
input_ids, position_ids, padding_mask, attn_kwargs = squeeze_input_for_thd(
@@ -633,8 +722,15 @@ def forward(
633722
if thd_mode:
634723
logits = logits.unsqueeze(0)
635724

725+
if pp_mtp_enabled and self.lm_head is None:
726+
if not mtp_embed_inputs:
727+
mtp_embed_inputs = self._build_mtp_embed_inputs_for_pp(input_ids)
728+
return (logits, *mtp_embed_inputs)
729+
636730
mtp_per_depth_h = None
637731
if use_mtp:
732+
if is_pp_stage and not mtp_embed_inputs:
733+
raise ValueError("Final PP stage requires propagated MTP embeddings")
638734
# MTP consumes the pre-final-head HC stream [B, S, hc_mult, hidden]
639735
# and returns collapsed per-depth [B, S, hidden] tensors for CE.
640736
seq_len = hidden_states.shape[1]
@@ -650,14 +746,27 @@ def forward(
650746
batch_size=batch_size,
651747
sliding_window=sliding_window,
652748
)
653-
mtp_per_depth_h = self.mtp(
654-
input_ids=input_ids,
655-
hidden_states=mtp_hc_hidden,
656-
embed_fn=self.model.embed_tokens,
657-
position_ids=position_ids,
658-
attention_mask=mtp_attn_mask,
659-
padding_mask=padding_mask,
660-
)
749+
mtp_kwargs = {
750+
"hidden_states": mtp_hc_hidden,
751+
"position_ids": position_ids,
752+
"attention_mask": mtp_attn_mask,
753+
"padding_mask": padding_mask,
754+
}
755+
if mtp_embed_inputs:
756+
mtp_kwargs["embed_inputs"] = tuple(mtp_embed_inputs)
757+
else:
758+
mtp_kwargs["input_ids"] = input_ids
759+
mtp_kwargs["embed_fn"] = self.model.embed_tokens
760+
mtp_per_depth_h = self.mtp(**mtp_kwargs)
761+
elif pp_mtp_enabled and self.lm_head is not None:
762+
mtp_per_depth_h = [hidden_states.new_empty(hidden_states.shape) for _ in range(self.mtp_config.num_layers)]
763+
764+
if is_pp_stage:
765+
if pp_mtp_enabled:
766+
if self.training and self.mtp is None:
767+
raise ValueError("Final PP stage has MTP enabled but does not own the MTP module")
768+
return (logits, *mtp_per_depth_h)
769+
return logits
661770

662771
return DeepseekV4CausalLMOutput(
663772
logits=logits,

nemo_automodel/components/models/deepseek_v4/mtp.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -225,17 +225,28 @@ def num_depths(self) -> int:
225225

226226
def forward(
227227
self,
228-
input_ids: torch.LongTensor,
229228
hidden_states: torch.Tensor,
230-
embed_fn,
229+
input_ids: torch.LongTensor | None = None,
230+
embed_fn=None,
231+
embed_inputs: tuple[torch.Tensor, ...] | list[torch.Tensor] | None = None,
231232
position_ids: torch.LongTensor | None = None,
232233
**block_kwargs,
233234
) -> list[torch.Tensor]:
234235
per_depth_h: list[torch.Tensor] = []
235236
cur_input_ids = input_ids
236-
for block in self.layers:
237-
cur_input_ids = roll_tensor(cur_input_ids, shifts=-1, dim=-1)
238-
decoder_input = embed_fn(cur_input_ids)
237+
if embed_inputs is not None and len(embed_inputs) != len(self.layers):
238+
raise ValueError(
239+
f"Expected {len(self.layers)} MTP embedding tensors, got {len(embed_inputs)}"
240+
)
241+
if embed_inputs is None and (cur_input_ids is None or embed_fn is None):
242+
raise ValueError("MTP requires either embed_inputs or both input_ids and embed_fn")
243+
244+
for depth, block in enumerate(self.layers):
245+
if embed_inputs is None:
246+
cur_input_ids = roll_tensor(cur_input_ids, shifts=-1, dim=-1)
247+
decoder_input = embed_fn(cur_input_ids)
248+
else:
249+
decoder_input = embed_inputs[depth]
239250
kwargs = dict(block_kwargs)
240251
if position_ids is not None:
241252
kwargs["position_ids"] = position_ids

nemo_automodel/components/utils/model_utils.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,30 @@ def filter_forward_kwargs(model: nn.Module, kwargs: dict) -> dict:
125125
return filtered
126126

127127

128+
def get_lm_head_module(model: nn.Module) -> nn.Module | None:
129+
"""Return the model's LM head module, if one can be found."""
130+
if hasattr(model, "get_output_embeddings"):
131+
lm_head = model.get_output_embeddings()
132+
if lm_head is not None:
133+
return lm_head
134+
for name, module in model.named_modules():
135+
if (name == "lm_head" or name.endswith(".lm_head")) and hasattr(module, "weight"):
136+
return module
137+
return None
138+
139+
140+
def get_lm_head_weight(model: nn.Module) -> torch.Tensor:
141+
"""Return the model's LM-head weight, materializing DTensor weights when needed."""
142+
lm_head = get_lm_head_module(model)
143+
if lm_head is not None:
144+
weight = lm_head.weight
145+
return weight.full_tensor() if hasattr(weight, "full_tensor") else weight
146+
for name, param in model.named_parameters(remove_duplicate=False):
147+
if "lm_head" in name and name.endswith(".weight"):
148+
return param.full_tensor() if hasattr(param, "full_tensor") else param
149+
raise ValueError("lm_head.weight not found in model")
150+
151+
128152
def _get_logical_numel(param) -> int:
129153
"""Return the logical number of elements for a parameter,
130154
accounting for quantized (packed) storage.

0 commit comments

Comments
 (0)