Skip to content

Commit 7ce58da

Browse files
pzelaskonasretdinovr
authored andcommitted
SpeechLM2 SALM: load ckpt faster, with less GPU memory (NVIDIA-NeMo#14113)
* Skip loading pretrained ASR weights with released speechlm2 ckpts Signed-off-by: Piotr Żelasko <petezor@gmail.com> * Decrease max memory needed to load SALM model Signed-off-by: Piotr Żelasko <petezor@gmail.com> * Configurable dtype in SALM eval scripts Signed-off-by: Piotr Żelasko <petezor@gmail.com> * fix tests Signed-off-by: Piotr Żelasko <petezor@gmail.com> --------- Signed-off-by: Piotr Żelasko <petezor@gmail.com>
1 parent 64429eb commit 7ce58da

10 files changed

Lines changed: 147 additions & 36 deletions

File tree

examples/speechlm2/salm_eval.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ class SalmEvalConfig:
4747
verbose: bool = True
4848
use_normalizer: Optional[str] = "english" # "english", "basic", or "none" / "None"
4949
device: str = "cuda"
50+
dtype: str = "bfloat16"
5051
extra_eos_tokens: Optional[list[str]] = None
5152
system_prompt: Optional[str] = None
5253
user_prompt: Optional[str] = None
@@ -56,10 +57,7 @@ class SalmEvalConfig:
5657
def main(cfg: SalmEvalConfig):
5758
logging.info(f'Hydra config:\n{OmegaConf.to_yaml(cfg)}')
5859

59-
with torch.device(cfg.device):
60-
torch.set_default_dtype(torch.bfloat16)
61-
model = SALM.from_pretrained(cfg.pretrained_name).eval().to(torch.bfloat16).to(cfg.device)
62-
torch.set_default_dtype(torch.float32)
60+
model = SALM.from_pretrained(cfg.pretrained_name).eval().to(getattr(torch, cfg.dtype)).to(cfg.device)
6361

6462
cuts = guess_parse_cutset(cfg.inputs).sort_by_duration()
6563
dloader = torch.utils.data.DataLoader(

examples/speechlm2/salm_generate.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ class SalmEvalConfig:
4242
output_manifest: str = "generations.jsonl"
4343
verbose: bool = True
4444
device: str = "cuda"
45+
dtype: str = "bfloat16"
4546
extra_eos_tokens: Optional[list[str]] = None
4647
system_prompt: Optional[str] = None
4748
user_prompt: Optional[str] = None
@@ -51,10 +52,7 @@ class SalmEvalConfig:
5152
def main(cfg: SalmEvalConfig):
5253
logging.info(f"Hydra config:\n{OmegaConf.to_yaml(cfg)}")
5354

54-
with torch.device(cfg.device):
55-
torch.set_default_dtype(torch.bfloat16)
56-
model = SALM.from_pretrained(cfg.pretrained_name).eval().to(torch.bfloat16).to(cfg.device)
57-
torch.set_default_dtype(torch.float32)
55+
model = SALM.from_pretrained(cfg.pretrained_name).eval().to(getattr(torch, cfg.dtype)).to(cfg.device)
5856

5957
conversations = (
6058
guess_parse_cutset(cfg.inputs)

examples/speechlm2/to_hf.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ class HfExportConfig:
3535
# Path where we should save the HuggingFace Hub compatible checkpoint
3636
output_dir: str
3737

38+
# Dtype used for stored parameters
39+
dtype: str = "bfloat16"
40+
3841

3942
def load_checkpoint(model: torch.nn.Module, checkpoint_path: str):
4043
if Path(checkpoint_path).is_dir():
@@ -60,6 +63,7 @@ def main(cfg: HfExportConfig):
6063
cls = import_class_by_path(cfg.class_path)
6164
model = cls(OmegaConf.to_container(model_cfg, resolve=True))
6265
load_checkpoint(model, cfg.ckpt_path)
66+
model = model.to(getattr(torch, cfg.dtype))
6367
model.save_pretrained(cfg.output_dir)
6468

6569

nemo/collections/speechlm2/models/duplex_s2s_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def __init__(self, cfg: dict) -> None:
7373
maybe_install_lora(self)
7474

7575
# Load the pretrained ASR model.
76-
setup_speech_encoder(self)
76+
setup_speech_encoder(self, pretrained_weights=self.cfg.pretrained_weights)
7777

7878
self.embed_audio_tokens = torch.nn.ModuleList(
7979
[

nemo/collections/speechlm2/models/duplex_s2s_speech_decoder_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def __init__(self, cfg: dict) -> None:
7474
maybe_install_lora(self)
7575

7676
# Load the pretrained streaming ASR model and copy its parameters into the audio perception module.
77-
setup_speech_encoder(self)
77+
setup_speech_encoder(self, pretrained_weights=self.cfg.pretrained_weights)
7878

7979
self.speech_generation = TransformerARSpeechDecoder(
8080
speech_decoder_parms=OmegaConf.to_container(self.cfg.speech_decoder),

nemo/collections/speechlm2/models/salm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def __init__(self, cfg) -> None:
6767
maybe_install_lora(self)
6868

6969
# Load the pretrained streaming ASR model and copy its parameters into the audio perception module.
70-
setup_speech_encoder(self)
70+
setup_speech_encoder(self, pretrained_weights=self.cfg.pretrained_weights)
7171

7272
self._use_fsdp = False
7373
self._use_tp = False

nemo/collections/speechlm2/parts/pretrained.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -83,16 +83,19 @@ def setup_audio_codec(model: torch.nn.Module):
8383
del model.audio_codec.discriminator # free up some memory
8484

8585

86-
def setup_speech_encoder(model: torch.nn.Module):
86+
def setup_speech_encoder(model: torch.nn.Module, pretrained_weights: bool = True):
8787
"""
8888
Sets up an ``AudioPerceptionModule``, initializing its ``encoder`` and ``preprocessor``
8989
with a pretrained NeMo ``ASRModel``.
9090
The result is assigned to ``model.perception`` attribute and is trainable.
9191
"""
92-
asr = load_pretrained_nemo(ASRModel, model.cfg.pretrained_asr).eval()
93-
with open_dict(model.cfg):
94-
model.cfg.perception.preprocessor = asr.cfg.preprocessor
95-
model.cfg.perception.encoder = asr.cfg.encoder
96-
model.cfg.perception.output_dim = model.llm.config.hidden_size
97-
model.perception = AudioPerceptionModule(model.cfg.perception).train()
98-
model.perception.load_state_dict(asr.state_dict(), strict=False)
92+
if pretrained_weights:
93+
asr = load_pretrained_nemo(ASRModel, model.cfg.pretrained_asr).eval()
94+
with open_dict(model.cfg):
95+
model.cfg.perception.preprocessor = asr.cfg.preprocessor
96+
model.cfg.perception.encoder = asr.cfg.encoder
97+
model.cfg.perception.output_dim = model.llm.config.hidden_size
98+
model.perception = AudioPerceptionModule(model.cfg.perception).train()
99+
model.perception.load_state_dict(asr.state_dict(), strict=False)
100+
else:
101+
model.perception = AudioPerceptionModule(model.cfg.perception).train()

tests/collections/speechlm2/test_duplex_s2s.py

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,14 +54,49 @@ def model():
5454
"audio_loss_weight": 1,
5555
"text_loss_weight": 3,
5656
"perception": {
57-
"_target_": "nemo.collections.speechlm2.modules.perception.AudioPerceptionModule",
58-
"modality_adapter": {
57+
"target": "nemo.collections.speechlm2.modules.perception.AudioPerceptionModule",
58+
"output_dim": 2048,
59+
"encoder": {
5960
"_target_": "nemo.collections.asr.modules.ConformerEncoder",
60-
"feat_in": 512,
61+
"att_context_size": [-1, -1],
62+
"causal_downsampling": False,
63+
"conv_context_size": None,
64+
"conv_kernel_size": 9,
65+
"conv_norm_type": "batch_norm",
66+
"d_model": 1024,
67+
"dropout": 0.1,
68+
"dropout_att": 0.1,
69+
"dropout_emb": 0.0,
70+
"dropout_pre_encoder": 0.1,
71+
"feat_in": 128,
6172
"feat_out": -1,
62-
"n_layers": 1,
63-
"d_model": 512,
64-
"subsampling_factor": 1,
73+
"ff_expansion_factor": 4,
74+
"n_heads": 8,
75+
"n_layers": 2,
76+
"pos_emb_max_len": 5000,
77+
"self_attention_model": "rel_pos",
78+
"subsampling": "dw_striding",
79+
"subsampling_conv_channels": 256,
80+
"subsampling_factor": 8,
81+
},
82+
"modality_adapter": {
83+
"_target_": "nemo.collections.speechlm2.modules.perception.IdentityConnector",
84+
"d_model": 1024,
85+
},
86+
"preprocessor": {
87+
"_target_": "nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor",
88+
"dither": 1e-05,
89+
"features": 128,
90+
"frame_splicing": 1,
91+
"log": True,
92+
"n_fft": 512,
93+
"normalize": "per_feature",
94+
"pad_to": 0,
95+
"pad_value": 0.0,
96+
"sample_rate": 16000,
97+
"window": "hann",
98+
"window_size": 0.025,
99+
"window_stride": 0.01,
65100
},
66101
},
67102
"optimizer": {"_target_": "torch.optim.AdamW"},
@@ -177,13 +212,13 @@ def test_s2s_offline_generation(model):
177212
assert isinstance(ans["text"][0], str)
178213

179214
gen_text = ans["tokens_text"]
180-
assert gen_text.shape == (1, 14)
215+
assert gen_text.shape == (1, 13)
181216
assert gen_text.dtype == torch.long
182217
assert (gen_text >= 0).all()
183218
assert (gen_text < model.text_vocab_size).all()
184219

185220
gen_audio_codes = ans["tokens_audio"]
186-
assert gen_audio_codes.shape == (1, 14, 8)
221+
assert gen_audio_codes.shape == (1, 13, 8)
187222
assert gen_audio_codes.dtype == torch.long
188223
assert (gen_audio_codes >= 0).all()
189224
assert (gen_audio_codes < model.speech_vocab_size).all()

tests/collections/speechlm2/test_duplex_s2s_speech_decoder.py

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,14 +54,49 @@ def model():
5454
"audio_loss_weight": 1,
5555
"text_loss_weight": 3,
5656
"perception": {
57-
"_target_": "nemo.collections.speechlm2.modules.perception.AudioPerceptionModule",
58-
"modality_adapter": {
57+
"target": "nemo.collections.speechlm2.modules.perception.AudioPerceptionModule",
58+
"output_dim": 2048,
59+
"encoder": {
5960
"_target_": "nemo.collections.asr.modules.ConformerEncoder",
60-
"feat_in": 512,
61+
"att_context_size": [-1, -1],
62+
"causal_downsampling": False,
63+
"conv_context_size": None,
64+
"conv_kernel_size": 9,
65+
"conv_norm_type": "batch_norm",
66+
"d_model": 1024,
67+
"dropout": 0.1,
68+
"dropout_att": 0.1,
69+
"dropout_emb": 0.0,
70+
"dropout_pre_encoder": 0.1,
71+
"feat_in": 128,
6172
"feat_out": -1,
62-
"n_layers": 1,
63-
"d_model": 512,
64-
"subsampling_factor": 1,
73+
"ff_expansion_factor": 4,
74+
"n_heads": 8,
75+
"n_layers": 2,
76+
"pos_emb_max_len": 5000,
77+
"self_attention_model": "rel_pos",
78+
"subsampling": "dw_striding",
79+
"subsampling_conv_channels": 256,
80+
"subsampling_factor": 8,
81+
},
82+
"modality_adapter": {
83+
"_target_": "nemo.collections.speechlm2.modules.perception.IdentityConnector",
84+
"d_model": 1024,
85+
},
86+
"preprocessor": {
87+
"_target_": "nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor",
88+
"dither": 1e-05,
89+
"features": 128,
90+
"frame_splicing": 1,
91+
"log": True,
92+
"n_fft": 512,
93+
"normalize": "per_feature",
94+
"pad_to": 0,
95+
"pad_value": 0.0,
96+
"sample_rate": 16000,
97+
"window": "hann",
98+
"window_size": 0.025,
99+
"window_stride": 0.01,
65100
},
66101
},
67102
"speech_decoder": {
@@ -164,13 +199,13 @@ def test_s2s_speech_decoder_offline_generation(model):
164199
assert isinstance(ans["text"][0], str)
165200

166201
gen_text = ans["tokens_text"]
167-
assert gen_text.shape == (1, 14)
202+
assert gen_text.shape == (1, 13)
168203
assert gen_text.dtype == torch.long
169204
assert (gen_text >= 0).all()
170205
assert (gen_text < model.text_vocab_size).all()
171206

172207
gen_audio_codes = ans["tokens_audio"]
173-
assert gen_audio_codes.shape == (1, 14, 8)
208+
assert gen_audio_codes.shape == (1, 13, 8)
174209
assert gen_audio_codes.dtype == torch.long
175210
assert (gen_audio_codes >= 0).all()
176211
assert (gen_audio_codes < model.speech_vocab_size).all()

tests/collections/speechlm2/test_salm.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,50 @@ def model():
5757
"prompt_format": PROMPT,
5858
"audio_locator_tag": AUDIO_LOCATOR_TAG,
5959
"perception": {
60-
"_target_": "nemo.collections.speechlm2.modules.perception.AudioPerceptionModule",
60+
"target": "nemo.collections.speechlm2.modules.perception.AudioPerceptionModule",
6161
"output_dim": 2048,
62+
"encoder": {
63+
"_target_": "nemo.collections.asr.modules.ConformerEncoder",
64+
"att_context_size": [-1, -1],
65+
"causal_downsampling": False,
66+
"conv_context_size": None,
67+
"conv_kernel_size": 9,
68+
"conv_norm_type": "batch_norm",
69+
"d_model": 1024,
70+
"dropout": 0.1,
71+
"dropout_att": 0.1,
72+
"dropout_emb": 0.0,
73+
"dropout_pre_encoder": 0.1,
74+
"feat_in": 128,
75+
"feat_out": -1,
76+
"ff_expansion_factor": 4,
77+
"n_heads": 8,
78+
"n_layers": 2,
79+
"pos_emb_max_len": 5000,
80+
"self_attention_model": "rel_pos",
81+
"subsampling": "dw_striding",
82+
"subsampling_conv_channels": 256,
83+
"subsampling_factor": 8,
84+
},
6285
"modality_adapter": {
6386
"_target_": "nemo.collections.speechlm2.modules.perception.IdentityConnector",
6487
"d_model": 1024,
6588
},
89+
"preprocessor": {
90+
"_target_": "nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor",
91+
"dither": 1e-05,
92+
"features": 128,
93+
"frame_splicing": 1,
94+
"log": True,
95+
"n_fft": 512,
96+
"normalize": "per_feature",
97+
"pad_to": 0,
98+
"pad_value": 0.0,
99+
"sample_rate": 16000,
100+
"window": "hann",
101+
"window_size": 0.025,
102+
"window_stride": 0.01,
103+
},
66104
},
67105
"optimizer": {"_target_": "torch.optim.AdamW"},
68106
}

0 commit comments

Comments
 (0)