Skip to content

Commit 84cb6e8

Browse files
adil-aNeMo Bot
authored andcommitted
feat: passthrough if inputs_embeds passed into nanov3 (#1261)
* fix Signed-off-by: adil-a <adil.asif2000@hotmail.com> * rename Signed-off-by: adil-a <adil.asif2000@hotmail.com> * unit tests Signed-off-by: adil-a <adil.asif2000@hotmail.com> --------- Signed-off-by: adil-a <adil.asif2000@hotmail.com> Signed-off-by: NeMo Bot <nemo-bot@nvidia.com>
1 parent df68a08 commit 84cb6e8

2 files changed

Lines changed: 107 additions & 5 deletions

File tree

nemo_automodel/components/models/nemotron_v3/model.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -99,25 +99,32 @@ def __init__(
9999

100100
def forward(
101101
self,
102-
input_ids: torch.LongTensor,
102+
input_ids: torch.LongTensor | None = None,
103103
*,
104104
attention_mask: torch.Tensor | None = None,
105105
causal_mask_mapping: dict[str, torch.Tensor] | None = None,
106+
inputs_embeds: torch.Tensor | None = None,
106107
**kwargs: Any,
107108
) -> torch.Tensor:
108109
"""Forward pass through the model.
109110
110111
Args:
111-
input_ids: Input token IDs [batch_size, seq_len]
112+
input_ids: Input token IDs [batch_size, seq_len] (optional)
112113
attention_mask: 2D padding mask [batch_size, seq_len] (1=real, 0=padding)
113114
causal_mask_mapping: Dict with precomputed 4D causal masks for attention layers
115+
inputs_embeds: Input embeddings [batch_size, seq_len, hidden_size] (optional)
114116
**kwargs: Additional arguments (ignored)
115117
116118
Returns:
117119
Hidden states tensor [batch_size, seq_len, hidden_size]
118120
"""
119121
# Get embeddings
120-
hidden_states = self.embed_tokens(input_ids)
122+
if inputs_embeds is None:
123+
if input_ids is None:
124+
raise ValueError("input_ids must be provided if inputs_embeds is not provided")
125+
hidden_states = self.embed_tokens(input_ids)
126+
else:
127+
hidden_states = inputs_embeds
121128

122129
# TODO: attention mask currently does not work. A default causal mask is applied.
123130

@@ -244,7 +251,7 @@ def __init__(
244251

245252
def forward(
246253
self,
247-
input_ids: torch.LongTensor,
254+
input_ids: torch.LongTensor | None = None,
248255
*,
249256
attention_mask: torch.Tensor | None = None,
250257
causal_mask_mapping: dict[str, torch.Tensor] | None = None,
@@ -253,7 +260,7 @@ def forward(
253260
"""Forward pass with optional loss computation.
254261
255262
Args:
256-
input_ids: Input token IDs [batch_size, seq_len]
263+
input_ids: Input token IDs [batch_size, seq_len] (optional)
257264
attention_mask: 2D padding mask [batch_size, seq_len]
258265
causal_mask_mapping: Dict with precomputed 4D causal masks
259266
**kwargs: Additional arguments

tests/unit_tests/models/nemotron_v3/test_nemotron_v3_model.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,76 @@ def test_model_forward_with_causal_mask_mapping(self, config, backend):
187187

188188
assert output.shape == (batch_size, seq_len, config.hidden_size)
189189

190+
def test_model_forward_with_inputs_embeds(self, config, backend):
191+
"""Test model forward pass with inputs_embeds instead of input_ids."""
192+
from nemo_automodel.components.models.nemotron_v3.model import NemotronV3Model
193+
194+
model = NemotronV3Model(config, backend=backend)
195+
model = model.to(torch.bfloat16)
196+
197+
batch_size, seq_len = 2, 8
198+
inputs_embeds = torch.randn(batch_size, seq_len, config.hidden_size, dtype=torch.bfloat16)
199+
200+
output = model(inputs_embeds=inputs_embeds)
201+
202+
assert output.shape == (batch_size, seq_len, config.hidden_size)
203+
204+
def test_model_forward_inputs_embeds_bypasses_embedding(self, config, backend):
205+
"""Test that inputs_embeds bypasses the embedding layer."""
206+
from nemo_automodel.components.models.nemotron_v3.model import NemotronV3Model
207+
208+
model = NemotronV3Model(config, backend=backend)
209+
model = model.to(torch.bfloat16)
210+
211+
batch_size, seq_len = 2, 8
212+
inputs_embeds = torch.randn(batch_size, seq_len, config.hidden_size, dtype=torch.bfloat16)
213+
214+
# Should work even with input_ids=None (the default)
215+
output = model(input_ids=None, inputs_embeds=inputs_embeds)
216+
217+
assert output.shape == (batch_size, seq_len, config.hidden_size)
218+
219+
def test_model_forward_inputs_embeds_takes_precedence(self, config, backend):
220+
"""Test that inputs_embeds takes precedence over input_ids when both provided."""
221+
from nemo_automodel.components.models.nemotron_v3.model import NemotronV3Model
222+
223+
model = NemotronV3Model(config, backend=backend)
224+
model = model.to(torch.bfloat16)
225+
226+
batch_size, seq_len = 2, 8
227+
input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len))
228+
inputs_embeds = torch.randn(batch_size, seq_len, config.hidden_size, dtype=torch.bfloat16)
229+
230+
# When both are provided, inputs_embeds should be used (input_ids ignored)
231+
output = model(input_ids, inputs_embeds=inputs_embeds)
232+
233+
assert output.shape == (batch_size, seq_len, config.hidden_size)
234+
235+
def test_model_forward_no_input_ids_no_inputs_embeds_raises(self, config, backend):
236+
"""Test that ValueError is raised when neither input_ids nor inputs_embeds is provided."""
237+
from nemo_automodel.components.models.nemotron_v3.model import NemotronV3Model
238+
239+
model = NemotronV3Model(config, backend=backend)
240+
model = model.to(torch.bfloat16)
241+
242+
with pytest.raises(ValueError, match="input_ids must be provided if inputs_embeds is not provided"):
243+
model(input_ids=None)
244+
245+
def test_model_forward_inputs_embeds_with_mask(self, config, backend):
246+
"""Test model forward pass with inputs_embeds and attention mask."""
247+
from nemo_automodel.components.models.nemotron_v3.model import NemotronV3Model
248+
249+
model = NemotronV3Model(config, backend=backend)
250+
model = model.to(torch.bfloat16)
251+
252+
batch_size, seq_len = 2, 8
253+
inputs_embeds = torch.randn(batch_size, seq_len, config.hidden_size, dtype=torch.bfloat16)
254+
attention_mask = torch.ones(batch_size, seq_len)
255+
256+
output = model(inputs_embeds=inputs_embeds, attention_mask=attention_mask)
257+
258+
assert output.shape == (batch_size, seq_len, config.hidden_size)
259+
190260
def test_model_moe_config_creation(self, config, backend):
191261
"""Test that model creates MoE config correctly."""
192262
from nemo_automodel.components.models.nemotron_v3.model import NemotronV3Model
@@ -307,6 +377,31 @@ def test_causal_lm_forward_float32_logits(self, config, backend):
307377

308378
assert logits.dtype == torch.float32
309379

380+
def test_causal_lm_forward_with_inputs_embeds(self, config, backend):
381+
"""Test causal LM forward pass with inputs_embeds."""
382+
from nemo_automodel.components.models.nemotron_v3.model import NemotronHForCausalLM
383+
384+
model = NemotronHForCausalLM(config, backend=backend)
385+
model = model.to(torch.bfloat16)
386+
387+
batch_size, seq_len = 2, 8
388+
inputs_embeds = torch.randn(batch_size, seq_len, config.hidden_size, dtype=torch.bfloat16)
389+
390+
logits = model(inputs_embeds=inputs_embeds)
391+
392+
assert logits.shape == (batch_size, seq_len, config.vocab_size)
393+
assert logits.dtype == torch.float32
394+
395+
def test_causal_lm_forward_no_input_ids_no_inputs_embeds_raises(self, config, backend):
396+
"""Test that ValueError is raised when neither input_ids nor inputs_embeds is provided."""
397+
from nemo_automodel.components.models.nemotron_v3.model import NemotronHForCausalLM
398+
399+
model = NemotronHForCausalLM(config, backend=backend)
400+
model = model.to(torch.bfloat16)
401+
402+
with pytest.raises(ValueError, match="input_ids must be provided if inputs_embeds is not provided"):
403+
model()
404+
310405
def test_causal_lm_from_config(self, config, backend):
311406
"""Test from_config classmethod."""
312407
from nemo_automodel.components.models.nemotron_v3.model import NemotronHForCausalLM

0 commit comments

Comments
 (0)