1313# limitations under the License.
1414
1515import copy
16+ import inspect
1617import logging
1718import math
1819import os
4243logger = logging .getLogger (__name__ )
4344
4445
46+ def _get_optional_hook (module : object , name : str ) -> Callable | None :
47+ try :
48+ inspect .getattr_static (module , name )
49+ except AttributeError :
50+ return None
51+ hook = getattr (module , name )
52+ return hook if callable (hook ) else None
53+
54+
4555class ParallelizeFnProtocol (Protocol ):
4656 def __call__ (
4757 self ,
@@ -284,52 +294,43 @@ def _precompute_stage_shapes(
284294 """
285295 hidden_size , vocab_size = _get_hidden_and_vocab_size (model_config )
286296
287- # DeepSeek V4 preserves an extra hc_mult axis between blocks, so inter-stage
288- # hidden state is [mb, seq, hc_mult, dim] until the last (norm) stage folds
289- # it back to [mb, seq, dim].
290- is_v4 = (
291- getattr (model_config , "model_type" , None ) == "deepseek_v4"
292- or getattr (getattr (model_config , "text_config" , None ), "model_type" , None ) == "deepseek_v4"
293- )
294- hc_mult = int (getattr (model_config , "hc_mult" , 1 ) or 1 ) if is_v4 else 1
295-
296297 for stage in stages :
297298 # Infer the computation dtype from the stage's parameters
298299 try :
299300 model_dtype = next (stage .submod .parameters ()).dtype
300301 except StopIteration :
301302 model_dtype = torch .bfloat16
302303
303- inner_submod = getattr (stage .submod , "model" , stage .submod )
304- stage_has_norm = getattr (inner_submod , "norm" , None ) is not None
304+ get_stage_metas = _get_optional_hook (stage .submod , "get_pipeline_stage_metas" )
305+ if get_stage_metas is not None :
306+ stage .inputs_meta , outputs_meta = get_stage_metas (
307+ is_first = stage .is_first ,
308+ microbatch_size = microbatch_size ,
309+ seq_len = seq_len ,
310+ dtype = model_dtype ,
311+ )
312+ stage ._configure_outputs_meta (outputs_meta )
313+ continue
305314
306315 # --- inputs_meta ---
307316 if stage .is_first :
308317 # First stage receives input_ids: [mb, seq_len] int64
309318 stage .inputs_meta = (torch .empty (microbatch_size , seq_len , device = "meta" , dtype = torch .long ),)
310319 else :
311- if hc_mult > 1 :
312- stage .inputs_meta = (
313- torch .empty (microbatch_size , seq_len , hc_mult , hidden_size , device = "meta" , dtype = model_dtype ),
314- )
315- else :
316- stage .inputs_meta = (
317- torch .empty (microbatch_size , seq_len , hidden_size , device = "meta" , dtype = model_dtype ),
318- )
320+ stage .inputs_meta = (
321+ torch .empty (microbatch_size , seq_len , hidden_size , device = "meta" , dtype = model_dtype ),
322+ )
319323
320324 # --- outputs_meta ---
321325 has_lm_head = hasattr (stage .submod , "lm_head" ) and stage .submod .lm_head is not None
322326 if has_lm_head :
323327 # Last stage with lm_head produces logits: [mb, seq_len, vocab_size]
324- outputs_meta = (torch .empty (microbatch_size , seq_len , vocab_size , device = "meta" , dtype = model_dtype ),)
325- elif hc_mult > 1 and not stage_has_norm :
326- # V4 mid-pipeline: tensor still carries the hc_mult axis.
327- outputs_meta = (
328- torch .empty (microbatch_size , seq_len , hc_mult , hidden_size , device = "meta" , dtype = model_dtype ),
328+ primary_output_meta = torch .empty (
329+ microbatch_size , seq_len , vocab_size , device = "meta" , dtype = model_dtype
329330 )
330331 else :
331- # Standard intermediate stage (or V4 final-norm stage without lm_head).
332- outputs_meta = (torch . empty ( microbatch_size , seq_len , hidden_size , device = "meta" , dtype = model_dtype ) ,)
332+ primary_output_meta = torch . empty ( microbatch_size , seq_len , hidden_size , device = "meta" , dtype = model_dtype )
333+ outputs_meta = (primary_output_meta ,)
333334 stage ._configure_outputs_meta (outputs_meta )
334335
335336 logger .info (
@@ -475,12 +476,6 @@ def split_model_into_stages(
475476 else :
476477 lm_head_fqn = "lm_head"
477478
478- # DeepSeek V4: model carries an extra compressor-rotary module on every stage
479- # and an HC head on the last stage; both must survive PP module pruning.
480- is_v4_keep = getattr (getattr (model , "config" , None ), "model_type" , None ) == "deepseek_v4"
481- has_rotary_emb_compress = is_v4_keep and hasattr (text_model , "rotary_emb_compress" )
482- has_hc_head = is_v4_keep and hasattr (text_model , "hc_head" )
483-
484479 # Auto-generate module split if not provided
485480 if module_names_per_stage is None :
486481 module_names_per_stage = generate_hf_model_fqn_per_model_part (
@@ -495,13 +490,15 @@ def split_model_into_stages(
495490 lm_head_fqn = lm_head_fqn ,
496491 )
497492
498- # V4 post-processing: keep the compressor rotary on every stage and the
499- # HC head on the last stage so the V4 PP forward can run end-to-end.
500- if has_rotary_emb_compress :
501- for stage_modules in module_names_per_stage :
502- stage_modules .append (f"{ layers_prefix } rotary_emb_compress" )
503- if has_hc_head :
504- module_names_per_stage [- 1 ].append (f"{ layers_prefix } hc_head" )
493+ customize_stage_modules = _get_optional_hook (model , "customize_pipeline_stage_modules" )
494+ if customize_stage_modules is not None :
495+ custom_module_names = customize_stage_modules (
496+ module_names_per_stage ,
497+ layers_prefix = layers_prefix ,
498+ text_model = text_model ,
499+ )
500+ if custom_module_names is not None :
501+ module_names_per_stage = custom_module_names
505502
506503 def _build_stage_from_modules (
507504 stage_idx : int , module_names : list [str ], num_stages : int
0 commit comments