Skip to content

Commit d47f8ae

Browse files
jwilberJared Wilberchtruong814
authored andcommitted
Add fix for evo2 generate/inference (NVIDIA-NeMo#14027)
* Add fix for evo2 generate/inference Signed-off-by: Jared Wilber <jwilber@login-eos01.eos.clusters.nvidia.com> * Add Farhad's suggested refactor Signed-off-by: Jared Wilber <jwilber@login-eos01.eos.clusters.nvidia.com> * lint Signed-off-by: Jared Wilber <jwilber@login-eos01.eos.clusters.nvidia.com> * Remove unused and truism code Signed-off-by: Jared Wilber <jwilber@login-eos01.eos.clusters.nvidia.com> * add inference_context conditional check to use new ops Signed-off-by: Jared Wilber <jwilber@login-eos01.eos.clusters.nvidia.com> * add hyena operator tests Signed-off-by: Jared Wilber <jwilber@login-eos01.eos.clusters.nvidia.com> * lint Signed-off-by: Jared Wilber <jwilber@login-eos01.eos.clusters.nvidia.com> * remove unused code Signed-off-by: Jared Wilber <jwilber@login-eos01.eos.clusters.nvidia.com> * add doc strings for flake8 Signed-off-by: Jared Wilber <jwilber@login-eos01.eos.clusters.nvidia.com> * Apply isort and black reformatting Signed-off-by: jwilber <jwilber@users.noreply.github.com> * remove unnecessary assignments Signed-off-by: Jared Wilber <jwilber@login-eos01.eos.clusters.nvidia.com> * invoke original forward for non-inference calls Signed-off-by: Jared Wilber <jwilber@login-eos01.eos.clusters.nvidia.com> * Apply isort and black reformatting Signed-off-by: jwilber <jwilber@users.noreply.github.com> * Fix reset issue Signed-off-by: Jared Wilber <jwilber@login-eos01.eos.clusters.nvidia.com> * Apply isort and black reformatting Signed-off-by: jwilber <jwilber@users.noreply.github.com> * add docstring Signed-off-by: Jared Wilber <jwilber@login-eos01.eos.clusters.nvidia.com> * Apply isort and black reformatting Signed-off-by: jwilber <jwilber@users.noreply.github.com> * Remove test Signed-off-by: Jared Wilber <jwilber@login-eos01.eos.clusters.nvidia.com> * Add tests for hyena operator Signed-off-by: Jared Wilber <jwilber@login-eos01.eos.clusters.nvidia.com> * Apply isort and black reformatting Signed-off-by: jwilber <jwilber@users.noreply.github.com> * lint Signed-off-by: Jared Wilber <jwilber@login-eos01.eos.clusters.nvidia.com> * simplify context manager in tests Signed-off-by: Jared Wilber <jwilber@login-eos01.eos.clusters.nvidia.com> * remove unused import in test Signed-off-by: Jared Wilber <jwilber@login-eos01.eos.clusters.nvidia.com> * Add env vars for test Signed-off-by: Jared Wilber <jwilber@login-eos01.eos.clusters.nvidia.com> * mark tests gpu only Signed-off-by: Jared Wilber <jwilber@login-eos01.eos.clusters.nvidia.com> * Apply isort and black reformatting Signed-off-by: jwilber <jwilber@users.noreply.github.com> --------- Signed-off-by: Jared Wilber <jwilber@login-eos01.eos.clusters.nvidia.com> Signed-off-by: jwilber <jwilber@users.noreply.github.com> Co-authored-by: Jared Wilber <jwilber@login-eos01.eos.clusters.nvidia.com> Co-authored-by: jwilber <jwilber@users.noreply.github.com> Co-authored-by: Charlie Truong <chtruong@nvidia.com>
1 parent 80d19fd commit d47f8ae

7 files changed

Lines changed: 1016 additions & 30 deletions

File tree

nemo/collections/llm/api.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1049,6 +1049,7 @@ def generate(
10491049
text_only: bool = False,
10501050
output_path: Optional[AnyPath] = None,
10511051
enable_flash_decode: bool = True,
1052+
**kwargs,
10521053
) -> list[Union["InferenceRequest", str]]:
10531054
"""
10541055
Generates text using a NeMo LLM model.
@@ -1116,6 +1117,7 @@ def generate(
11161117
output_path (Optional[Union[Path, str]], optional): The path to save the generated text or test dataset
11171118
predictions. Defaults to None.
11181119
enable_flash_decode (bool, optional): Whether to enable flash decode. Defaults to True.
1120+
**kwargs: Additional keyword arguments passed to setup_model_and_tokenizer.
11191121
11201122
Returns:
11211123
list[Union["InferenceRequest", str]]: A list of generated text,
@@ -1139,6 +1141,7 @@ def generate(
11391141
params_dtype=params_dtype,
11401142
inference_batch_times_seqlen_threshold=inference_batch_times_seqlen_threshold,
11411143
enable_flash_decode=enable_flash_decode,
1144+
**kwargs,
11421145
)
11431146

11441147
max_seq_length = inference_params.num_tokens_to_generate + max(len(mcore_tokenizer.tokenize(p)) for p in inputs)

nemo/collections/llm/gpt/model/hyena.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
import torch
2525
from megatron.core import parallel_state
26+
from megatron.core.inference.contexts import StaticInferenceContext
2627
from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import GPTInferenceWrapper
2728
from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import InferenceWrapperConfig
2829
from megatron.core.transformer.enums import AttnBackend
@@ -38,6 +39,18 @@
3839
from nemo.utils import logging
3940

4041

42+
class HyenaInferenceContext(StaticInferenceContext):
43+
"""Hyena-specific inference context."""
44+
45+
def reset(self):
46+
"""Reset the inference context."""
47+
super().reset() # standard state reset for GPT models
48+
for key in dir(self):
49+
# Remove all of the state that we add in hyena.py
50+
if "filter_state_dict" in key:
51+
delattr(self, key)
52+
53+
4154
class HyenaModel(GPTModel):
4255
"""
4356
This is a wrapper around the MCoreHyenaModel to allow for inference. Our model follows the same API
@@ -88,9 +101,11 @@ def get_inference_wrapper(
88101
inference_batch_times_seqlen_threshold=inference_batch_times_seqlen_threshold,
89102
padded_vocab_size=vocab_size,
90103
inference_max_seq_length=inference_max_seq_length,
104+
inference_max_requests=1,
91105
)
92106

93-
model_inference_wrapper = GPTInferenceWrapper(mcore_model, inference_wrapper_config)
107+
inference_context = HyenaInferenceContext.from_config(inference_wrapper_config)
108+
model_inference_wrapper = GPTInferenceWrapper(mcore_model, inference_wrapper_config, inference_context)
94109
return model_inference_wrapper
95110

96111
def forward(
@@ -467,6 +482,7 @@ class Hyena7bARCLongContextConfig(Hyena7bConfig):
467482
due to constraintes from large TP size for training."""
468483

469484
ffn_hidden_size: int = 11264
485+
seq_len_interpolation_factor: float = 128
470486

471487

472488
@dataclass
@@ -475,6 +491,7 @@ class Hyena40bARCLongContextConfig(Hyena40bConfig):
475491
due to constraintes from large TP size for training."""
476492

477493
ffn_hidden_size: int = 22528
494+
seq_len_interpolation_factor: float = 128
478495

479496

480497
@io.model_importer(HyenaModel, "pytorch")
Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# Copyright (c) 2024 Arc Institute. All rights reserved.
3+
# Copyright (c) 2024 Michael Poli. All rights reserved.
4+
# Copyright (c) 2024 Stanford University. All rights reserved
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
18+
import torch
19+
import torch.nn.functional as F
20+
from einops import rearrange
21+
22+
23+
def adjust_filter_shape_for_broadcast(u, h):
24+
"""
25+
Adjust filter shape for broadcasting compatibility with input tensor.
26+
"""
27+
h = h.squeeze() # Standardize to [D, L] from [1, D, L] and [D, 1, L]
28+
29+
# Case: u: [B, D, L], k_f: [D, L]
30+
if len(u.shape) > len(h.shape):
31+
h = h.unsqueeze(0)
32+
33+
# Case: u: [B, D1, D2, L], k_f: [B, D, L]
34+
if len(u.shape) > 3:
35+
h = h.unsqueeze(1)
36+
return h
37+
38+
39+
def fftconv_func(*, u, k, D):
40+
"""
41+
Compute fast Fourier transform convolution with bias addition.
42+
43+
This function performs convolution using FFT for efficient computation of long sequences.
44+
The convolution is computed in the frequency domain and then transformed back to the time domain.
45+
"""
46+
seqlen = u.shape[-1]
47+
fft_size = 2 * seqlen
48+
49+
k_f = torch.fft.rfft(k, n=fft_size) / fft_size
50+
k_f = adjust_filter_shape_for_broadcast(u, k_f)
51+
k = k.squeeze()
52+
53+
u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size)
54+
55+
y = torch.fft.irfft(u_f * k_f, n=fft_size, norm="forward")[..., :seqlen]
56+
57+
return y + u * D.unsqueeze(-1)
58+
59+
60+
def parallel_fir(
61+
*,
62+
u, # B L D
63+
weight,
64+
bias,
65+
L,
66+
gated_bias,
67+
fir_length,
68+
compute_state,
69+
):
70+
"""Compute parallel finite impulse response filtering with optional state computation."""
71+
L = u.shape[1]
72+
u = rearrange(u, "b l d -> b d l")
73+
74+
if fir_length >= 128:
75+
with torch.autocast("cuda"):
76+
z = fftconv_func(
77+
u=u.to(torch.float32),
78+
k=weight[:, :, :L].to(torch.float32),
79+
D=bias,
80+
).to(dtype=u.dtype)
81+
else:
82+
z = F.conv1d(
83+
u.to(torch.float32),
84+
weight.to(torch.float32),
85+
bias=None,
86+
stride=1,
87+
padding=fir_length - 1,
88+
groups=u.shape[1], # always set to D, regardless of filter grouping
89+
)[..., :L]
90+
91+
z = z.to(u.dtype)
92+
93+
if bias is not None:
94+
if gated_bias:
95+
z = z + bias[None, :, None] * u
96+
else:
97+
z = z + bias[None, :, None]
98+
99+
fir_state = None
100+
if compute_state:
101+
fir_state = u[..., -fir_length + 1 :]
102+
return z, fir_state
103+
104+
105+
def parallel_iir(*, z_pre, h, D, L, poles, t, hidden_size, compute_state):
106+
"""Compute the output state of the short convolutional filter."""
107+
fft_size = 2 * L
108+
x1, x2, v = z_pre.split([hidden_size, hidden_size, hidden_size], dim=1)
109+
110+
x1v = x1 * v
111+
112+
H = torch.fft.rfft(h.to(dtype=torch.float32), n=fft_size) / fft_size
113+
X_s = torch.fft.fft(x1v.to(dtype=torch.float32), n=fft_size)
114+
X = X_s[..., : H.shape[-1]]
115+
if len(z_pre.shape) > 3:
116+
H = H.unsqueeze(1)
117+
y = torch.fft.irfft(X * H, n=fft_size, norm="forward")[..., :L]
118+
y = y.to(dtype=x1v.dtype)
119+
y = (y + x1v * D.unsqueeze(-1)) * x2
120+
121+
iir_state = None
122+
if compute_state:
123+
iir_state = prefill_via_modal_fft(
124+
x1v=x1v,
125+
X_s=X_s,
126+
L=L,
127+
t=t,
128+
poles=poles,
129+
)
130+
131+
return y.permute(0, 2, 1), iir_state
132+
133+
134+
def step_fir(*, u, fir_state, weight, bias=None, gated_bias=False, flip_filter=False):
135+
"""Steps forward FIR filters in the architecture.
136+
FIR filters generally include truncated convolutions in Hyena with an explicit or
137+
hybrid time-domain parametrization:
138+
* Short FIR filters in Hyena featurizers
139+
* Short and medium FIR filters in Hyena operators
140+
Note:
141+
`fir_state` contains the last FIR filter length - 1 elements of `u`: `u_(L-2), u_{L-1), ...`
142+
We assume dimensions of `short_filter_weight` to be `[d, 1, short_filter_len]`.
143+
"""
144+
weight = weight.squeeze()
145+
146+
cache_size = fir_state.shape[-1]
147+
filter_length = weight.shape[-1]
148+
if flip_filter:
149+
weight = weight.flip(-1)
150+
weight = weight[..., -cache_size - 1 :].unsqueeze(0)
151+
else:
152+
weight = weight[..., : cache_size + 1].unsqueeze(0)
153+
154+
input_dtype = u.dtype
155+
weight = weight.to(torch.float32)
156+
u = u.to(torch.float32)
157+
fir_state = fir_state.to(torch.float32)
158+
bias = bias.to(torch.float32) if bias is not None else None
159+
160+
h0, h = weight[..., -1], weight[..., :-1]
161+
y = h0 * u + torch.sum(fir_state * h, dim=-1)
162+
163+
if bias is not None:
164+
if gated_bias:
165+
y = y + bias * u
166+
else:
167+
y = y + bias
168+
169+
# Update the state
170+
if cache_size < filter_length - 1:
171+
fir_state = torch.cat([fir_state, u[..., None]], dim=-1)
172+
else:
173+
fir_state = torch.roll(fir_state, -1, dims=2)
174+
fir_state[..., -1] = u
175+
176+
return y.to(input_dtype), fir_state
177+
178+
179+
def step_iir(*, x2, x1, v, D, residues, poles, iir_state):
180+
"""Steps forward IIR filters in the architecture."""
181+
x1v = x1 * v
182+
poles = torch.exp(poles) # poles arg contains log_poles
183+
poles = poles[..., 0][None] # squeeze dummy seqlen dim and add dummy batch dim
184+
residues = residues[None] # add dummy batch dim
185+
iir_state = poles * iir_state + x1v[..., None]
186+
187+
res_state = torch.sum(residues * iir_state, dim=-1)
188+
y = x2 * (res_state + D * x1v)
189+
return y, iir_state
190+
191+
192+
def prefill_via_modal_fft(*, x1v, L, poles, t, X_s):
193+
"""
194+
Compute the IIR state via a single FFT
195+
"""
196+
# When the model has a long convolution derived from a recurrence in modal form and prefill_style is "fft",
197+
# we split the filter into poles and residues and reuse FFT computation on the input.
198+
bs = x1v.shape[0]
199+
fft_size = 2 * L
200+
state_s = (poles.to(torch.float32) * t).exp()
201+
state_S = torch.fft.fft(state_s, n=fft_size).repeat(bs, 1, 1, 1) # B, D, state_dim, 2 * L
202+
state = torch.fft.ifft(X_s[..., None, :] * state_S, n=fft_size)
203+
return state[..., L - 1].to(dtype=torch.float32)

nemo/collections/llm/gpt/model/megatron/hyena/hyena_mixer.py

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,12 @@
3535
from nemo.collections.llm.gpt.model.megatron.hyena.hyena_config import HyenaConfig
3636
from nemo.collections.llm.gpt.model.megatron.hyena.hyena_utils import (
3737
B2BCausalConv1dModule,
38-
ParallelCausalDepthwiseConv1d,
38+
ParallelCausalDepthwiseConv1dWithState,
3939
ParallelHyenaOperator,
4040
ParallelShortHyenaOperator,
4141
divide,
4242
)
4343

44-
4544
logger = logging.getLogger(__name__)
4645

4746
try:
@@ -160,7 +159,8 @@ def __init__(
160159

161160
hyena_proj_groups = self.proj_groups if not self.grouped_attention else 1
162161
grouped_proj_size = self.hidden_size_per_partition // hyena_proj_groups
163-
self.hyena_proj_conv = ParallelCausalDepthwiseConv1d(
162+
163+
self.hyena_proj_conv = ParallelCausalDepthwiseConv1dWithState(
164164
self.hidden_size_per_partition + 2 * grouped_proj_size,
165165
self.transformer_config,
166166
self.hyena_config,
@@ -179,7 +179,7 @@ def __init__(
179179
self.transformer_config,
180180
self.hyena_config,
181181
self.transformer_config.init_method,
182-
short_conv_class=ParallelCausalDepthwiseConv1d,
182+
short_conv_class=ParallelCausalDepthwiseConv1dWithState,
183183
use_fast_causal_conv=self.fast_conv_mixer,
184184
use_conv_bias=self.transformer_config.use_short_conv_bias,
185185
)
@@ -280,18 +280,48 @@ def forward(self, x, layer_past=None, inference_context=None, _hyena_use_cp=True
280280
_proj_use_cp = True
281281
else:
282282
_proj_use_cp = False
283-
features, _ = self._maybe_use_fp8(self.dense_projection, x)
283+
# Handle padding for FP8 if enabled
284+
if self.transformer_config.vortex_style_fp8:
285+
286+
def pad_to_multiple(x, multiple=16):
287+
"""Pad tensor to make sequence length divisible by multiple."""
288+
seq_len = x.size(0)
289+
if seq_len % multiple == 0:
290+
return x
291+
292+
pad_len = multiple - (seq_len % multiple)
293+
pad_tensor = torch.zeros(pad_len, *x.shape[1:], device=x.device, dtype=x.dtype)
294+
return torch.cat([x, pad_tensor], dim=0)
295+
296+
# Direct padding without rearrange
297+
L = x.shape[0]
298+
x = pad_to_multiple(x)
299+
features, _ = self._maybe_use_fp8(self.dense_projection, x)
300+
301+
# Slice back to original sequence length if padding was added
302+
303+
if features.shape[0] > L:
304+
features = features[:L, :, :]
305+
else:
306+
features, _ = self.dense_projection(x)
284307
features = rearrange(features, "l b d -> b d l").contiguous()
285308

286-
if self.use_b2b_causal_conv1d and self.operator_type in ["hyena_short_conv", "hyena_medium_conv"]:
309+
if (
310+
self.use_b2b_causal_conv1d
311+
and self.operator_type in ["hyena_short_conv", "hyena_medium_conv"]
312+
and inference_context is not None
313+
):
314+
# todo: support inference_context for b2b_kernel
287315
# Use the B2BCausalConv1dModule wrapper with the existing weights from the original model
288316
z = self.b2b_kernel(features, _use_cp=_proj_use_cp)
289317
else:
290-
features = self.hyena_proj_conv(features, _use_cp=_proj_use_cp) # [B, D, L]
318+
features = self.hyena_proj_conv(
319+
features, _use_cp=_proj_use_cp, inference_context=inference_context
320+
) # [B, D, L]
291321
x1, x2, v = rearrange(features, "b (g dg p) l -> b (g dg) p l", p=3, g=self.num_groups_per_tp_rank).unbind(
292322
dim=2
293323
)
294-
z = self.mixer(x1, x2, v, _hyena_use_cp=_proj_use_cp)
324+
z = self.mixer(x1, x2, v, _hyena_use_cp=_proj_use_cp, inference_context=inference_context)
295325

296326
z = rearrange(z, "b d l -> l b d").contiguous()
297327
y, bias = self.dense(z)

0 commit comments

Comments
 (0)