diff --git a/examples/asr/conf/fastconformer/transformer_stacking_tdt_bpe.yaml b/examples/asr/conf/fastconformer/transformer_stacking_tdt_bpe.yaml new file mode 100644 index 000000000000..0863a29cc139 --- /dev/null +++ b/examples/asr/conf/fastconformer/transformer_stacking_tdt_bpe.yaml @@ -0,0 +1,185 @@ +name: "Transformer-Stacking-TDT-BPE" + +model: + sample_rate: 16000 + compute_eval_loss: false + log_prediction: true + rnnt_reduction: 'mean_volume' + skip_nan_grad: false + + model_defaults: + enc_hidden: ${model.encoder.d_model} + pred_hidden: 640 + joint_hidden: 640 + tdt_durations: [0, 1, 2, 3, 4] + num_tdt_durations: 5 + + train_ds: + manifest_filepath: ??? + sample_rate: ${model.sample_rate} + batch_size: 16 + shuffle: true + num_workers: 8 + pin_memory: true + max_duration: 20 + min_duration: 0.1 + is_tarred: false + tarred_audio_filepaths: null + shuffle_n: 2048 + bucketing_strategy: "fully_randomized" + bucketing_batch_size: null + + validation_ds: + manifest_filepath: ??? + sample_rate: ${model.sample_rate} + batch_size: 16 + shuffle: false + use_start_end_token: false + num_workers: 8 + pin_memory: true + + test_ds: + manifest_filepath: null + sample_rate: ${model.sample_rate} + batch_size: 16 + shuffle: false + use_start_end_token: false + num_workers: 8 + pin_memory: true + + tokenizer: + dir: ??? + type: bpe + + preprocessor: + _target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor + sample_rate: ${model.sample_rate} + normalize: "per_feature" + window_size: 0.025 + window_stride: 0.01 + window: "hann" + features: 128 + n_fft: 512 + frame_splicing: 1 + dither: 0.00001 + pad_to: 0 + + spec_augment: + _target_: nemo.collections.asr.modules.SpectrogramAugmentation + freq_masks: 2 + time_masks: 10 + freq_width: 27 + time_width: 0.05 + + encoder: + _target_: nemo.collections.asr.modules.transformer_encoder.TransformerEncoder + feat_in: ${model.preprocessor.features} + d_model: 1280 + n_heads: 16 + n_layers: 32 + drop_rate: 0.1 + qkv_bias: false + qk_norm: true + ff_expansion: 4 + subsampling_factor: 8 + + decoder: + _target_: nemo.collections.asr.modules.RNNTDecoder + normalization_mode: null + random_state_sampling: false + blank_as_pad: true + + prednet: + pred_hidden: ${model.model_defaults.pred_hidden} + pred_rnn_layers: 1 + t_max: null + dropout: 0.2 + + joint: + _target_: nemo.collections.asr.modules.RNNTJoint + log_softmax: null + preserve_memory: false + fuse_loss_wer: false + fused_batch_size: 4 + + jointnet: + joint_hidden: ${model.model_defaults.joint_hidden} + activation: "relu" + dropout: 0.2 + + decoding: + strategy: "greedy_batch" + model_type: "tdt" + durations: ${model.model_defaults.tdt_durations} + + greedy: + max_symbols: 10 + use_cuda_graph_decoder: true + + beam: + beam_size: 2 + return_best_hypothesis: false + score_norm: true + tsd_max_sym_exp: 50 + alsd_max_target_len: 2.0 + + loss: + loss_name: "tdt" + tdt_kwargs: + fastemit_lambda: 0.0 + clamp: -1.0 + durations: ${model.model_defaults.tdt_durations} + sigma: 0.02 + omega: 0.1 + + optim: + name: adamw + lr: 5e-4 + betas: [0.9, 0.95] + weight_decay: 1e-2 + + sched: + name: CosineAnnealing + warmup_steps: 25000 + warmup_ratio: null + min_lr: 5e-5 + +trainer: + devices: -1 + num_nodes: 1 + max_epochs: 500 + max_steps: -1 + val_check_interval: 1.0 + accelerator: auto + strategy: + _target_: lightning.pytorch.strategies.DDPStrategy + gradient_as_bucket_view: true + accumulate_grad_batches: 1 + gradient_clip_val: 1.0 + precision: bf16 + log_every_n_steps: 10 + enable_progress_bar: true + num_sanity_val_steps: 0 + check_val_every_n_epoch: 1 + sync_batchnorm: true + enable_checkpointing: false + logger: false + benchmark: false + +exp_manager: + exp_dir: null + name: ${name} + create_tensorboard_logger: false + create_checkpoint_callback: true + checkpoint_callback_params: + monitor: "val_wer" + mode: "min" + save_top_k: 5 + always_save_nemo: true + resume_from_checkpoint: null + resume_if_exists: false + resume_ignore_no_checkpoint: false + create_wandb_logger: false + wandb_logger_kwargs: + name: null + project: null diff --git a/nemo/collections/asr/modules/__init__.py b/nemo/collections/asr/modules/__init__.py index 7259d077809e..b6e79b7451b1 100644 --- a/nemo/collections/asr/modules/__init__.py +++ b/nemo/collections/asr/modules/__init__.py @@ -52,6 +52,7 @@ RandomBlockMasking, RandomProjectionVectorQuantizer, ) +from nemo.collections.asr.modules.transformer_encoder import TransformerEncoder # noqa: F401 __all__ = [ 'AudioToMelSpectrogramPreprocessor', @@ -83,4 +84,5 @@ 'MultiSoftmaxDecoder', 'RandomBlockMasking', 'RandomProjectionVectorQuantizer', + 'TransformerEncoder', ] diff --git a/nemo/collections/asr/modules/transformer_encoder.py b/nemo/collections/asr/modules/transformer_encoder.py new file mode 100644 index 000000000000..6f110cb11cf1 --- /dev/null +++ b/nemo/collections/asr/modules/transformer_encoder.py @@ -0,0 +1,232 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +import torch +import torch.nn as nn +from torch.nn.attention.flex_attention import create_block_mask, flex_attention + +flex_attention_compiled = torch.compile(flex_attention, dynamic=True) + + +@dataclass +class TransformerEncoderConfig: + feat_in: int = 80 + d_model: int = 512 + n_heads: int = 8 + n_layers: int = 17 + drop_rate: float = 0.1 + qkv_bias: bool = False + qk_norm: bool = False + ff_expansion: float = 4.0 + subsampling_factor: int = 4 + # Attention mode — currently only "full" is supported. + # Future: "causal", "lookahead", "local", "sliding_window" + attn_mode: str = "full" + + +def _make_padding_mod(lengths): + """Mask out padding positions based on per-sample lengths.""" + + def pad_mask(b, h, q_idx, kv_idx): + return kv_idx < lengths[b] + + return pad_mask + + +class FeatureStacking(nn.Module): + """Stacks consecutive input frames and projects to model dimension. + + Reduces the temporal resolution by ``subsampling_factor`` while increasing + the feature dimension proportionally, then linearly projects back to + ``feat_out``. + + Args: + subsampling_factor: Number of consecutive frames to stack. + feat_in: Input feature dimension (e.g. number of mel bins). + feat_out: Output feature dimension (model hidden size). + """ + + def __init__(self, subsampling_factor: int, feat_in: int, feat_out: int): + super().__init__() + self.subsampling_factor = subsampling_factor + self.proj = nn.Linear(subsampling_factor * feat_in, feat_out, bias=False) + + def compute_num_out_frames(self, in_frames): + return (in_frames + self.subsampling_factor - 1) // self.subsampling_factor + + def forward(self, x, lengths): + """ + Args: + x: (B, C, T) — input features (channels-first from preprocessor). + lengths: (B,) — valid lengths per sample. + Returns: + x: (B, T', feat_out) — stacked and projected features. + lengths: (B,) — updated lengths after subsampling. + """ + x = x.transpose(1, 2) # (B, C, T) -> (B, T, C) + b, t, c = x.size() + pad_size = (self.subsampling_factor - (t % self.subsampling_factor)) % self.subsampling_factor + if pad_size > 0: + x = nn.functional.pad(x, (0, 0, 0, pad_size)) + t_new = (t + pad_size) // self.subsampling_factor + x = x.reshape(b, t_new, c * self.subsampling_factor) + x = self.proj(x) + lengths = self.compute_num_out_frames(lengths) + return x, lengths + + +class FeedForward(nn.Module): + def __init__(self, cfg: TransformerEncoderConfig): + super().__init__() + ff_hidden = int(cfg.ff_expansion * cfg.d_model) + self.net = nn.Sequential( + nn.Linear(cfg.d_model, ff_hidden), + nn.GELU(), + nn.Dropout(cfg.drop_rate), + nn.Linear(ff_hidden, cfg.d_model), + nn.Dropout(cfg.drop_rate), + ) + + def forward(self, x): + return self.net(x) + + +class MultiHeadAttention(nn.Module): + def __init__(self, cfg: TransformerEncoderConfig): + super().__init__() + self.n_heads = cfg.n_heads + self.head_dim = cfg.d_model // cfg.n_heads + self.d_model = cfg.d_model + + self.w_qkv = nn.Linear(cfg.d_model, 3 * cfg.d_model, bias=cfg.qkv_bias) + self.out_proj = nn.Linear(cfg.d_model, cfg.d_model) + + self.qk_norm = cfg.qk_norm + if cfg.qk_norm: + self.q_norm = nn.LayerNorm(self.head_dim) + self.k_norm = nn.LayerNorm(self.head_dim) + + def forward(self, x, block_mask=None): + B, T, _ = x.shape + H, D = self.n_heads, self.head_dim + + qkv = self.w_qkv(x).view(B, T, 3, H, D).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) + + if self.qk_norm: + q = self.q_norm(q).to(v.dtype) + k = self.k_norm(k).to(v.dtype) + + out = flex_attention_compiled(q, k, v, block_mask=block_mask) + out = out.transpose(1, 2).contiguous().view(B, T, self.d_model) + return self.out_proj(out) + + +class TransformerBlock(nn.Module): + def __init__(self, cfg: TransformerEncoderConfig): + super().__init__() + self.norm1 = nn.LayerNorm(cfg.d_model) + self.attn = MultiHeadAttention(cfg) + self.drop = nn.Dropout(cfg.drop_rate) + self.norm2 = nn.LayerNorm(cfg.d_model) + self.ffn = FeedForward(cfg) + + def forward(self, x, block_mask=None): + x = x + self.drop(self.attn(self.norm1(x), block_mask=block_mask)) + x = x + self.drop(self.ffn(self.norm2(x))) + return x + + +class TransformerEncoder(nn.Module): + """Pre-norm Transformer encoder for ASR. + + Architecture: FeatureStacking -> EmbedScale -> LayerNorm -> N x TransformerBlock -> FinalNorm + + Uses PyTorch FlexAttention for attention computation. On CUDA, mask functions + are compiled into fused Triton kernels with block-sparse optimization. On CPU, + FlexAttention falls back to an unfused implementation automatically. + + Args: + feat_in: Input feature dimension (number of mel bins). + d_model: Transformer hidden dimension. + n_heads: Number of attention heads. + n_layers: Number of transformer blocks. + drop_rate: Dropout probability. + qkv_bias: Whether to use bias in Q/K/V projections. + qk_norm: Whether to apply per-head LayerNorm to Q and K before the dot product. + ff_expansion: Feed-forward expansion factor (float to support sub-1x for MoE). + subsampling_factor: Frame stacking factor for the pre-encoder. + attn_mode: Attention pattern — currently only "full" (bidirectional) is supported. + """ + + def __init__( + self, + feat_in: int = 80, + d_model: int = 512, + n_heads: int = 8, + n_layers: int = 17, + drop_rate: float = 0.1, + qkv_bias: bool = False, + qk_norm: bool = False, + ff_expansion: float = 4.0, + subsampling_factor: int = 4, + attn_mode: str = "full", + ): + super().__init__() + if attn_mode != "full": + raise ValueError(f"attn_mode='{attn_mode}' is not yet supported. Currently only 'full' is available.") + + cfg = TransformerEncoderConfig( + feat_in=feat_in, + d_model=d_model, + n_heads=n_heads, + n_layers=n_layers, + drop_rate=drop_rate, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + ff_expansion=ff_expansion, + subsampling_factor=subsampling_factor, + attn_mode=attn_mode, + ) + self.d_model = d_model + + self.pre_encode = FeatureStacking(subsampling_factor, feat_in, d_model) + self.embed_norm = nn.LayerNorm(d_model) + self.layers = nn.ModuleList([TransformerBlock(cfg) for _ in range(n_layers)]) + self.final_norm = nn.LayerNorm(d_model) + + def forward(self, audio_signal, length): + """ + Args: + audio_signal: (B, C, T) — mel spectrogram from preprocessor. + length: (B,) — valid frame counts per sample. + Returns: + x: (B, D, T') — encoded representation (channels-first). + length: (B,) — output lengths after subsampling. + """ + x, length = self.pre_encode(audio_signal, length) + + x = self.embed_norm(x) + + B, T, _ = x.shape + block_mask = create_block_mask(_make_padding_mod(length), B=B, H=1, Q_LEN=T, KV_LEN=T, device=x.device) + + for layer in self.layers: + x = layer(x, block_mask=block_mask) + + x = self.final_norm(x) + x = x.transpose(1, 2) # (B, T, D) -> (B, D, T) + return x, length diff --git a/tests/collections/asr/test_transformer_encoder.py b/tests/collections/asr/test_transformer_encoder.py new file mode 100644 index 000000000000..64ec3cf8a945 --- /dev/null +++ b/tests/collections/asr/test_transformer_encoder.py @@ -0,0 +1,265 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch + +from nemo.collections.asr.modules.transformer_encoder import ( + FeatureStacking, + TransformerEncoder, + TransformerEncoderConfig, +) + + +class TestTransformerEncoderConfig: + @pytest.mark.unit + def test_default_config(self): + cfg = TransformerEncoderConfig() + assert cfg.feat_in == 80 + assert cfg.d_model == 512 + assert cfg.n_heads == 8 + assert cfg.n_layers == 17 + assert cfg.drop_rate == 0.1 + assert cfg.qkv_bias is False + assert cfg.qk_norm is False + assert cfg.ff_expansion == 4.0 + assert cfg.subsampling_factor == 4 + assert cfg.attn_mode == "full" + + @pytest.mark.unit + def test_custom_config(self): + cfg = TransformerEncoderConfig(feat_in=128, d_model=1280, n_heads=16, n_layers=32, qk_norm=True) + assert cfg.feat_in == 128 + assert cfg.d_model == 1280 + assert cfg.n_heads == 16 + assert cfg.n_layers == 32 + assert cfg.qk_norm is True + + +class TestFeatureStacking: + @pytest.mark.unit + @pytest.mark.parametrize("subsampling_factor", [2, 4, 8]) + def test_output_shape(self, subsampling_factor): + B, C, T = 2, 80, 400 + stacking = FeatureStacking(subsampling_factor=subsampling_factor, feat_in=C, feat_out=256) + x = torch.randn(B, C, T) + lengths = torch.tensor([400, 300]) + + out, out_lengths = stacking(x, lengths) + expected_t = stacking.compute_num_out_frames(T) + assert out.shape == (B, expected_t, 256) + assert out_lengths[0].item() == expected_t + + @pytest.mark.unit + def test_padding_when_not_divisible(self): + B, C, T = 1, 80, 401 + subsampling_factor = 4 + stacking = FeatureStacking(subsampling_factor=subsampling_factor, feat_in=C, feat_out=256) + x = torch.randn(B, C, T) + lengths = torch.tensor([401]) + + out, out_lengths = stacking(x, lengths) + expected_t = stacking.compute_num_out_frames(T) + assert out.shape == (B, expected_t, 256) + + @pytest.mark.unit + def test_length_shorter_than_batch(self): + """Output length must be ceil(sample_length / factor), not dependent on batch T.""" + B, C, T = 2, 80, 403 + subsampling_factor = 4 + stacking = FeatureStacking(subsampling_factor=subsampling_factor, feat_in=C, feat_out=256) + x = torch.randn(B, C, T) + lengths = torch.tensor([401, 397]) + + _, out_lengths = stacking(x, lengths) + assert out_lengths[0].item() == stacking.compute_num_out_frames(401) + assert out_lengths[1].item() == stacking.compute_num_out_frames(397) + + @pytest.mark.unit + def test_no_padding_when_divisible(self): + B, C, T = 1, 80, 400 + stacking = FeatureStacking(subsampling_factor=4, feat_in=C, feat_out=256) + x = torch.randn(B, C, T) + lengths = torch.tensor([400]) + + out, out_lengths = stacking(x, lengths) + assert out.shape == (B, stacking.compute_num_out_frames(T), 256) + assert out_lengths[0].item() == stacking.compute_num_out_frames(T) + + +class TestTransformerEncoder: + @pytest.mark.unit + def test_model_creation(self): + model = TransformerEncoder(feat_in=80, d_model=64, n_heads=4, n_layers=2) + total_params = sum(p.numel() for p in model.parameters()) + assert total_params > 0 + assert len(model.layers) == 2 + + @pytest.mark.unit + def test_model_creation_with_qk_norm(self): + model = TransformerEncoder(feat_in=80, d_model=64, n_heads=4, n_layers=2, qk_norm=True) + attn = model.layers[0].attn + assert hasattr(attn, 'q_norm') + assert hasattr(attn, 'k_norm') + + @pytest.mark.unit + def test_model_creation_without_qk_norm(self): + model = TransformerEncoder(feat_in=80, d_model=64, n_heads=4, n_layers=2, qk_norm=False) + attn = model.layers[0].attn + assert not hasattr(attn, 'q_norm') + assert not hasattr(attn, 'k_norm') + + @pytest.mark.unit + def test_invalid_attn_mode(self): + with pytest.raises(ValueError, match="not yet supported"): + TransformerEncoder(feat_in=80, d_model=64, n_heads=4, n_layers=2, attn_mode="causal") + + @pytest.mark.unit + def test_forward_cpu(self): + """Forward pass on CPU uses unfused FlexAttention fallback.""" + model = TransformerEncoder(feat_in=80, d_model=64, n_heads=4, n_layers=2, drop_rate=0.0, subsampling_factor=4) + model.eval() + + B, C, T = 2, 80, 400 + x = torch.randn(B, C, T) + lengths = torch.tensor([400, 300]) + + with torch.no_grad(): + out, out_lengths = model(x, lengths) + + assert out.shape == (B, 64, T // 4) + assert out_lengths[0].item() == T // 4 + assert out_lengths[1].item() == 300 // 4 + assert not torch.isnan(out).any() + + @pytest.mark.unit + def test_forward_cpu_with_qk_norm(self): + model = TransformerEncoder(feat_in=80, d_model=64, n_heads=4, n_layers=2, drop_rate=0.0, qk_norm=True) + model.eval() + + x = torch.randn(1, 80, 200) + lengths = torch.tensor([200]) + + with torch.no_grad(): + out, _ = model(x, lengths) + + assert out.shape == (1, 64, 50) + assert not torch.isnan(out).any() + + @pytest.mark.run_only_on('GPU') + def test_forward_basic(self): + model = TransformerEncoder(feat_in=80, d_model=64, n_heads=4, n_layers=2, drop_rate=0.0, subsampling_factor=4) + model = model.cuda().to(torch.bfloat16) + + B, C, T = 2, 80, 400 + x = torch.randn(B, C, T, device='cuda', dtype=torch.bfloat16) + lengths = torch.tensor([400, 300], device='cuda') + + model.eval() + with torch.no_grad(): + out, out_lengths = model(x, lengths) + + assert out.shape == (B, 64, T // 4) + assert out_lengths[0].item() == T // 4 + assert out_lengths[1].item() == 300 // 4 + assert not torch.isnan(out).any() + + @pytest.mark.run_only_on('GPU') + def test_forward_with_qk_norm(self): + model = TransformerEncoder( + feat_in=128, d_model=128, n_heads=8, n_layers=2, drop_rate=0.0, qk_norm=True, subsampling_factor=8 + ) + model = model.cuda().to(torch.bfloat16) + + B, C, T = 2, 128, 800 + x = torch.randn(B, C, T, device='cuda', dtype=torch.bfloat16) + lengths = torch.tensor([800, 640], device='cuda') + + model.eval() + with torch.no_grad(): + out, out_lengths = model(x, lengths) + + assert out.shape == (B, 128, T // 8) + assert not torch.isnan(out).any() + + @pytest.mark.run_only_on('GPU') + def test_forward_output_channels_first(self): + """Verify output is (B, D, T) channels-first as expected by downstream decoders.""" + model = TransformerEncoder(feat_in=80, d_model=64, n_heads=4, n_layers=1, drop_rate=0.0) + model = model.cuda().to(torch.bfloat16) + + x = torch.randn(1, 80, 200, device='cuda', dtype=torch.bfloat16) + lengths = torch.tensor([200], device='cuda') + + model.eval() + with torch.no_grad(): + out, _ = model(x, lengths) + + assert out.shape[1] == 64 # D dimension + assert out.shape[2] == 200 // 4 # T dimension + + @pytest.mark.run_only_on('GPU') + def test_eval_deterministic(self): + """In eval mode with no dropout, repeated forward passes should produce identical output.""" + model = TransformerEncoder(feat_in=80, d_model=64, n_heads=4, n_layers=2, drop_rate=0.0) + model = model.cuda().to(torch.bfloat16).eval() + + x = torch.randn(1, 80, 200, device='cuda', dtype=torch.bfloat16) + lengths = torch.tensor([200], device='cuda') + + with torch.no_grad(): + out1, _ = model(x, lengths) + out2, _ = model(x, lengths) + + assert torch.allclose(out1, out2, atol=1e-6) + + @pytest.mark.run_only_on('GPU') + def test_padding_does_not_affect_valid_output(self): + """Padding frames should not change the encoded output at valid positions.""" + model = TransformerEncoder(feat_in=80, d_model=64, n_heads=4, n_layers=2, drop_rate=0.0) + model = model.cuda().to(torch.bfloat16).eval() + + T_valid = 200 + x_short = torch.randn(1, 80, T_valid, device='cuda', dtype=torch.bfloat16) + lengths_short = torch.tensor([T_valid], device='cuda') + + T_padded = 400 + x_long = torch.zeros(1, 80, T_padded, device='cuda', dtype=torch.bfloat16) + x_long[:, :, :T_valid] = x_short + lengths_long = torch.tensor([T_valid], device='cuda') + + with torch.no_grad(): + out_short, len_short = model(x_short, lengths_short) + out_long, len_long = model(x_long, lengths_long) + + assert len_short[0].item() == len_long[0].item() + valid_t = len_short[0].item() + # bf16 + different block mask shapes cause small numerical differences in Triton kernels + assert torch.allclose(out_short[:, :, :valid_t], out_long[:, :, :valid_t], atol=5e-2) + + @pytest.mark.run_only_on('GPU') + def test_backward_pass(self): + model = TransformerEncoder(feat_in=80, d_model=64, n_heads=4, n_layers=2, drop_rate=0.0) + model = model.cuda().to(torch.bfloat16).train() + + x = torch.randn(2, 80, 200, device='cuda', dtype=torch.bfloat16) + lengths = torch.tensor([200, 160], device='cuda') + + out, out_lengths = model(x, lengths) + loss = out.sum() + loss.backward() + + for name, param in model.named_parameters(): + assert param.grad is not None, f"No gradient for {name}" + assert not torch.isnan(param.grad).any(), f"NaN gradient for {name}"