Skip to content
15 changes: 13 additions & 2 deletions examples/tts/magpietts_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def append_metrics_to_csv(csv_path: str, checkpoint_name: str, dataset: str, met
metrics.get('wer_gt_audio_cumulative', ''),
metrics.get('utmosv2_avg', ''),
metrics.get('total_gen_audio_seconds', ''),
metrics.get('frechet_codec_distance', ''),
]
with open(csv_path, "a") as f:
f.write(",".join(str(v) for v in values) + "\n")
Expand Down Expand Up @@ -203,7 +204,7 @@ def run_inference_and_evaluation(
"wer_cumulative,ssim_pred_gt_avg,ssim_pred_context_avg,ssim_gt_context_avg,"
"ssim_pred_gt_avg_alternate,ssim_pred_context_avg_alternate,"
"ssim_gt_context_avg_alternate,cer_gt_audio_cumulative,wer_gt_audio_cumulative,"
"utmosv2_avg,total_gen_audio_seconds"
"utmosv2_avg,total_gen_audio_seconds,frechet_codec_distance"
)

for dataset in datasets:
Expand Down Expand Up @@ -244,13 +245,14 @@ def run_inference_and_evaluation(
f"Dataset length mismatch: {len(test_dataset)} vs {len(manifest_records)} manifest records"
)

rtf_metrics_list, _ = runner.run_inference_on_dataset(
rtf_metrics_list, _, codec_file_paths = runner.run_inference_on_dataset(
dataset=test_dataset,
output_dir=repeat_audio_dir,
manifest_records=manifest_records,
audio_base_dir=meta['audio_dir'],
save_cross_attention_maps=True,
save_context_audio=(repeat_idx == 0), # Only save context audio once
save_predicted_codes=eval_config.with_fcd, # Code files are only needed for FCD computation
)

# Compute mean RTF metrics
Expand All @@ -268,6 +270,8 @@ def run_inference_and_evaluation(
asr_model_name=eval_config.asr_model_name,
language=language,
with_utmosv2=eval_config.with_utmosv2,
with_fcd=eval_config.with_fcd,
codec_model_path=eval_config.codec_model_path,
)

metrics, filewise_metrics = evaluate_generated_audio_dir(
Expand All @@ -294,6 +298,10 @@ def run_inference_and_evaluation(
violin_path = Path(eval_dir) / f"{dataset}_violin_{repeat_idx}.png"
create_violin_plot(filewise_metrics, violin_plot_metrics, violin_path)

# Delete temporary predicted codes files
for codec_file_path in codec_file_paths:
os.remove(codec_file_path)

if skip_evaluation or not metrics_all_repeats:
continue

Expand Down Expand Up @@ -511,6 +519,7 @@ def create_argument_parser() -> argparse.ArgumentParser:
nargs='*',
default=['cer', 'pred_context_ssim', 'utmosv2'],
)
eval_group.add_argument('--disable_fcd', action='store_true', help="Disable Frechet Codec Distance computation")

# Quality targets (for CI/CD)
target_group = parser.add_argument_group('Quality Targets')
Expand Down Expand Up @@ -580,6 +589,8 @@ def main():
sv_model=args.sv_model,
asr_model_name=args.asr_model_name,
with_utmosv2=not args.disable_utmosv2,
with_fcd=not args.disable_fcd,
codec_model_path=args.codecmodel_path if not args.disable_fcd else None,
)

cer, ssim = None, None
Expand Down
193 changes: 193 additions & 0 deletions nemo/collections/tts/metrics/frechet_codec_distance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Tuple

import numpy as np
import torch
from einops import rearrange
from torch import Tensor, nn
from torchmetrics.image.fid import FrechetInceptionDistance

from nemo.collections.asr.parts.preprocessing.segment import AudioSegment
from nemo.collections.tts.models import AudioCodecModel
from nemo.utils import logging


class CodecEmbedder(nn.Module):
"""
Converts codec codes to dequantized codec embeddings.
The class implements the right API to be used as a custom feature extractor
provided to `torchmetrics.image.fid`.
"""

def __init__(self, codec: AudioCodecModel):
super().__init__()
self.codec = codec

def forward(self, x: Tensor) -> Tensor:
"""
Embeds a batch of audio codes into the codec's (dequantized) embedding space.
Each frame is treated independently.

Args:
x: Audio codes tensor of shape (B*T, C)

Returns:
Embeddings tensor of shape (B*T, D)
"""
# We treat all frames as one large batch element, since the codec requires (B, C, T) input and
# we don't have the per-batch-element lengths at this point due to FID API limitations

# Consturct a length tensor: one batch element, all frames.
x_len = torch.tensor(x.shape[0], device=x.device, dtype=torch.long).unsqueeze(0) # (1, 1)
tokens = x.permute(1, 0).unsqueeze(0) # 1, C, B*T
embeddings = self.codec.dequantize(tokens=tokens, tokens_len=x_len) # (B, D, T)
# we treat each time step as a separate example
embeddings = rearrange(embeddings, 'B D T -> (B T) D')
return embeddings

@property
def num_features(self) -> int:
return self.codec.vector_quantizer.codebook_dim


class FrechetCodecDistance(FrechetInceptionDistance):
"""
A metric that measures the Frechet Distance between a collection of real and
generated codec frames. The distance is measured in the codec's embedding space,
i.e. the continuous vectors obtained by dequantizing the codec frames. Each
multi-codebook frame is treated as a separate example.

We subclass `torchmetrics.image.fid.FrechetInceptionDistance` and use the codec
embedder as a custom feature extractor.
"""

def __init__(self, codec_name: str):
"""
Initializes the FrechetCodecDistance metric.

Args:
codec_name: The name of the codec model to use.
Can be a local .nemo file or a HuggingFace or NGC model.
If the name ends with ".nemo", it is assumed to be a local .nemo file.
Otherwise, it should start with "nvidia/", and is assumed to be a HuggingFace or NGC model.
"""
if codec_name.endswith(".nemo"):
# Local .nemo file
codec = AudioCodecModel.restore_from(codec_name, strict=False)
elif codec_name.startswith("nvidia/"):
# Model on HuggingFace or NGC
codec = AudioCodecModel.from_pretrained(codec_name)
else:
raise ValueError(
f"Invalid codec name: {codec_name}. Must be a local .nemo file or a HuggingFace or NGC model name starting with 'nvidia/'"
)
codec.eval()
feature = CodecEmbedder(codec)
super().__init__(feature=feature)
self.codec = codec
self.updated_since_last_reset = False

def _encode_audio_file(self, audio_path: str) -> Tuple[Tensor, Tensor]:
"""
Encodes an audio file using the audio codec.

Args:
audio_path: Path to the audio file.

Returns:
Tuple of tensors containing the codec codes and the lengths of the codec codes.
"""
audio_segment = AudioSegment.from_file(audio_path, target_sr=self.codec.sample_rate)
assert np.issubdtype(audio_segment.samples.dtype, np.floating)
audio_min = audio_segment.samples.min()
audio_max = audio_segment.samples.max()
eps = 0.01 # certain ways of normalizing audio can result in samples that are slightly outside of [-1, 1]
if audio_min < (-1.0 - eps) or audio_max > (1.0 + eps):
logging.warning(f"Audio samples are not normalized: min={audio_min}, max={audio_max}")
samples = torch.tensor(audio_segment.samples, device=self.codec.device).unsqueeze(0)
audio_len = torch.tensor(samples.shape[1], device=self.codec.device).unsqueeze(0)
codes, codes_len = self.codec.encode(audio=samples, audio_len=audio_len)
return codes, codes_len

def update(self, codes: Tensor, codes_len: Tensor, is_real: bool):
"""
Updates the metric with a batch of codec frames.

Args:
codes: Tensor of shape (B, C, T) containing the codec codes.
codes_len: Tensor of shape (B,) containing the lengths of the codec codes.
is_real: Boolean indicating whether the codes are real or generated.
"""
if codes.numel() == 0:
logging.warning("FCD: No valid codes to update, skipping update")
return
if codes.shape[1] != self.codec.num_codebooks:
logging.warning(
f"FCD: Number of codebooks mismatch: {codes.shape[1]} != {self.codec.num_codebooks}, skipping update"
)
return

# Keep only valid frames
codes_batch_all = []
for batch_idx in range(codes.shape[0]):
codes_batch = codes[batch_idx, :, : codes_len[batch_idx]] # (C, T)
codes_batch_all.append(codes_batch)

# Combine into a single tensor. We treat each frame independently so we can concatenate them all.
codes_batch_all = torch.cat(codes_batch_all, dim=-1).permute(1, 0) # (B*T, C)
if len(codes_batch_all) == 0:
logging.warning("FCD: No valid codes to update, skipping update")
return

# Update the metric
super().update(codes_batch_all, real=is_real)
self.updated_since_last_reset = True

def reset(self):
"""
Resets the metric. Should be called after each compute.
"""
super().reset()
self.updated_since_last_reset = False

def update_from_audio_file(self, audio_path: str, is_real: bool):
"""
Updates the metric with codes representing a single audio file.
Uses the codec to encode the audio file into codec codes and updates the metric.

Args:
audio_path: Path to the audio file.
is_real: Boolean indicating whether the audio file is real or generated.
"""
codes, codes_len = self._encode_audio_file(audio_path=audio_path)
self.update(codes=codes, codes_len=codes_len, is_real=is_real)

def compute(self) -> Tensor:
"""
Computes the Frechet Distance between the real and generated codec frame distributions.
"""
if not self.updated_since_last_reset:
logging.warning("FCD: No updates since last reset, returning 0")
return torch.tensor(0.0, device=self.device)
fcd = super().compute()
min_allowed_fcd = -0.01 # a bit of tolerance for numerical issues
fcd_value = fcd.cpu().item()
if fcd_value < min_allowed_fcd:
logging.warning(f"FCD value is negative: {fcd_value}")
raise ValueError(f"FCD value is negative: {fcd_value}")
# FCD should be non-negative
fcd = fcd.clamp(min=0)
return fcd
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

import nemo.collections.asr as nemo_asr
from nemo.collections.asr.metrics.wer import word_error_rate_detail
from nemo.collections.tts.metrics.frechet_codec_distance import FrechetCodecDistance
from nemo.utils import logging

# Optional import for UTMOSv2 (audio quality metric)
Expand Down Expand Up @@ -193,10 +194,17 @@ def evaluate(
sv_model_type="titanet",
asr_model_name="stt_en_conformer_transducer_large",
with_utmosv2=True,
with_fcd=True,
codec_model_path=None,
):
audio_file_lists = find_generated_audio_files(generated_audio_dir)
records = read_manifest(manifest_path)
assert len(audio_file_lists) == len(records)
if with_fcd:
if codec_model_path is None:
raise ValueError("codec_model_path is required when with_fcd is True")
codes_file_lists = find_generated_codec_files(generated_audio_dir)
assert len(codes_file_lists) == len(records)

device = "cuda"

Expand Down Expand Up @@ -225,14 +233,21 @@ def evaluate(
)
speaker_verification_model = speaker_verification_model.to(device)
speaker_verification_model.eval()
# The model `titanet_small` prints thousands of lines during initialization, so suppress logs temporarily

logging.info("Loading `titanet_small` model...")
speaker_verification_model_alternate = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained(
model_name='titanet_small'
)
# The model `titanet_small` prints thousands of lines during initialization, so suppress logs temporarily
with logging.temp_verbosity(logging.ERROR):
speaker_verification_model_alternate = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained(
model_name='titanet_small'
)
speaker_verification_model_alternate = speaker_verification_model_alternate.to(device)
speaker_verification_model_alternate.eval()

if with_fcd:
fcd_metric = FrechetCodecDistance(codec_name=codec_model_path).to(device)
else:
fcd_metric = None

if with_utmosv2:
if not UTMOSV2_AVAILABLE:
logging.warning(
Expand All @@ -253,6 +268,10 @@ def evaluate(
if context_audio_filepath is not None:
context_audio_filepath = os.path.join(audio_dir, context_audio_filepath)

# Update the FCD metric with real (ground truth) codes
if fcd_metric is not None:
fcd_metric.update_from_audio_file(gt_audio_filepath, True)

pred_audio_filepath = audio_file_lists[ridx]

if with_utmosv2 and UTMOSV2_AVAILABLE:
Expand Down Expand Up @@ -294,12 +313,18 @@ def evaluate(
logging.info(f"{ridx} GT Text: {gt_text}")
logging.info(f"{ridx} Pr Text: {pred_text}")
# Format cer and wer to 2 decimal places
logging.info("CER:", "{:.4f} | WER: {:.4f}".format(detailed_cer[0], detailed_wer[0]))
logging.info(f"CER: {detailed_cer[0]:.4f} | WER: {detailed_wer[0]:.4f}")

pred_texts.append(pred_text)
gt_texts.append(gt_text)
gt_audio_texts.append(gt_audio_text)

# Update FCD metric with generated codes
if fcd_metric is not None:
predicted_codes = torch.load(codes_file_lists[ridx]).unsqueeze(0) # B, C, T
predicted_codes_lens = torch.tensor([predicted_codes.size(-1)], dtype=torch.int, device=device)
fcd_metric.update(predicted_codes, predicted_codes_lens, False)

pred_context_ssim = 0.0
gt_context_ssim = 0.0
with torch.inference_mode():
Expand Down Expand Up @@ -377,6 +402,13 @@ def evaluate(
}
)

# compute frechet distance for the whole dataset
if fcd_metric is not None:
fcd = fcd_metric.compute().cpu().item()
fcd_metric.reset()
else:
fcd = float('nan')

filewise_metrics_keys_to_save = [
'cer',
'wer',
Expand All @@ -387,12 +419,11 @@ def evaluate(
'pred_audio_filepath',
'context_audio_filepath',
]
filtered_filewise_metrics = []
for m in filewise_metrics:
filtered_filewise_metrics.append({k: m[k] for k in filewise_metrics_keys_to_save})
# Filter filewise metrics to only keep only the metrics we want to save
filtered_filewise_metrics = [{k: m[k] for k in filewise_metrics_keys_to_save} for m in filewise_metrics]

# Sort filewise metrics by cer in reverse
filewise_metrics.sort(key=lambda x: x['cer'], reverse=True)
filtered_filewise_metrics.sort(key=lambda x: x['cer'], reverse=True)

avg_metrics = {}
avg_metrics['cer_filewise_avg'] = sum([m['detailed_cer'][0] for m in filewise_metrics]) / len(filewise_metrics)
Expand Down Expand Up @@ -423,9 +454,10 @@ def evaluate(
)[0]
avg_metrics["utmosv2_avg"] = sum([m['utmosv2'] for m in filewise_metrics]) / len(filewise_metrics)
avg_metrics["total_gen_audio_seconds"] = total_generated_audio_seconds
avg_metrics["frechet_codec_distance"] = fcd
pprint.pprint(avg_metrics)

return avg_metrics, filewise_metrics
return avg_metrics, filtered_filewise_metrics


def main():
Expand Down
Loading
Loading