Skip to content

Commit 660ed94

Browse files
pzelaskoclaude
andauthored
fix(nemotron-v3): support THD with input_embeds instead of input_ids (#2185)
fix(thd): support inputs_embeds-only callers in NemotronHForCausalLM `SALMAutomodel` and other multimodal callers feed the LLM through `inputs_embeds` (audio frames spliced into the token stream have no integer ID) and leave `input_ids=None`. Two bugs surfaced when running that path under `qkv_format="thd"`: 1. `squeeze_input_for_thd` did `input_ids.squeeze(0)` unconditionally and crashed with `AttributeError: 'NoneType' object has no attribute 'squeeze'`. Add the same `is-not-None` guard the helper already uses for `padding_mask`; document `None` as a valid value. 2. `NemotronHForCausalLM.forward` did `logits = logits.unsqueeze(0)` whenever `is_thd`, producing `[1, 1, T, V]` for the `inputs_embeds` path because `NemotronHModel.forward` already restores the batch dim (`squeezed_for_thd` branch). Restrict the outer unsqueeze to the case where the inner returned 2D logits; the standard `input_ids` path still satisfies that. Tests: - `TestSqueezeInputForThd` (5 cases) covers the helper-level contract: standard `input_ids` path, `input_ids=None` path, `padding_mask=None` composition, 3D `[1, T, H]` embedding-via-`input_ids` slot path, and `cu_seqlens_padded` filtering. - `TestNemotronHForCausalLM::test_causal_lm_thd_*` (2 cases) covers the outer logits-shape contract: `inputs_embeds`-only stays `[1, T, V]` (no double-unsqueeze), and `input_ids`-only still gets the batch dim re-added. The inner forward is stubbed via a tiny `nn.Module` because THD shapes only run end-to-end on TE/GPU. Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 13ea298 commit 660ed94

4 files changed

Lines changed: 192 additions & 4 deletions

File tree

nemo_automodel/components/models/nemotron_v3/model.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,14 @@ def forward(
418418
shift_labels.view(-1),
419419
)
420420

421-
if is_thd:
421+
# Restore the batch dim for THD only when the inner forward returned
422+
# 2D logits. When the caller feeds the model via ``inputs_embeds``
423+
# (shape ``[1, T, H]``), ``NemotronHModel.forward`` squeezes to
424+
# ``[T, H]`` for the layer stack and unsqueezes back to ``[1, T, H]``
425+
# before returning (see the ``squeezed_for_thd`` branch); the lm_head
426+
# then yields ``[1, T, V]`` already and a second unsqueeze here would
427+
# produce a spurious ``[1, 1, T, V]``.
428+
if is_thd and logits.dim() == 2:
422429
logits = logits.unsqueeze(0)
423430

424431
return CausalLMOutputWithPast(

nemo_automodel/components/utils/model_utils.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -386,8 +386,13 @@ def squeeze_input_for_thd(input_ids, position_ids, padding_mask, attn_kwargs, se
386386
3. Converts max_seqlen from tensor to scalar if needed
387387
388388
Args:
389-
input_ids (torch.Tensor): Input token IDs with shape [1, total_tokens] or
390-
[1, total_tokens, hidden_dim]. The first dimension will be squeezed.
389+
input_ids (torch.Tensor or None): Input token IDs with shape [1, total_tokens]
390+
or [1, total_tokens, hidden_dim]. The first dimension will be squeezed.
391+
``None`` is permitted when the caller is feeding the model via
392+
``inputs_embeds`` instead — embeddings are squeezed inside the model
393+
forward (the ``squeezed_for_thd`` branch in ``NemotronHModel.forward``
394+
and analogous code paths), so this helper has nothing to squeeze and
395+
simply returns ``None`` for the ``input_ids`` slot.
391396
position_ids (torch.Tensor): Position IDs with shape [1, total_tokens].
392397
The first dimension will be squeezed.
393398
padding_mask (torch.Tensor): Padding mask with shape [1, total_tokens].
@@ -435,7 +440,8 @@ def squeeze_input_for_thd(input_ids, position_ids, padding_mask, attn_kwargs, se
435440
This function modifies attn_kwargs in-place. If you need to preserve the original
436441
dictionary, pass a copy.
437442
"""
438-
input_ids = input_ids.squeeze(0)
443+
if input_ids is not None:
444+
input_ids = input_ids.squeeze(0)
439445
position_ids = position_ids.squeeze(0)
440446
if isinstance(padding_mask, torch.Tensor):
441447
padding_mask = padding_mask.squeeze(0)

tests/unit_tests/models/nemotron_v3/test_nemotron_v3_model.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,91 @@ def test_causal_lm_forward_no_input_ids_no_inputs_embeds_raises(self, config, ba
405405
with pytest.raises(ValueError, match="input_ids must be provided if inputs_embeds is not provided"):
406406
model()
407407

408+
def _build_stub_inner_model(self, hidden):
409+
"""Tiny ``nn.Module`` whose forward returns a fixed tensor. Lets us
410+
replace ``NemotronHForCausalLM.model`` (an ``nn.Module``) without
411+
tripping ``nn.Module.__setattr__``'s child-module type check."""
412+
413+
class _StubInner(torch.nn.Module):
414+
def forward(self, *args, **kwargs):
415+
return hidden
416+
417+
return _StubInner()
418+
419+
def test_causal_lm_thd_inputs_embeds_does_not_double_unsqueeze(self, config, backend):
420+
"""Regression test: in THD mode with ``inputs_embeds``-only inputs, the
421+
outer ``NemotronHForCausalLM.forward`` used to double-unsqueeze the
422+
logits to ``[1, 1, T, V]``. The inner ``NemotronHModel.forward`` already
423+
restores the batch dim (``squeezed_for_thd`` branch), so the outer must
424+
only re-add it when the inner returned 2D logits.
425+
426+
We bypass the attention stack (which needs TE/GPU for THD shapes) by
427+
replacing ``model.model`` with a stub that returns a fixed 3D tensor —
428+
the same shape the real inner forward returns when it took the
429+
``inputs_embeds`` → squeeze → unsqueeze round-trip.
430+
"""
431+
from nemo_automodel.components.models.nemotron_v3.model import NemotronHForCausalLM
432+
433+
model = NemotronHForCausalLM(config, backend=backend)
434+
model = model.to(torch.bfloat16)
435+
436+
seq_len = 8
437+
# Stand in for the inner forward that unsqueezed back to [1, T, H].
438+
stub_hidden = torch.randn(1, seq_len, config.hidden_size, dtype=torch.bfloat16)
439+
model.model = self._build_stub_inner_model(stub_hidden)
440+
441+
inputs_embeds = torch.randn(1, seq_len, config.hidden_size, dtype=torch.bfloat16)
442+
position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0)
443+
cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32)
444+
max_seqlen = torch.tensor(seq_len, dtype=torch.int32)
445+
446+
output = model(
447+
inputs_embeds=inputs_embeds,
448+
position_ids=position_ids,
449+
cu_seqlens=cu_seqlens,
450+
cu_seqlens_padded=cu_seqlens,
451+
max_seqlen_q=max_seqlen,
452+
max_seqlen_kv=max_seqlen,
453+
qkv_format="thd",
454+
)
455+
456+
assert output.logits.shape == (1, seq_len, config.vocab_size)
457+
assert output.logits.dim() == 3
458+
459+
def test_causal_lm_thd_input_ids_unsqueezes_2d_logits(self, config, backend):
460+
"""The original THD path (``input_ids`` only, no ``inputs_embeds``) is
461+
the one most callers use; the unsqueeze fix must not regress it. The
462+
inner forward returns ``[T, H]`` (2D, never went through the
463+
``squeezed_for_thd`` round-trip because ``embed_tokens(input_ids[T])``
464+
is already 2D), so the outer still has to add the batch dim."""
465+
from nemo_automodel.components.models.nemotron_v3.model import NemotronHForCausalLM
466+
467+
model = NemotronHForCausalLM(config, backend=backend)
468+
model = model.to(torch.bfloat16)
469+
470+
seq_len = 8
471+
# Stand in for the inner forward that returned 2D hidden_states.
472+
stub_hidden = torch.randn(seq_len, config.hidden_size, dtype=torch.bfloat16)
473+
model.model = self._build_stub_inner_model(stub_hidden)
474+
475+
input_ids = torch.randint(0, config.vocab_size, (1, seq_len))
476+
position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0)
477+
cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32)
478+
max_seqlen = torch.tensor(seq_len, dtype=torch.int32)
479+
480+
output = model(
481+
input_ids=input_ids,
482+
position_ids=position_ids,
483+
cu_seqlens=cu_seqlens,
484+
cu_seqlens_padded=cu_seqlens,
485+
max_seqlen_q=max_seqlen,
486+
max_seqlen_kv=max_seqlen,
487+
qkv_format="thd",
488+
)
489+
490+
assert output.logits.shape == (1, seq_len, config.vocab_size)
491+
assert output.logits.dim() == 3
492+
408493
def test_causal_lm_from_config(self, config, backend):
409494
"""Test from_config classmethod."""
410495
from nemo_automodel.components.models.nemotron_v3.model import NemotronHForCausalLM

tests/unit_tests/utils/test_model_utils.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,3 +453,93 @@ def forward(self, input_ids, **kwargs):
453453
filtered = model_utils.filter_forward_kwargs(model, batch)
454454

455455
assert filtered == batch
456+
457+
458+
class TestSqueezeInputForThd:
459+
"""``squeeze_input_for_thd`` strips the placeholder batch dim (``[1, T, ...] -> [T, ...]``)
460+
before THD attention/Mamba kernels see the inputs. The contract has to handle
461+
``input_ids=None`` because callers feeding the model via ``inputs_embeds`` only
462+
(multimodal LMs, speech-language models) leave ``input_ids`` unset; the embeddings
463+
are squeezed inside the model forward instead.
464+
"""
465+
466+
def _attn_kwargs(self, *, padded: bool = False, with_max_seqlen: bool = True):
467+
"""Build the kwargs dict in the canonical [1, num_seqs+1] layout, padded
468+
with the ``-1000`` sentinel that ``squeeze_input_for_thd`` filters out."""
469+
kwargs: dict = {
470+
"cu_seqlens": torch.tensor([[0, 3, 5, -1000]], dtype=torch.int32),
471+
}
472+
if padded:
473+
kwargs["cu_seqlens_padded"] = torch.tensor([[0, 4, 6, -1000]], dtype=torch.int32)
474+
if with_max_seqlen:
475+
kwargs["max_seqlen"] = torch.tensor([3])
476+
return kwargs
477+
478+
def test_squeezes_input_ids_when_provided(self):
479+
input_ids = torch.tensor([[1, 2, 3, 4, 5]])
480+
position_ids = torch.tensor([[0, 1, 2, 0, 1]])
481+
padding_mask = torch.tensor([[False, False, False, False, False]])
482+
kwargs = self._attn_kwargs()
483+
484+
ids, pos, mask, kw = model_utils.squeeze_input_for_thd(input_ids, position_ids, padding_mask, kwargs)
485+
486+
assert ids.shape == (5,)
487+
assert pos.shape == (5,)
488+
assert mask.shape == (5,)
489+
# Sentinel filtered out and dtype/shape preserved.
490+
assert kw["cu_seqlens"].tolist() == [0, 3, 5]
491+
assert kw["cu_seqlens"].dtype == torch.int32
492+
# max_seqlen tensor → Python int.
493+
assert kw["max_seqlen"] == 3
494+
assert isinstance(kw["max_seqlen"], int)
495+
496+
def test_accepts_input_ids_none_for_inputs_embeds_callers(self):
497+
"""The bug: prior code did ``input_ids.squeeze(0)`` unconditionally and
498+
crashed when the caller used ``inputs_embeds`` only. The fix returns
499+
``None`` for the ``input_ids`` slot and squeezes everything else."""
500+
position_ids = torch.tensor([[0, 1, 2, 0, 1]])
501+
padding_mask = torch.tensor([[False, False, False, False, False]])
502+
kwargs = self._attn_kwargs()
503+
504+
ids, pos, mask, kw = model_utils.squeeze_input_for_thd(None, position_ids, padding_mask, kwargs)
505+
506+
assert ids is None
507+
assert pos.shape == (5,)
508+
assert mask.shape == (5,)
509+
assert kw["cu_seqlens"].tolist() == [0, 3, 5]
510+
assert kw["max_seqlen"] == 3
511+
512+
def test_padding_mask_none_is_passed_through(self):
513+
"""Existing behavior: ``padding_mask`` may be ``None`` (unmasked path).
514+
The new ``input_ids=None`` branch must compose with this."""
515+
position_ids = torch.tensor([[0, 1, 2, 3, 4]])
516+
kwargs = self._attn_kwargs(with_max_seqlen=False)
517+
518+
ids, pos, mask, kw = model_utils.squeeze_input_for_thd(None, position_ids, None, kwargs)
519+
520+
assert ids is None
521+
assert mask is None
522+
assert pos.shape == (5,)
523+
524+
def test_3d_inputs_embeds_via_input_ids_slot_still_works(self):
525+
"""Belt-and-braces: the docstring claims ``input_ids`` may carry a 3D
526+
``[1, T, H]`` embedding tensor. Squeezing dim 0 of that yields ``[T, H]``."""
527+
embeds = torch.randn(1, 5, 16)
528+
position_ids = torch.tensor([[0, 1, 2, 3, 4]])
529+
kwargs = self._attn_kwargs(with_max_seqlen=False)
530+
531+
ids, pos, _mask, _kw = model_utils.squeeze_input_for_thd(embeds, position_ids, None, kwargs)
532+
533+
assert ids.shape == (5, 16)
534+
assert pos.shape == (5,)
535+
536+
def test_cu_seqlens_padded_filtered_alongside_cu_seqlens(self):
537+
"""Both ``cu_seqlens`` and ``cu_seqlens_padded`` (CP path) get the
538+
sentinel filter — the bug fix must not regress this."""
539+
position_ids = torch.tensor([[0, 1, 2, 0, 1]])
540+
kwargs = self._attn_kwargs(padded=True, with_max_seqlen=False)
541+
542+
_ids, _pos, _mask, kw = model_utils.squeeze_input_for_thd(None, position_ids, None, kwargs)
543+
544+
assert kw["cu_seqlens"].tolist() == [0, 3, 5]
545+
assert kw["cu_seqlens_padded"].tolist() == [0, 4, 6]

0 commit comments

Comments
 (0)