@@ -599,15 +599,104 @@ def get_output_embeddings(self):
599599 def set_output_embeddings (self , new_embeddings ):
600600 self .lm_head = new_embeddings
601601
602+ def customize_pipeline_stage_modules (
603+ self ,
604+ module_names_per_stage : list [list [str ]],
605+ * ,
606+ layers_prefix : str ,
607+ text_model : nn .Module | None = None ,
608+ ) -> list [list [str ]]:
609+ """Keep DSV4 non-layer PP dependencies with the stages that need them."""
610+
611+ text_model = text_model or self .model
612+ stage_modules = [list (modules ) for modules in module_names_per_stage ]
613+
614+ def append_once (modules : list [str ], fqn : str ) -> None :
615+ if fqn not in modules :
616+ modules .append (fqn )
617+
618+ if getattr (text_model , "rotary_emb_compress" , None ) is not None :
619+ for modules in stage_modules :
620+ append_once (modules , f"{ layers_prefix } rotary_emb_compress" )
621+ if getattr (text_model , "hc_head" , None ) is not None :
622+ append_once (stage_modules [- 1 ], f"{ layers_prefix } hc_head" )
623+ if self .mtp is not None :
624+ append_once (stage_modules [- 1 ], "mtp" )
625+
626+ return stage_modules
627+
628+ def get_pipeline_stage_metas (
629+ self ,
630+ * ,
631+ is_first : bool ,
632+ microbatch_size : int ,
633+ seq_len : int ,
634+ dtype : torch .dtype ,
635+ ) -> tuple [tuple [torch .Tensor , ...], tuple [torch .Tensor , ...]]:
636+ """Return PP input/output meta tensors for DSV4's HC and MTP contract."""
637+
638+ hidden_shape = (microbatch_size , seq_len , self .config .hidden_size )
639+ hc_hidden_shape = (microbatch_size , seq_len , self .config .hc_mult , self .config .hidden_size )
640+ mtp_depth = int (getattr (self .mtp_config , "num_layers" , 0 ) or 0 )
641+
642+ def meta (shape : tuple [int , ...]) -> torch .Tensor :
643+ return torch .empty (* shape , device = "meta" , dtype = dtype )
644+
645+ def append_mtp_metas (primary : torch .Tensor ) -> tuple [torch .Tensor , ...]:
646+ mtp_metas = (meta (hidden_shape ) for _ in range (mtp_depth ))
647+ return (primary , * mtp_metas )
648+
649+ if is_first :
650+ inputs_meta = (torch .empty (microbatch_size , seq_len , device = "meta" , dtype = torch .long ),)
651+ else :
652+ inputs_meta = append_mtp_metas (meta (hc_hidden_shape if self .config .hc_mult > 1 else hidden_shape ))
653+
654+ if self .lm_head is not None :
655+ output_meta = meta ((microbatch_size , seq_len , self .config .vocab_size ))
656+ elif getattr (self .model , "norm" , None ) is not None :
657+ output_meta = meta (hidden_shape )
658+ else :
659+ output_meta = meta (hc_hidden_shape if self .config .hc_mult > 1 else hidden_shape )
660+
661+ return inputs_meta , append_mtp_metas (output_meta )
662+
663+ def _is_pipeline_parallel_stage (self ) -> bool :
664+ if self .lm_head is None :
665+ return True
666+ if getattr (self .model , "embed_tokens" , None ) is None :
667+ return True
668+ try :
669+ return len (self .model .layers ) != int (self .config .num_hidden_layers )
670+ except TypeError :
671+ return False
672+
673+ def _build_mtp_embed_inputs_for_pp (self , input_ids : torch .Tensor ) -> tuple [torch .Tensor , ...]:
674+ if getattr (self .model , "embed_tokens" , None ) is None :
675+ raise ValueError ("First PP stage must own embed_tokens to build MTP embeddings" )
676+ if input_ids .dtype not in (torch .int32 , torch .int64 , torch .long ):
677+ raise ValueError ("First PP stage must receive token ids to build MTP embeddings" )
678+
679+ from nemo_automodel .components .models .common .mtp import roll_tensor # noqa: PLC0415
680+
681+ cur_input_ids = input_ids
682+ embeds = []
683+ for _ in range (self .mtp_config .num_layers ):
684+ cur_input_ids = roll_tensor (cur_input_ids , shifts = - 1 , dim = - 1 )
685+ embeds .append (self .model .embed_tokens (cur_input_ids ))
686+ return tuple (embeds )
687+
602688 def forward (
603689 self ,
604690 input_ids : torch .Tensor ,
605- * ,
691+ * mtp_embed_inputs : torch . Tensor ,
606692 position_ids : torch .Tensor | None = None ,
607693 attention_mask : torch .Tensor | None = None ,
608694 padding_mask : torch .Tensor | None = None ,
609695 ** attn_kwargs : Any ,
610- ) -> "DeepseekV4CausalLMOutput" :
696+ ) -> "DeepseekV4CausalLMOutput" | tuple [torch .Tensor , ...] | torch .Tensor :
697+ is_pp_stage = self ._is_pipeline_parallel_stage ()
698+ pp_mtp_enabled = is_pp_stage and self .mtp_config .enabled
699+
611700 thd_mode = "qkv_format" in attn_kwargs and attn_kwargs ["qkv_format" ] == "thd"
612701 if thd_mode :
613702 input_ids , position_ids , padding_mask , attn_kwargs = squeeze_input_for_thd (
@@ -633,8 +722,15 @@ def forward(
633722 if thd_mode :
634723 logits = logits .unsqueeze (0 )
635724
725+ if pp_mtp_enabled and self .lm_head is None :
726+ if not mtp_embed_inputs :
727+ mtp_embed_inputs = self ._build_mtp_embed_inputs_for_pp (input_ids )
728+ return (logits , * mtp_embed_inputs )
729+
636730 mtp_per_depth_h = None
637731 if use_mtp :
732+ if is_pp_stage and not mtp_embed_inputs :
733+ raise ValueError ("Final PP stage requires propagated MTP embeddings" )
638734 # MTP consumes the pre-final-head HC stream [B, S, hc_mult, hidden]
639735 # and returns collapsed per-depth [B, S, hidden] tensors for CE.
640736 seq_len = hidden_states .shape [1 ]
@@ -650,14 +746,27 @@ def forward(
650746 batch_size = batch_size ,
651747 sliding_window = sliding_window ,
652748 )
653- mtp_per_depth_h = self .mtp (
654- input_ids = input_ids ,
655- hidden_states = mtp_hc_hidden ,
656- embed_fn = self .model .embed_tokens ,
657- position_ids = position_ids ,
658- attention_mask = mtp_attn_mask ,
659- padding_mask = padding_mask ,
660- )
749+ mtp_kwargs = {
750+ "hidden_states" : mtp_hc_hidden ,
751+ "position_ids" : position_ids ,
752+ "attention_mask" : mtp_attn_mask ,
753+ "padding_mask" : padding_mask ,
754+ }
755+ if mtp_embed_inputs :
756+ mtp_kwargs ["embed_inputs" ] = tuple (mtp_embed_inputs )
757+ else :
758+ mtp_kwargs ["input_ids" ] = input_ids
759+ mtp_kwargs ["embed_fn" ] = self .model .embed_tokens
760+ mtp_per_depth_h = self .mtp (** mtp_kwargs )
761+ elif pp_mtp_enabled and self .lm_head is not None :
762+ mtp_per_depth_h = [hidden_states .new_empty (hidden_states .shape ) for _ in range (self .mtp_config .num_layers )]
763+
764+ if is_pp_stage :
765+ if pp_mtp_enabled :
766+ if self .training and self .mtp is None :
767+ raise ValueError ("Final PP stage has MTP enabled but does not own the MTP module" )
768+ return (logits , * mtp_per_depth_h )
769+ return logits
661770
662771 return DeepseekV4CausalLMOutput (
663772 logits = logits ,
0 commit comments