Skip to content

Commit 95cde23

Browse files
committed
style(deepseek-v4): apply ruff formatting
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
1 parent 7c5a5b3 commit 95cde23

2 files changed

Lines changed: 9 additions & 11 deletions

File tree

nemo_automodel/components/distributed/pipelining/functional.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -317,17 +317,13 @@ def _precompute_stage_shapes(
317317
# First stage receives input_ids: [mb, seq_len] int64
318318
stage.inputs_meta = (torch.empty(microbatch_size, seq_len, device="meta", dtype=torch.long),)
319319
else:
320-
stage.inputs_meta = (
321-
torch.empty(microbatch_size, seq_len, hidden_size, device="meta", dtype=model_dtype),
322-
)
320+
stage.inputs_meta = (torch.empty(microbatch_size, seq_len, hidden_size, device="meta", dtype=model_dtype),)
323321

324322
# --- outputs_meta ---
325323
has_lm_head = hasattr(stage.submod, "lm_head") and stage.submod.lm_head is not None
326324
if has_lm_head:
327325
# Last stage with lm_head produces logits: [mb, seq_len, vocab_size]
328-
primary_output_meta = torch.empty(
329-
microbatch_size, seq_len, vocab_size, device="meta", dtype=model_dtype
330-
)
326+
primary_output_meta = torch.empty(microbatch_size, seq_len, vocab_size, device="meta", dtype=model_dtype)
331327
else:
332328
primary_output_meta = torch.empty(microbatch_size, seq_len, hidden_size, device="meta", dtype=model_dtype)
333329
outputs_meta = (primary_output_meta,)

nemo_automodel/components/models/deepseek_v4/mtp.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,9 @@ def forward(
141141

142142
if position_ids is None:
143143
seq_len = embed_input.shape[1]
144-
position_ids = torch.arange(seq_len, device=embed_input.device).unsqueeze(0).expand(embed_input.shape[0], -1)
144+
position_ids = (
145+
torch.arange(seq_len, device=embed_input.device).unsqueeze(0).expand(embed_input.shape[0], -1)
146+
)
145147
position_embeddings = self._rotary_emb(embed_input, position_ids)
146148
position_embeddings_compress = self._rotary_emb_compress(embed_input, position_ids)
147149

@@ -235,9 +237,7 @@ def forward(
235237
per_depth_h: list[torch.Tensor] = []
236238
cur_input_ids = input_ids
237239
if embed_inputs is not None and len(embed_inputs) != len(self.layers):
238-
raise ValueError(
239-
f"Expected {len(self.layers)} MTP embedding tensors, got {len(embed_inputs)}"
240-
)
240+
raise ValueError(f"Expected {len(self.layers)} MTP embedding tensors, got {len(embed_inputs)}")
241241
if embed_inputs is None and (cur_input_ids is None or embed_fn is None):
242242
raise ValueError("MTP requires either embed_inputs or both input_ids and embed_fn")
243243

@@ -263,7 +263,9 @@ def forward(
263263
def build_mtp_config_from_hf(config, *, loss_scaling_factor: float = 0.1) -> MTPConfig:
264264
"""Build an MTPConfig from a DeepseekV4Config."""
265265
num_layers = int(getattr(config, "num_nextn_predict_layers", 0) or 0)
266-
return MTPConfig(num_layers=num_layers, layer_pattern="*" if num_layers > 0 else "", loss_scaling_factor=loss_scaling_factor)
266+
return MTPConfig(
267+
num_layers=num_layers, layer_pattern="*" if num_layers > 0 else "", loss_scaling_factor=loss_scaling_factor
268+
)
267269

268270

269271
def build_deepseek_v4_mtp(

0 commit comments

Comments
 (0)