Add Transformer Encoder for ASR #15661
Conversation
Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>
Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>
| self.head_dim = cfg.d_model // cfg.n_heads | ||
| self.d_model = cfg.d_model | ||
|
|
||
| self.w_q = nn.Linear(cfg.d_model, cfg.d_model, bias=cfg.qkv_bias) |
There was a problem hiding this comment.
let's merge these 3 lines into w_qkv to batch 3 ops in one
There was a problem hiding this comment.
This was done due to QK-Norm option to enable. But yes, I can unbind again for such ops.
| self.qk_norm = cfg.qk_norm | ||
| if cfg.qk_norm: | ||
| self.q_norm = nn.LayerNorm(self.head_dim) | ||
| self.k_norm = nn.LayerNorm(self.head_dim) |
There was a problem hiding this comment.
I haven;t validated runs with RMSNorm yet, so will keep this for future work.
| x = nn.functional.pad(x, (0, 0, 0, pad_size)) | ||
| t_new = (t + pad_size) // self.subsampling_factor | ||
| x = x.reshape(b, t_new, c * self.subsampling_factor) | ||
| x = self.proj(x) |
There was a problem hiding this comment.
do we want to add some norm before projection to center the activations? rms norm could be reasonable although then we might need to change the zero-padding above to something that doesn't skew the norm too much
There was a problem hiding this comment.
Thanks for suggestion. I will need to check by running experiments with this effect. So for now will defer to later PR.
| length: (B,) — output lengths after subsampling. | ||
| """ | ||
| x, length = self.pre_encode(audio_signal, length) | ||
| x = x * (self.d_model**0.5) |
There was a problem hiding this comment.
why re-scale just before norm?
There was a problem hiding this comment.
Added to prevent NaNs during training, similar to the fix we applied in FastConformer during scaling.
I need to check again to remove this if this has no effect after I added LayerNorm post Encoder for additional stability. Will revist
There was a problem hiding this comment.
Where would you get NaN from if you have embed_norm just below? I don't get it.
There was a problem hiding this comment.
So couple of things:
- I was getting NaNs in training loss so applied bunch of techniques to see which could fix like embedding scaling, layer norm after last layer and padding after this model converged well.
- Then I didn't went back and change each to see exactly which one contributed heavily for smoother training. I think Post LayerNorm but I haven't validated.
I don't think this might have effect but I didn't check without it
There was a problem hiding this comment.
Good point.
So I disabled it and re-ran for already trained ckpt and here are results:
with:
librispeech_test_clean_manifest: WER=0.0177
librispeech_test_other_manifest: WER=0.0359
commenting it out:
librispeech_test_clean_manifest: WER=0.0177
librispeech_test_other_manifest: WER=0.0359
You are correct, looks like pre-norm from attention is removing the added scale.
There was a problem hiding this comment.
so I believe its the post Layer Norm that helped NaNs issue.
| def __init__(self, cfg: TransformerEncoderConfig): | ||
| super().__init__() | ||
| self.net = nn.Sequential( | ||
| nn.Linear(cfg.d_model, 4 * cfg.d_model), |
There was a problem hiding this comment.
This should never be hardcoded since this is very frequently customized variable.
cfg.ff_expansion should be added, replacing the hardcoded 4.
Also, I think we should add this function to TransformerEncoderConfig
@property
def ff_hidden_size(self) -> int:
return int(self.ff_expansion * self.d_model)
and make this line
nn.Linear(cfg.d_model, cfg.ff_hidden_size)
There was a problem hiding this comment.
I donlt think adding property to dataclass is a good idea. however added the ff_exansion parameter
There was a problem hiding this comment.
Don't need to add property function. But wrapping the variable with int() is necessary.
And ff_expansion should be float.
Maybe one line of a variable sanity check on the "ffn_hidden_size" would make the code easy to track errors.
There was a problem hiding this comment.
Yes added now, pls check
There was a problem hiding this comment.
Checked the change and ff_hidden variable. Looks good.
There was a problem hiding this comment.
Please approve if all looks good from your end
| length: (B,) — output lengths after subsampling. | ||
| """ | ||
| x, length = self.pre_encode(audio_signal, length) | ||
| x = x * (self.d_model**0.5) |
There was a problem hiding this comment.
Since the layer norm is after this, this line has no effect.
If this is needed, needs to be after x = self.embed_norm(x)..?
There was a problem hiding this comment.
Added to prevent NaNs during training, similar to the fix we applied in FastConformer during scaling.
I need to check again to remove this if this has no effect after I added LayerNorm post Encoder for additional stability. Will revist
There was a problem hiding this comment.
Good point.
So I disabled it and re-ran for already trained ckpt and here are results:
with:
librispeech_test_clean_manifest: WER=0.0177
librispeech_test_other_manifest: WER=0.0359
commenting it out:
librispeech_test_clean_manifest: WER=0.0177
librispeech_test_other_manifest: WER=0.0359
You are correct, looks like pre-norm from attention is removing the added scale.
There was a problem hiding this comment.
Very nice. This was a concern for me for breaking the symmetry with other TF encoder implementations.
Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>
|
/ok to test ef9547a |
|
[🤖]: Hi @nithinraok 👋, We wanted to let you know that a CICD pipeline for this PR just finished successfully. So it might be time to merge this PR or get some approvals. |
Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>
Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>
pzelasko
left a comment
There was a problem hiding this comment.
Thanks @nithinraok ! It looks good from my side, let's wait for @tango4j to approve and LGTM
|
This doesn't need to be happening in this PR but to ensure standardization, but later on, we probably need to verify this implementation is 100% compatible (transferable weights and identical results) with standard baselines like the Hugging Face Transformer. If there are any custom scaling or structural tweaks, it would be great to make them optional flags. This keeps our baseline universally compatible while still allowing for optimizations when needed. I think we should merge this after getting approval from @KunalDhawan and @stevehuang52 too. |
Important
The
Update branchbutton must only be pressed in very rare occassions.An outdated branch is never blocking the merge of a PR.
Please reach out to the automation team before pressing that button.
What does this PR do ?
Adds a new ASR transformer encoder with frame stacking, configurable subsampling, FlexAttention-based full attention, and optional QK Norm
Collection: ASR
Changelog
nemo/collections/asr/modules/transformer_encoder.pyas a lightweight pre-norm transformer encoder for ASR.TransformerEncoderConfigdataclass to capture encoder hyperparameters such as d_model, n_heads, n_layers, qkv_bias, qk_norm, andsubsampling_factor.
FeatureStackingpre-encoder module to stack consecutive frames, reduce sequence length, and project stacked features into the modeldimension.
FeedForward,MultiHeadAttention, andTransformerBlock.nemo/collections/asr/modules/__init__.pyfor config-based instantiation.examples/asr/conf/fastconformer/transformer_stacking_tdt_bpe.yamlshowing how to train an RNNT/TDT model with the newencoder.
Usage
You can also instantiate it from Hydra config with:
GitHub Actions CI
The Jenkins CI system has been replaced by GitHub Actions self-hosted runners.
The GitHub Actions CI will run automatically when the "Run CICD" label is added to the PR.
To re-run CI remove and add the label again.
To run CI on an untrusted fork, a NeMo user with write access must first click "Approve and run".
Before your PR is "Ready for review"
Pre checks:
PR Type:
If you haven't finished some of the above items you can still open "Draft" PR.
Who can review?
Anyone in the community is free to review the PR once the checks have passed.
Contributor guidelines contains specific people who can review PRs to various areas.
Additional Information