Skip to content

Commit c985123

Browse files
committed
Add pipeline hooks for custom stage metadata
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
1 parent cbbb216 commit c985123

1 file changed

Lines changed: 36 additions & 39 deletions

File tree

nemo_automodel/components/distributed/pipelining/functional.py

Lines changed: 36 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import copy
16+
import inspect
1617
import logging
1718
import math
1819
import os
@@ -42,6 +43,15 @@
4243
logger = logging.getLogger(__name__)
4344

4445

46+
def _get_optional_hook(module: object, name: str) -> Callable | None:
47+
try:
48+
inspect.getattr_static(module, name)
49+
except AttributeError:
50+
return None
51+
hook = getattr(module, name)
52+
return hook if callable(hook) else None
53+
54+
4555
class ParallelizeFnProtocol(Protocol):
4656
def __call__(
4757
self,
@@ -284,52 +294,43 @@ def _precompute_stage_shapes(
284294
"""
285295
hidden_size, vocab_size = _get_hidden_and_vocab_size(model_config)
286296

287-
# DeepSeek V4 preserves an extra hc_mult axis between blocks, so inter-stage
288-
# hidden state is [mb, seq, hc_mult, dim] until the last (norm) stage folds
289-
# it back to [mb, seq, dim].
290-
is_v4 = (
291-
getattr(model_config, "model_type", None) == "deepseek_v4"
292-
or getattr(getattr(model_config, "text_config", None), "model_type", None) == "deepseek_v4"
293-
)
294-
hc_mult = int(getattr(model_config, "hc_mult", 1) or 1) if is_v4 else 1
295-
296297
for stage in stages:
297298
# Infer the computation dtype from the stage's parameters
298299
try:
299300
model_dtype = next(stage.submod.parameters()).dtype
300301
except StopIteration:
301302
model_dtype = torch.bfloat16
302303

303-
inner_submod = getattr(stage.submod, "model", stage.submod)
304-
stage_has_norm = getattr(inner_submod, "norm", None) is not None
304+
get_stage_metas = _get_optional_hook(stage.submod, "get_pipeline_stage_metas")
305+
if get_stage_metas is not None:
306+
stage.inputs_meta, outputs_meta = get_stage_metas(
307+
is_first=stage.is_first,
308+
microbatch_size=microbatch_size,
309+
seq_len=seq_len,
310+
dtype=model_dtype,
311+
)
312+
stage._configure_outputs_meta(outputs_meta)
313+
continue
305314

306315
# --- inputs_meta ---
307316
if stage.is_first:
308317
# First stage receives input_ids: [mb, seq_len] int64
309318
stage.inputs_meta = (torch.empty(microbatch_size, seq_len, device="meta", dtype=torch.long),)
310319
else:
311-
if hc_mult > 1:
312-
stage.inputs_meta = (
313-
torch.empty(microbatch_size, seq_len, hc_mult, hidden_size, device="meta", dtype=model_dtype),
314-
)
315-
else:
316-
stage.inputs_meta = (
317-
torch.empty(microbatch_size, seq_len, hidden_size, device="meta", dtype=model_dtype),
318-
)
320+
stage.inputs_meta = (
321+
torch.empty(microbatch_size, seq_len, hidden_size, device="meta", dtype=model_dtype),
322+
)
319323

320324
# --- outputs_meta ---
321325
has_lm_head = hasattr(stage.submod, "lm_head") and stage.submod.lm_head is not None
322326
if has_lm_head:
323327
# Last stage with lm_head produces logits: [mb, seq_len, vocab_size]
324-
outputs_meta = (torch.empty(microbatch_size, seq_len, vocab_size, device="meta", dtype=model_dtype),)
325-
elif hc_mult > 1 and not stage_has_norm:
326-
# V4 mid-pipeline: tensor still carries the hc_mult axis.
327-
outputs_meta = (
328-
torch.empty(microbatch_size, seq_len, hc_mult, hidden_size, device="meta", dtype=model_dtype),
328+
primary_output_meta = torch.empty(
329+
microbatch_size, seq_len, vocab_size, device="meta", dtype=model_dtype
329330
)
330331
else:
331-
# Standard intermediate stage (or V4 final-norm stage without lm_head).
332-
outputs_meta = (torch.empty(microbatch_size, seq_len, hidden_size, device="meta", dtype=model_dtype),)
332+
primary_output_meta = torch.empty(microbatch_size, seq_len, hidden_size, device="meta", dtype=model_dtype)
333+
outputs_meta = (primary_output_meta,)
333334
stage._configure_outputs_meta(outputs_meta)
334335

335336
logger.info(
@@ -475,12 +476,6 @@ def split_model_into_stages(
475476
else:
476477
lm_head_fqn = "lm_head"
477478

478-
# DeepSeek V4: model carries an extra compressor-rotary module on every stage
479-
# and an HC head on the last stage; both must survive PP module pruning.
480-
is_v4_keep = getattr(getattr(model, "config", None), "model_type", None) == "deepseek_v4"
481-
has_rotary_emb_compress = is_v4_keep and hasattr(text_model, "rotary_emb_compress")
482-
has_hc_head = is_v4_keep and hasattr(text_model, "hc_head")
483-
484479
# Auto-generate module split if not provided
485480
if module_names_per_stage is None:
486481
module_names_per_stage = generate_hf_model_fqn_per_model_part(
@@ -495,13 +490,15 @@ def split_model_into_stages(
495490
lm_head_fqn=lm_head_fqn,
496491
)
497492

498-
# V4 post-processing: keep the compressor rotary on every stage and the
499-
# HC head on the last stage so the V4 PP forward can run end-to-end.
500-
if has_rotary_emb_compress:
501-
for stage_modules in module_names_per_stage:
502-
stage_modules.append(f"{layers_prefix}rotary_emb_compress")
503-
if has_hc_head:
504-
module_names_per_stage[-1].append(f"{layers_prefix}hc_head")
493+
customize_stage_modules = _get_optional_hook(model, "customize_pipeline_stage_modules")
494+
if customize_stage_modules is not None:
495+
custom_module_names = customize_stage_modules(
496+
module_names_per_stage,
497+
layers_prefix=layers_prefix,
498+
text_model=text_model,
499+
)
500+
if custom_module_names is not None:
501+
module_names_per_stage = custom_module_names
505502

506503
def _build_stage_from_modules(
507504
stage_idx: int, module_names: list[str], num_stages: int

0 commit comments

Comments
 (0)