Skip to content

Add Transformer Encoder for ASR #15661

Open
nithinraok wants to merge 6 commits intomainfrom
transformer_asr_pr
Open

Add Transformer Encoder for ASR #15661
nithinraok wants to merge 6 commits intomainfrom
transformer_asr_pr

Conversation

@nithinraok
Copy link
Copy Markdown
Member

@nithinraok nithinraok commented May 4, 2026

Important

The Update branch button must only be pressed in very rare occassions.
An outdated branch is never blocking the merge of a PR.
Please reach out to the automation team before pressing that button.

What does this PR do ?

Adds a new ASR transformer encoder with frame stacking, configurable subsampling, FlexAttention-based full attention, and optional QK Norm

Collection: ASR

Changelog

  • Add TransformerEncoder in nemo/collections/asr/modules/transformer_encoder.py as a lightweight pre-norm transformer encoder for ASR.
  • Add TransformerEncoderConfig dataclass to capture encoder hyperparameters such as d_model, n_heads, n_layers, qkv_bias, qk_norm, and
    subsampling_factor.
  • Add FeatureStacking pre-encoder module to stack consecutive frames, reduce sequence length, and project stacked features into the model
    dimension.
  • Add transformer building blocks: FeedForward, MultiHeadAttention, and TransformerBlock.
  • Implement full-attention masking with PyTorch FlexAttention and padding-aware block masks.
  • Add optional per-head QK normalization before attention score computation.
  • Restrict the initial PR scope to attn_mode="full" and raise an error for unsupported future modes.
  • Export TransformerEncoder from nemo/collections/asr/modules/__init__.py for config-based instantiation.
  • Add example config in examples/asr/conf/fastconformer/transformer_stacking_tdt_bpe.yaml showing how to train an RNNT/TDT model with the new
    encoder.

Usage

  import torch

  from nemo.collections.asr.modules import TransformerEncoder

  encoder = TransformerEncoder(
      feat_in=128,
      d_model=512,
      n_heads=8,
      n_layers=12,
      drop_rate=0.1,
      qkv_bias=False,
      qk_norm=True,
      subsampling_factor=4,
      attn_mode="full",
  )

  audio_signal = torch.randn(2, 128, 400)   # (B, C, T)
  lengths = torch.tensor([400, 360])        # valid frame lengths

  encoded, encoded_lengths = encoder(audio_signal, lengths)

  print(encoded.shape)         # (B, D, T')
  print(encoded_lengths)       # ceil(length / subsampling_factor)

You can also instantiate it from Hydra config with:

  encoder:
    _target_: nemo.collections.asr.modules.transformer_encoder.TransformerEncoder
    feat_in: ${model.preprocessor.features}
    d_model: 1280
    n_heads: 16
    n_layers: 32
    drop_rate: 0.1
    qkv_bias: false
    qk_norm: true
    subsampling_factor: 8
    attn_mode: full

GitHub Actions CI

The Jenkins CI system has been replaced by GitHub Actions self-hosted runners.

The GitHub Actions CI will run automatically when the "Run CICD" label is added to the PR.
To re-run CI remove and add the label again.
To run CI on an untrusted fork, a NeMo user with write access must first click "Approve and run".

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?
  • Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
    • Reviewer: Does the PR have correct import guards for all optional libraries?

PR Type:

  • New Feature
  • Bugfix
  • Documentation

If you haven't finished some of the above items you can still open "Draft" PR.

Who can review?

Anyone in the community is free to review the PR once the checks have passed.
Contributor guidelines contains specific people who can review PRs to various areas.

Additional Information

  • Related to # (issue)

nithinraok added 2 commits May 4, 2026 08:37
Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>
Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 4, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@github-actions github-actions Bot added the ASR label May 4, 2026
@nithinraok nithinraok requested a review from pzelasko May 4, 2026 15:51
self.head_dim = cfg.d_model // cfg.n_heads
self.d_model = cfg.d_model

self.w_q = nn.Linear(cfg.d_model, cfg.d_model, bias=cfg.qkv_bias)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's merge these 3 lines into w_qkv to batch 3 ops in one

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was done due to QK-Norm option to enable. But yes, I can unbind again for such ops.

Comment thread nemo/collections/asr/modules/transformer_encoder.py Outdated
Comment thread nemo/collections/asr/modules/transformer_encoder.py Outdated
self.qk_norm = cfg.qk_norm
if cfg.qk_norm:
self.q_norm = nn.LayerNorm(self.head_dim)
self.k_norm = nn.LayerNorm(self.head_dim)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about RMSNorm?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven;t validated runs with RMSNorm yet, so will keep this for future work.

x = nn.functional.pad(x, (0, 0, 0, pad_size))
t_new = (t + pad_size) // self.subsampling_factor
x = x.reshape(b, t_new, c * self.subsampling_factor)
x = self.proj(x)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we want to add some norm before projection to center the activations? rms norm could be reasonable although then we might need to change the zero-padding above to something that doesn't skew the norm too much

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for suggestion. I will need to check by running experiments with this effect. So for now will defer to later PR.

length: (B,) — output lengths after subsampling.
"""
x, length = self.pre_encode(audio_signal, length)
x = x * (self.d_model**0.5)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why re-scale just before norm?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added to prevent NaNs during training, similar to the fix we applied in FastConformer during scaling.

I need to check again to remove this if this has no effect after I added LayerNorm post Encoder for additional stability. Will revist

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where would you get NaN from if you have embed_norm just below? I don't get it.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So couple of things:

  1. I was getting NaNs in training loss so applied bunch of techniques to see which could fix like embedding scaling, layer norm after last layer and padding after this model converged well.
  2. Then I didn't went back and change each to see exactly which one contributed heavily for smoother training. I think Post LayerNorm but I haven't validated.

I don't think this might have effect but I didn't check without it

Copy link
Copy Markdown
Member Author

@nithinraok nithinraok May 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point.

So I disabled it and re-ran for already trained ckpt and here are results:

with:

librispeech_test_clean_manifest: WER=0.0177
librispeech_test_other_manifest: WER=0.0359

commenting it out:

librispeech_test_clean_manifest: WER=0.0177
librispeech_test_other_manifest: WER=0.0359

You are correct, looks like pre-norm from attention is removing the added scale.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so I believe its the post Layer Norm that helped NaNs issue.

Comment thread tests/collections/asr/test_transformer_encoder.py Outdated
Comment thread tests/collections/asr/test_transformer_encoder.py
Comment thread tests/collections/asr/test_transformer_encoder.py
Comment thread tests/collections/asr/test_transformer_encoder.py
Copy link
Copy Markdown
Collaborator

@pzelasko pzelasko left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice work!

def __init__(self, cfg: TransformerEncoderConfig):
super().__init__()
self.net = nn.Sequential(
nn.Linear(cfg.d_model, 4 * cfg.d_model),
Copy link
Copy Markdown
Collaborator

@tango4j tango4j May 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should never be hardcoded since this is very frequently customized variable.
cfg.ff_expansion should be added, replacing the hardcoded 4.

Also, I think we should add this function to TransformerEncoderConfig

@property
    def ff_hidden_size(self) -> int:
        return int(self.ff_expansion * self.d_model)

and make this line
nn.Linear(cfg.d_model, cfg.ff_hidden_size)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I donlt think adding property to dataclass is a good idea. however added the ff_exansion parameter

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't need to add property function. But wrapping the variable with int() is necessary.
And ff_expansion should be float.
Maybe one line of a variable sanity check on the "ffn_hidden_size" would make the code easy to track errors.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes added now, pls check

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checked the change and ff_hidden variable. Looks good.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please approve if all looks good from your end

Comment thread examples/asr/conf/fastconformer/transformer_stacking_tdt_bpe.yaml
Comment thread nemo/collections/asr/modules/transformer_encoder.py
Comment thread nemo/collections/asr/modules/transformer_encoder.py
Comment thread nemo/collections/asr/modules/transformer_encoder.py
Comment thread nemo/collections/asr/modules/transformer_encoder.py
Comment thread nemo/collections/asr/modules/transformer_encoder.py
length: (B,) — output lengths after subsampling.
"""
x, length = self.pre_encode(audio_signal, length)
x = x * (self.d_model**0.5)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the layer norm is after this, this line has no effect.
If this is needed, needs to be after x = self.embed_norm(x)..?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added to prevent NaNs during training, similar to the fix we applied in FastConformer during scaling.

I need to check again to remove this if this has no effect after I added LayerNorm post Encoder for additional stability. Will revist

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point.

So I disabled it and re-ran for already trained ckpt and here are results:

with:

librispeech_test_clean_manifest: WER=0.0177
librispeech_test_other_manifest: WER=0.0359

commenting it out:

librispeech_test_clean_manifest: WER=0.0177
librispeech_test_other_manifest: WER=0.0359

You are correct, looks like pre-norm from attention is removing the added scale.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice. This was a concern for me for breaking the symmetry with other TF encoder implementations.

nithinraok added 2 commits May 5, 2026 00:49
Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>
Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>
@nithinraok
Copy link
Copy Markdown
Member Author

/ok to test ef9547a

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented May 5, 2026

[🤖]: Hi @nithinraok 👋,

We wanted to let you know that a CICD pipeline for this PR just finished successfully.

So it might be time to merge this PR or get some approvals.

nithinraok added 2 commits May 5, 2026 12:14
Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>
Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>
Copy link
Copy Markdown
Collaborator

@pzelasko pzelasko left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @nithinraok ! It looks good from my side, let's wait for @tango4j to approve and LGTM

@tango4j
Copy link
Copy Markdown
Collaborator

tango4j commented May 6, 2026

This doesn't need to be happening in this PR but to ensure standardization, but later on, we probably need to verify this implementation is 100% compatible (transferable weights and identical results) with standard baselines like the Hugging Face Transformer.

If there are any custom scaling or structural tweaks, it would be great to make them optional flags. This keeps our baseline universally compatible while still allowing for optimizations when needed.

I think we should merge this after getting approval from @KunalDhawan and @stevehuang52 too.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants