diff --git a/examples/tts/magpietts_inference.py b/examples/tts/magpietts_inference.py index 2897b714889f..dffd38ec8473 100644 --- a/examples/tts/magpietts_inference.py +++ b/examples/tts/magpietts_inference.py @@ -101,6 +101,7 @@ def append_metrics_to_csv(csv_path: str, checkpoint_name: str, dataset: str, met metrics.get('wer_gt_audio_cumulative', ''), metrics.get('utmosv2_avg', ''), metrics.get('total_gen_audio_seconds', ''), + metrics.get('frechet_codec_distance', ''), ] with open(csv_path, "a") as f: f.write(",".join(str(v) for v in values) + "\n") @@ -203,7 +204,7 @@ def run_inference_and_evaluation( "wer_cumulative,ssim_pred_gt_avg,ssim_pred_context_avg,ssim_gt_context_avg," "ssim_pred_gt_avg_alternate,ssim_pred_context_avg_alternate," "ssim_gt_context_avg_alternate,cer_gt_audio_cumulative,wer_gt_audio_cumulative," - "utmosv2_avg,total_gen_audio_seconds" + "utmosv2_avg,total_gen_audio_seconds,frechet_codec_distance" ) for dataset in datasets: @@ -244,13 +245,14 @@ def run_inference_and_evaluation( f"Dataset length mismatch: {len(test_dataset)} vs {len(manifest_records)} manifest records" ) - rtf_metrics_list, _ = runner.run_inference_on_dataset( + rtf_metrics_list, _, codec_file_paths = runner.run_inference_on_dataset( dataset=test_dataset, output_dir=repeat_audio_dir, manifest_records=manifest_records, audio_base_dir=meta['audio_dir'], save_cross_attention_maps=True, save_context_audio=(repeat_idx == 0), # Only save context audio once + save_predicted_codes=eval_config.with_fcd, # Code files are only needed for FCD computation ) # Compute mean RTF metrics @@ -268,6 +270,8 @@ def run_inference_and_evaluation( asr_model_name=eval_config.asr_model_name, language=language, with_utmosv2=eval_config.with_utmosv2, + with_fcd=eval_config.with_fcd, + codec_model_path=eval_config.codec_model_path, ) metrics, filewise_metrics = evaluate_generated_audio_dir( @@ -294,6 +298,10 @@ def run_inference_and_evaluation( violin_path = Path(eval_dir) / f"{dataset}_violin_{repeat_idx}.png" create_violin_plot(filewise_metrics, violin_plot_metrics, violin_path) + # Delete temporary predicted codes files + for codec_file_path in codec_file_paths: + os.remove(codec_file_path) + if skip_evaluation or not metrics_all_repeats: continue @@ -511,6 +519,7 @@ def create_argument_parser() -> argparse.ArgumentParser: nargs='*', default=['cer', 'pred_context_ssim', 'utmosv2'], ) + eval_group.add_argument('--disable_fcd', action='store_true', help="Disable Frechet Codec Distance computation") # Quality targets (for CI/CD) target_group = parser.add_argument_group('Quality Targets') @@ -580,6 +589,8 @@ def main(): sv_model=args.sv_model, asr_model_name=args.asr_model_name, with_utmosv2=not args.disable_utmosv2, + with_fcd=not args.disable_fcd, + codec_model_path=args.codecmodel_path if not args.disable_fcd else None, ) cer, ssim = None, None diff --git a/nemo/collections/tts/metrics/frechet_codec_distance.py b/nemo/collections/tts/metrics/frechet_codec_distance.py new file mode 100644 index 000000000000..df1140bc5af3 --- /dev/null +++ b/nemo/collections/tts/metrics/frechet_codec_distance.py @@ -0,0 +1,193 @@ +# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple + +import numpy as np +import torch +from einops import rearrange +from torch import Tensor, nn +from torchmetrics.image.fid import FrechetInceptionDistance + +from nemo.collections.asr.parts.preprocessing.segment import AudioSegment +from nemo.collections.tts.models import AudioCodecModel +from nemo.utils import logging + + +class CodecEmbedder(nn.Module): + """ + Converts codec codes to dequantized codec embeddings. + The class implements the right API to be used as a custom feature extractor + provided to `torchmetrics.image.fid`. + """ + + def __init__(self, codec: AudioCodecModel): + super().__init__() + self.codec = codec + + def forward(self, x: Tensor) -> Tensor: + """ + Embeds a batch of audio codes into the codec's (dequantized) embedding space. + Each frame is treated independently. + + Args: + x: Audio codes tensor of shape (B*T, C) + + Returns: + Embeddings tensor of shape (B*T, D) + """ + # We treat all frames as one large batch element, since the codec requires (B, C, T) input and + # we don't have the per-batch-element lengths at this point due to FID API limitations + + # Consturct a length tensor: one batch element, all frames. + x_len = torch.tensor(x.shape[0], device=x.device, dtype=torch.long).unsqueeze(0) # (1, 1) + tokens = x.permute(1, 0).unsqueeze(0) # 1, C, B*T + embeddings = self.codec.dequantize(tokens=tokens, tokens_len=x_len) # (B, D, T) + # we treat each time step as a separate example + embeddings = rearrange(embeddings, 'B D T -> (B T) D') + return embeddings + + @property + def num_features(self) -> int: + return self.codec.vector_quantizer.codebook_dim + + +class FrechetCodecDistance(FrechetInceptionDistance): + """ + A metric that measures the Frechet Distance between a collection of real and + generated codec frames. The distance is measured in the codec's embedding space, + i.e. the continuous vectors obtained by dequantizing the codec frames. Each + multi-codebook frame is treated as a separate example. + + We subclass `torchmetrics.image.fid.FrechetInceptionDistance` and use the codec + embedder as a custom feature extractor. + """ + + def __init__(self, codec_name: str): + """ + Initializes the FrechetCodecDistance metric. + + Args: + codec_name: The name of the codec model to use. + Can be a local .nemo file or a HuggingFace or NGC model. + If the name ends with ".nemo", it is assumed to be a local .nemo file. + Otherwise, it should start with "nvidia/", and is assumed to be a HuggingFace or NGC model. + """ + if codec_name.endswith(".nemo"): + # Local .nemo file + codec = AudioCodecModel.restore_from(codec_name, strict=False) + elif codec_name.startswith("nvidia/"): + # Model on HuggingFace or NGC + codec = AudioCodecModel.from_pretrained(codec_name) + else: + raise ValueError( + f"Invalid codec name: {codec_name}. Must be a local .nemo file or a HuggingFace or NGC model name starting with 'nvidia/'" + ) + codec.eval() + feature = CodecEmbedder(codec) + super().__init__(feature=feature) + self.codec = codec + self.updated_since_last_reset = False + + def _encode_audio_file(self, audio_path: str) -> Tuple[Tensor, Tensor]: + """ + Encodes an audio file using the audio codec. + + Args: + audio_path: Path to the audio file. + + Returns: + Tuple of tensors containing the codec codes and the lengths of the codec codes. + """ + audio_segment = AudioSegment.from_file(audio_path, target_sr=self.codec.sample_rate) + assert np.issubdtype(audio_segment.samples.dtype, np.floating) + audio_min = audio_segment.samples.min() + audio_max = audio_segment.samples.max() + eps = 0.01 # certain ways of normalizing audio can result in samples that are slightly outside of [-1, 1] + if audio_min < (-1.0 - eps) or audio_max > (1.0 + eps): + logging.warning(f"Audio samples are not normalized: min={audio_min}, max={audio_max}") + samples = torch.tensor(audio_segment.samples, device=self.codec.device).unsqueeze(0) + audio_len = torch.tensor(samples.shape[1], device=self.codec.device).unsqueeze(0) + codes, codes_len = self.codec.encode(audio=samples, audio_len=audio_len) + return codes, codes_len + + def update(self, codes: Tensor, codes_len: Tensor, is_real: bool): + """ + Updates the metric with a batch of codec frames. + + Args: + codes: Tensor of shape (B, C, T) containing the codec codes. + codes_len: Tensor of shape (B,) containing the lengths of the codec codes. + is_real: Boolean indicating whether the codes are real or generated. + """ + if codes.numel() == 0: + logging.warning("FCD: No valid codes to update, skipping update") + return + if codes.shape[1] != self.codec.num_codebooks: + logging.warning( + f"FCD: Number of codebooks mismatch: {codes.shape[1]} != {self.codec.num_codebooks}, skipping update" + ) + return + + # Keep only valid frames + codes_batch_all = [] + for batch_idx in range(codes.shape[0]): + codes_batch = codes[batch_idx, :, : codes_len[batch_idx]] # (C, T) + codes_batch_all.append(codes_batch) + + # Combine into a single tensor. We treat each frame independently so we can concatenate them all. + codes_batch_all = torch.cat(codes_batch_all, dim=-1).permute(1, 0) # (B*T, C) + if len(codes_batch_all) == 0: + logging.warning("FCD: No valid codes to update, skipping update") + return + + # Update the metric + super().update(codes_batch_all, real=is_real) + self.updated_since_last_reset = True + + def reset(self): + """ + Resets the metric. Should be called after each compute. + """ + super().reset() + self.updated_since_last_reset = False + + def update_from_audio_file(self, audio_path: str, is_real: bool): + """ + Updates the metric with codes representing a single audio file. + Uses the codec to encode the audio file into codec codes and updates the metric. + + Args: + audio_path: Path to the audio file. + is_real: Boolean indicating whether the audio file is real or generated. + """ + codes, codes_len = self._encode_audio_file(audio_path=audio_path) + self.update(codes=codes, codes_len=codes_len, is_real=is_real) + + def compute(self) -> Tensor: + """ + Computes the Frechet Distance between the real and generated codec frame distributions. + """ + if not self.updated_since_last_reset: + logging.warning("FCD: No updates since last reset, returning 0") + return torch.tensor(0.0, device=self.device) + fcd = super().compute() + min_allowed_fcd = -0.01 # a bit of tolerance for numerical issues + fcd_value = fcd.cpu().item() + if fcd_value < min_allowed_fcd: + logging.warning(f"FCD value is negative: {fcd_value}") + raise ValueError(f"FCD value is negative: {fcd_value}") + # FCD should be non-negative + fcd = fcd.clamp(min=0) + return fcd diff --git a/nemo/collections/tts/modules/magpietts_inference/evaluate_generated_audio.py b/nemo/collections/tts/modules/magpietts_inference/evaluate_generated_audio.py index 324c0a2e6939..0a17406c182a 100644 --- a/nemo/collections/tts/modules/magpietts_inference/evaluate_generated_audio.py +++ b/nemo/collections/tts/modules/magpietts_inference/evaluate_generated_audio.py @@ -31,6 +31,7 @@ import nemo.collections.asr as nemo_asr from nemo.collections.asr.metrics.wer import word_error_rate_detail +from nemo.collections.tts.metrics.frechet_codec_distance import FrechetCodecDistance from nemo.utils import logging # Optional import for UTMOSv2 (audio quality metric) @@ -193,10 +194,17 @@ def evaluate( sv_model_type="titanet", asr_model_name="stt_en_conformer_transducer_large", with_utmosv2=True, + with_fcd=True, + codec_model_path=None, ): audio_file_lists = find_generated_audio_files(generated_audio_dir) records = read_manifest(manifest_path) assert len(audio_file_lists) == len(records) + if with_fcd: + if codec_model_path is None: + raise ValueError("codec_model_path is required when with_fcd is True") + codes_file_lists = find_generated_codec_files(generated_audio_dir) + assert len(codes_file_lists) == len(records) device = "cuda" @@ -225,14 +233,21 @@ def evaluate( ) speaker_verification_model = speaker_verification_model.to(device) speaker_verification_model.eval() - # The model `titanet_small` prints thousands of lines during initialization, so suppress logs temporarily + logging.info("Loading `titanet_small` model...") - speaker_verification_model_alternate = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained( - model_name='titanet_small' - ) + # The model `titanet_small` prints thousands of lines during initialization, so suppress logs temporarily + with logging.temp_verbosity(logging.ERROR): + speaker_verification_model_alternate = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained( + model_name='titanet_small' + ) speaker_verification_model_alternate = speaker_verification_model_alternate.to(device) speaker_verification_model_alternate.eval() + if with_fcd: + fcd_metric = FrechetCodecDistance(codec_name=codec_model_path).to(device) + else: + fcd_metric = None + if with_utmosv2: if not UTMOSV2_AVAILABLE: logging.warning( @@ -253,6 +268,10 @@ def evaluate( if context_audio_filepath is not None: context_audio_filepath = os.path.join(audio_dir, context_audio_filepath) + # Update the FCD metric with real (ground truth) codes + if fcd_metric is not None: + fcd_metric.update_from_audio_file(gt_audio_filepath, True) + pred_audio_filepath = audio_file_lists[ridx] if with_utmosv2 and UTMOSV2_AVAILABLE: @@ -294,12 +313,18 @@ def evaluate( logging.info(f"{ridx} GT Text: {gt_text}") logging.info(f"{ridx} Pr Text: {pred_text}") # Format cer and wer to 2 decimal places - logging.info("CER:", "{:.4f} | WER: {:.4f}".format(detailed_cer[0], detailed_wer[0])) + logging.info(f"CER: {detailed_cer[0]:.4f} | WER: {detailed_wer[0]:.4f}") pred_texts.append(pred_text) gt_texts.append(gt_text) gt_audio_texts.append(gt_audio_text) + # Update FCD metric with generated codes + if fcd_metric is not None: + predicted_codes = torch.load(codes_file_lists[ridx]).unsqueeze(0) # B, C, T + predicted_codes_lens = torch.tensor([predicted_codes.size(-1)], dtype=torch.int, device=device) + fcd_metric.update(predicted_codes, predicted_codes_lens, False) + pred_context_ssim = 0.0 gt_context_ssim = 0.0 with torch.inference_mode(): @@ -377,6 +402,13 @@ def evaluate( } ) + # compute frechet distance for the whole dataset + if fcd_metric is not None: + fcd = fcd_metric.compute().cpu().item() + fcd_metric.reset() + else: + fcd = float('nan') + filewise_metrics_keys_to_save = [ 'cer', 'wer', @@ -387,12 +419,11 @@ def evaluate( 'pred_audio_filepath', 'context_audio_filepath', ] - filtered_filewise_metrics = [] - for m in filewise_metrics: - filtered_filewise_metrics.append({k: m[k] for k in filewise_metrics_keys_to_save}) + # Filter filewise metrics to only keep only the metrics we want to save + filtered_filewise_metrics = [{k: m[k] for k in filewise_metrics_keys_to_save} for m in filewise_metrics] # Sort filewise metrics by cer in reverse - filewise_metrics.sort(key=lambda x: x['cer'], reverse=True) + filtered_filewise_metrics.sort(key=lambda x: x['cer'], reverse=True) avg_metrics = {} avg_metrics['cer_filewise_avg'] = sum([m['detailed_cer'][0] for m in filewise_metrics]) / len(filewise_metrics) @@ -423,9 +454,10 @@ def evaluate( )[0] avg_metrics["utmosv2_avg"] = sum([m['utmosv2'] for m in filewise_metrics]) / len(filewise_metrics) avg_metrics["total_gen_audio_seconds"] = total_generated_audio_seconds + avg_metrics["frechet_codec_distance"] = fcd pprint.pprint(avg_metrics) - return avg_metrics, filewise_metrics + return avg_metrics, filtered_filewise_metrics def main(): diff --git a/nemo/collections/tts/modules/magpietts_inference/evaluation.py b/nemo/collections/tts/modules/magpietts_inference/evaluation.py index 3d3802061318..ff5440d0972b 100644 --- a/nemo/collections/tts/modules/magpietts_inference/evaluation.py +++ b/nemo/collections/tts/modules/magpietts_inference/evaluation.py @@ -39,12 +39,16 @@ class EvaluationConfig: asr_model_name: ASR model for transcription (e.g., "nvidia/parakeet-tdt-1.1b"). language: Language code for transcription (e.g., "en"). with_utmosv2: Whether to compute UTMOSv2 (Mean Opinion Score) metrics. + with_fcd: Whether to compute Frechet Codec Distance metric. + codec_model_path: Path to the audio codec model. If None, will skip computing Frechet Codec Distance metric. """ sv_model: str = "titanet" asr_model_name: str = "nvidia/parakeet-tdt-1.1b" language: str = "en" with_utmosv2: bool = True + with_fcd: bool = True + codec_model_path: str = None def evaluate_generated_audio_dir( @@ -59,6 +63,7 @@ def evaluate_generated_audio_dir( - ASR-based metrics: Character Error Rate (CER), Word Error Rate (WER) - Speaker similarity: Cosine similarity using speaker embeddings - Audio quality: UTMOSv2 scores (if enabled) + - Freceht Codec Distance (FCD) metric (if enabled) Args: manifest_path: Path to the evaluation manifest (NDJSON format). @@ -81,6 +86,8 @@ def evaluate_generated_audio_dir( sv_model_type=config.sv_model, asr_model_name=config.asr_model_name, with_utmosv2=config.with_utmosv2, + with_fcd=config.with_fcd, + codec_model_path=config.codec_model_path, ) return avg_metrics, filewise_metrics @@ -141,6 +148,7 @@ def compute_mean_with_confidence_interval( 'wer_gt_audio_cumulative', 'utmosv2_avg', 'total_gen_audio_seconds', + 'frechet_codec_distance', ] # Default metrics to show in violin plots diff --git a/nemo/collections/tts/modules/magpietts_inference/inference.py b/nemo/collections/tts/modules/magpietts_inference/inference.py index f7809fdb7ded..fd165edb859d 100644 --- a/nemo/collections/tts/modules/magpietts_inference/inference.py +++ b/nemo/collections/tts/modules/magpietts_inference/inference.py @@ -307,6 +307,7 @@ def run_inference_on_dataset( audio_base_dir: Optional[str] = None, save_cross_attention_maps: bool = True, save_context_audio: bool = True, + save_predicted_codes: bool = True, ) -> Tuple[List[dict], List[str]]: """Run inference on a dataset. @@ -321,11 +322,13 @@ def run_inference_on_dataset( audio_base_dir: Base directory for audio paths (uses cached if None). save_cross_attention_maps: Whether to save attention map images. save_context_audio: Whether to copy context audio files. + save_predicted_codes: Whether to save predicted code files. Returns: Tuple of: - rtf_metrics: List of real-time factor metrics per batch. - generated_audio_paths: List of paths to generated audio files. + - codec_file_paths: List of paths to predicted codes files. """ # Use cached values if not provided if manifest_records is None: @@ -342,12 +345,18 @@ def run_inference_on_dataset( if self._use_longform: logging.info("Using longform inference path") return self._run_longform_inference( - dataset, output_dir, manifest_records, audio_base_dir, save_context_audio + dataset, output_dir, manifest_records, audio_base_dir, save_context_audio, save_predicted_codes ) else: logging.info("Using standard inference path") return self._run_standard_inference( - dataset, output_dir, manifest_records, audio_base_dir, save_cross_attention_maps, save_context_audio + dataset, + output_dir, + manifest_records, + audio_base_dir, + save_cross_attention_maps, + save_context_audio, + save_predicted_codes, ) def _run_standard_inference( @@ -358,6 +367,7 @@ def _run_standard_inference( audio_base_dir: str, save_cross_attention_maps: bool = True, save_context_audio: bool = True, + save_predicted_codes: bool = True, ) -> Tuple[List[dict], List[str]]: """Run standard single-pass inference on a dataset. @@ -368,11 +378,13 @@ def _run_standard_inference( audio_base_dir: Base directory for resolving audio paths. save_cross_attention_maps: Whether to save attention map images. save_context_audio: Whether to copy context audio files. + save_predicted_codes: Whether to save predicted code files. Returns: Tuple of: - rtf_metrics: List of real-time factor metrics per batch. - generated_audio_paths: List of paths to generated audio files. + - codec_file_paths: List of paths to predicted codes files. """ os.makedirs(output_dir, exist_ok=True) self._delete_old_generated_files(output_dir) @@ -388,6 +400,7 @@ def _run_standard_inference( item_idx = 0 all_rtf_metrics = [] generated_audio_paths = [] + codec_file_paths = [] for batch_idx, batch in enumerate(dataloader): logging.info(f"Processing batch {batch_idx + 1}/{len(dataloader)}") @@ -422,6 +435,8 @@ def _run_standard_inference( predicted_audio = output.predicted_audio predicted_audio_lens = output.predicted_audio_lens + predicted_codes = output.predicted_codes + predicted_codes_lens = output.predicted_codes_lens rtf_metrics = output.rtf_metrics cross_attention_maps = output.cross_attention_maps @@ -453,9 +468,14 @@ def _run_standard_inference( item_idx, ) + if save_predicted_codes: + codes_path = os.path.join(output_dir, f"predicted_codes_{item_idx}.pt") + predicted_codes_current = predicted_codes[idx, :, : predicted_codes_lens[idx]] # C, T + torch.save(predicted_codes_current, codes_path) + codec_file_paths.append(codes_path) item_idx += 1 - return all_rtf_metrics, generated_audio_paths + return all_rtf_metrics, generated_audio_paths, codec_file_paths @staticmethod def _batch_to_cuda(batch: dict) -> dict: @@ -581,7 +601,8 @@ def _run_longform_inference( manifest_records: List[dict], audio_base_dir: str, save_context_audio: bool = True, - ) -> Tuple[List[dict], List[str]]: + save_predicted_codes: bool = True, + ) -> Tuple[List[dict], List[str], List[str]]: """Run longform inference with automatic sentence chunking. Processes text sentence-by-sentence using generate_long_form_speech(). @@ -592,11 +613,13 @@ def _run_longform_inference( manifest_records: List of manifest record dictionaries. audio_base_dir: Base directory for resolving audio paths. save_context_audio: Whether to copy context audio files. + save_predicted_codes: Whether to save predicted code files. Returns: Tuple of: - rtf_metrics: List of real-time factor metrics per batch. - generated_audio_paths: List of paths to generated audio files. + - codec_file_paths: List of paths to predicted codes files. """ os.makedirs(output_dir, exist_ok=True) self._delete_old_generated_files(output_dir) @@ -611,6 +634,7 @@ def _run_longform_inference( all_rtf_metrics = [] generated_audio_paths = [] + codec_file_paths = [] global_item_idx = 0 for batch_idx, batch in enumerate(dataloader): @@ -742,9 +766,15 @@ def _run_longform_inference( sample_idx, ) + if save_predicted_codes: + codes_path = os.path.join(output_dir, f"predicted_codes_{sample_idx}.pt") + predicted_codes_current = predicted_codes[b_idx, :, : predicted_codes_lens[b_idx]] # C, T + torch.save(predicted_codes_current, codes_path) + codec_file_paths.append(codes_path) + global_item_idx += 1 - return all_rtf_metrics, generated_audio_paths + return all_rtf_metrics, generated_audio_paths, codec_file_paths def _compute_end_of_text_flags( self, diff --git a/tests/collections/tts/metrics/__init__.py b/tests/collections/tts/metrics/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/collections/tts/metrics/test_frechet_codec_distance.py b/tests/collections/tts/metrics/test_frechet_codec_distance.py new file mode 100644 index 000000000000..d3a59072562a --- /dev/null +++ b/tests/collections/tts/metrics/test_frechet_codec_distance.py @@ -0,0 +1,126 @@ +# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch + +from nemo.collections.tts.metrics.frechet_codec_distance import FrechetCodecDistance +from nemo.collections.tts.models import AudioCodecModel + + +class TestFrechetCodecDistance: + codec_name = "nvidia/low-frame-rate-speech-codec-22khz" + + @pytest.fixture + def device(self): + return torch.device("cuda" if torch.cuda.is_available() else "cpu") + + @pytest.fixture + def codec(self, device, scope="session"): + return AudioCodecModel.from_pretrained(self.codec_name).to(device) + + @pytest.fixture + def metric(self, codec, device): + return FrechetCodecDistance(codec_name=self.codec_name).to(device) + + @pytest.mark.unit + def test_same_distribution(self, metric, device, codec): + """Test that FCD is close to zero when comparing identical distributions.""" + B, C, T = 3, codec.num_codebooks, 20 + codes = torch.randint(low=0, high=codec.codebook_size, size=(B, C, T), device=device) + codes_len = torch.randint(low=1, high=T, size=(B,), device=device) + + # Update with same codes for both real and fake + metric.update(codes, codes_len, is_real=True) + metric.update(codes, codes_len, is_real=False) + + eps = 0.01 + fcd = metric.compute() + assert fcd < eps and fcd >= 0, f"FCD value is {fcd} but should be close to 0" + metric.reset() + + @pytest.mark.unit + def test_different_distribution(self, metric, device, codec): + """Test that FCD is positive when comparing different distributions.""" + B, C, T = 3, codec.num_codebooks, 20 + + # Generate two different sets of codes + codes1 = torch.randint(low=0, high=codec.codebook_size, size=(B, C, T), device=device) + codes2 = torch.randint(low=0, high=codec.codebook_size, size=(B, C, T), device=device) + codes_len = torch.randint(low=1, high=T, size=(B,), device=device) + + metric.update(codes1, codes_len, is_real=True) + metric.update(codes2, codes_len, is_real=False) + + fcd = metric.compute() + assert fcd > 0, f"FCD value is {fcd} but should be positive for different distributions" + metric.reset() + + @pytest.mark.filterwarnings("ignore:The.*compute.*method of metric.*was called before the.*update.*method") + @pytest.mark.unit + def test_empty_distribution(self, metric): + """Test that computing the FCD on empty distributions returns 0.""" + fcd = metric.compute() + assert fcd == 0.0 + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + @pytest.mark.unit + def test_gpu_compatibility(self, metric, device, codec): + """Test that the metric works correctly on GPU.""" + assert metric.device.type == "cuda" + B, C, T = 3, codec.num_codebooks, 20 + codes = torch.randint(low=0, high=codec.codebook_size, size=(B, C, T), device=device) + codes_len = torch.randint(low=1, high=T, size=(B,), device=device) + + metric.update(codes, codes_len, is_real=True) + metric.update(codes, codes_len, is_real=False) + + fcd = metric.compute() + + eps = 0.01 + assert isinstance(fcd, torch.Tensor) + assert fcd.device.type == "cuda" + assert fcd < eps and fcd >= 0, f"FCD value is {fcd} but should be close to 0" + + @pytest.mark.unit + def test_update_from_audio_file(self, metric): + """Test the update_from_audio_file method.""" + + # Test with both "real" and "fake" audio files (different files) + metric.update_from_audio_file("tests/.data/tts/mini_ljspeech/wavs/LJ019-0373.wav", is_real=True) + metric.update_from_audio_file("tests/.data/tts/mini_ljspeech/wavs/LJ050-0234.wav", is_real=False) + + fcd = metric.compute() + assert isinstance(fcd, torch.Tensor) + assert fcd > 0, f"FCD value is {fcd} but should be positive given that we tested different audio files" + + @pytest.mark.unit + def test_empty_codes_update(self, metric, device): + """Test that the FCD metric doesn't crash when provided with empty codes.""" + B, C, T = 1, 0, 100 + codes = torch.ones(B, C, T, device=device) + codes_len = T * torch.ones(B, device=device) + # if it crashes PyTest will report it + metric.update(codes, codes_len, is_real=True) + + @pytest.mark.unit + def test_codebooks_mismatch_update(self, metric, device, codec): + """Test that the FCD metric doesn't crash when provided with incorrect number ofcodebooks.""" + B = 2 + C = codec.num_codebooks - 1 # intentionally missing one codebook + T = 10 + codes = torch.ones(B, C, T, device=device) + codes_len = T * torch.ones(B, device=device, dtype=torch.long) + # if it crashes PyTest will report it + metric.update(codes, codes_len, is_real=True)