diff --git a/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py b/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py index 5ad368d92b0d..137d58be97b1 100644 --- a/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py +++ b/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py @@ -55,7 +55,7 @@ from nemo.collections.asr.parts.utils.speaker_utils import ( audio_rttm_map, get_uniqname_from_filepath, - timestamps_to_pyannote_object, + timestamps_to_supervisions, ) from nemo.collections.asr.parts.utils.transcribe_utils import read_and_maybe_sort_manifest from nemo.collections.asr.parts.utils.vad_utils import ( @@ -274,9 +274,9 @@ def convert_pred_mat_to_segments( bypass_postprocessing (bool, optional): if True, postprocessing will be bypassed. Defaults to False. Returns: - all_hypothesis (list): list of pyannote objects for each audio file. - all_reference (list): list of pyannote objects for each audio file. - all_uems (list): list of pyannote objects for each audio file. + all_hypothesis (list): list of (uniq_id, list[SupervisionSegment]) per audio file. + all_reference (list): list of (uniq_id, list[SupervisionSegment]) per audio file. + all_uems (list): list of (uniq_id, list[SupervisionSegment]) per audio file. """ all_hypothesis, all_reference, all_uems = [], [], [] cfg_vad_params = OmegaConf.structured(postprocessing_cfg) @@ -294,7 +294,7 @@ def convert_pred_mat_to_segments( uniq_id = audio_rttm_values["uniq_id"] else: uniq_id = get_uniqname_from_filepath(audio_rttm_values["audio_filepath"]) - all_hypothesis, all_reference, all_uems = timestamps_to_pyannote_object( + all_hypothesis, all_reference, all_uems = timestamps_to_supervisions( speaker_timestamps, uniq_id, audio_rttm_values, diff --git a/examples/voice_agent/environment.yaml b/examples/voice_agent/environment.yaml index 1ade2ff7fa6c..0cab180abb83 100644 --- a/examples/voice_agent/environment.yaml +++ b/examples/voice_agent/environment.yaml @@ -328,9 +328,6 @@ dependencies: - pulp==3.3.0 - pure-eval==0.2.3 - py-cpuinfo==9.0.0 - - pyannote-core==5.0.0 - - pyannote-database==5.1.3 - - pyannote-metrics==3.2.1 - pyarrow==21.0.0 - pybase64==1.4.2 - pybind11==3.0.1 diff --git a/nemo/collections/asr/metrics/der.py b/nemo/collections/asr/metrics/der.py index 86342a8b6f81..ad56a04ccbd2 100644 --- a/nemo/collections/asr/metrics/der.py +++ b/nemo/collections/asr/metrics/der.py @@ -13,25 +13,241 @@ # limitations under the License. from itertools import permutations -from typing import Dict, List, Optional, Tuple +from typing import IO, Any, Dict, Iterable, List, Optional, Tuple import editdistance import numpy as np -import pandas as pd -from pyannote.core import Segment, Timeline -from pyannote.metrics.diarization import DiarizationErrorRate +from lhotse import SupervisionSegment from scipy.optimize import linear_sum_assignment as scipy_linear_sum_assignment +from nemo.collections.asr.metrics.md_eval import ( + EPSILON, + NOEVAL_SD, + DiarizationErrorResult, + SpeakerMap, + SpeakerOverlap, + _annotation_to_rttm_data, + _iter_annotation_segments, + _labels_to_rttm_data, + _merge_rttm_dicts, + _merge_uem_dicts, + _uem_list_to_uem_data, + add_exclusion_zones_to_uem, + create_speaker_segs, + evaluate, + get_uem_data, + map_speakers, + uem_from_rttm, +) from nemo.utils import logging __all__ = [ + 'get_partial_ref_labels', + 'get_online_DER_stats', 'score_labels', + 'evaluate_der', + 'score_labels_from_rttm_labels', 'calculate_session_cpWER', 'calculate_session_cpWER_bruteforce', 'concat_perm_word_error_rate', + # Lhotse-backed annotation/segment/timeline helpers. + 'make_diar_segment', + 'make_diar_annotation', + 'make_uem_timeline', + 'unique_speakers', + 'write_supervisions_to_rttm', ] +# ─── Lhotse-backed annotation helpers ────────────────────────────────────── +# +# NeMo's diarization code uses lhotse's ``SupervisionSegment`` / +# ``SupervisionSet`` (already a hard dependency via ``requirements_asr.txt``) +# as the carrier type for diarization annotations and UEM timelines. The +# helpers below provide a small adapter layer used throughout the DER +# pipeline: +# +# * ``make_diar_segment`` — build a single ``SupervisionSegment`` +# ``(start, end, speaker)``. +# * ``make_diar_annotation`` — build an annotation as a list of +# ``SupervisionSegment`` from ``"start end speaker"`` label strings. +# * ``make_uem_timeline`` — build a UEM (evaluation regions) timeline +# as a list of ``SupervisionSegment`` with ``speaker="UEM"`` so all +# annotation-like objects are uniform. +# * ``unique_speakers`` — return the unique speaker labels in an +# annotation. +# * ``write_supervisions_to_rttm`` — serialize an annotation to RTTM lines. +# +# Downstream consumers (``md_eval._annotation_to_rttm_data``) duck-type the +# input so any iterable of objects exposing ``start``/``end`` (or +# ``duration``) and ``speaker`` is accepted. + + +_DIAR_RECORDING_ID_PLACEHOLDER = "__diar__" + + +def make_diar_segment( + start: float, + end: float, + speaker: str, + recording_id: str = _DIAR_RECORDING_ID_PLACEHOLDER, + segment_id: Optional[str] = None, +) -> SupervisionSegment: + """Build a single diarization segment as a lhotse ``SupervisionSegment``. + + Args: + start: Segment start time in seconds. + end: Segment end time in seconds. + speaker: Speaker label. + recording_id: Recording (file) identifier (analogous to a recording + URI). Defaults to a placeholder when the caller does not yet + know the file id. + segment_id: Optional unique segment id; auto-generated when omitted. + + Returns: + A ``SupervisionSegment`` with ``start``, ``duration``, and + ``speaker`` populated. + """ + duration = max(0.0, float(end) - float(start)) + if segment_id is None: + segment_id = f"{recording_id}-{float(start):.6f}-{float(end):.6f}-{speaker}" + return SupervisionSegment( + id=segment_id, + recording_id=recording_id, + start=float(start), + duration=duration, + speaker=str(speaker), + ) + + +def make_diar_annotation( + labels: Iterable[str], + uniq_name: str = "", +) -> List[SupervisionSegment]: + """Build a diarization annotation from ``"start end speaker"`` label strings. + + Returns a list of ``SupervisionSegment`` accepted by every NeMo DER + helper. + + Args: + labels: Iterable of label strings, each formatted as + ``"start end speaker"``. + uniq_name: Recording / file identifier (used as the recording id of + each emitted supervision). + + Returns: + List of ``SupervisionSegment`` objects, one per label line. + """ + recording_id = uniq_name or _DIAR_RECORDING_ID_PLACEHOLDER + segments: List[SupervisionSegment] = [] + for idx, label in enumerate(labels): + parts = label.strip().split() + if len(parts) < 3: + continue + start, end, speaker = parts[0], parts[1], parts[2] + segments.append( + make_diar_segment( + start=float(start), + end=float(end), + speaker=speaker, + recording_id=recording_id, + segment_id=f"{recording_id}-{idx}", + ) + ) + return segments + + +def make_uem_timeline( + uem_lines: Iterable[Iterable[float]], + uniq_id: str, +) -> List[SupervisionSegment]: + """Build a UEM (evaluation region) timeline as a list of supervisions. + + Each region is represented as a ``SupervisionSegment`` with + ``speaker="UEM"`` so the same iteration patterns used for annotations + also work for UEMs. + + Args: + uem_lines: Iterable of ``[start, end]`` pairs in seconds. + uniq_id: Recording / file identifier. + + Returns: + List of ``SupervisionSegment`` objects representing the evaluation + regions for ``uniq_id``. + """ + segments: List[SupervisionSegment] = [] + for idx, span in enumerate(uem_lines): + span_list = list(span) + if len(span_list) < 2: + continue + start, end = float(span_list[0]), float(span_list[1]) + segments.append( + SupervisionSegment( + id=f"{uniq_id}-uem-{idx}", + recording_id=uniq_id, + start=start, + duration=max(0.0, end - start), + speaker="UEM", + ) + ) + return segments + + +def unique_speakers(annotation: Any) -> List[str]: + """Return the unique speaker labels in an annotation-like object. + + Accepts the lhotse-based annotation objects used throughout NeMo's DER + pipeline (list of ``SupervisionSegment`` / ``SupervisionSet``) as well as + any iterable whose items expose ``start`` / ``end`` (or ``duration``) and + ``speaker``. If the input exposes a ``.labels()`` method it is used + directly. + """ + if hasattr(annotation, "labels") and not isinstance(annotation, (list, tuple)): + try: + return list(annotation.labels()) + except TypeError: + pass + seen: List[str] = [] + seen_set = set() + for _start, _end, speaker in _iter_annotation_segments(annotation): + if speaker not in seen_set: + seen.append(speaker) + seen_set.add(speaker) + return seen + + +def write_supervisions_to_rttm( + annotation: Any, + file_handle: IO[str], + recording_id: Optional[str] = None, + channel: int = 1, +) -> None: + """Write an annotation-like object to ``file_handle`` in NIST RTTM format. + + Args: + annotation: Iterable of ``SupervisionSegment`` (or any object + accepted by :func:`md_eval._iter_annotation_segments`). + file_handle: An open text file handle. + recording_id: Recording identifier emitted in the second RTTM + column. When omitted, the ``recording_id`` of the first + supervision (or an empty string) is used. + channel: Channel id (1-indexed); RTTM convention defaults to ``1``. + """ + if recording_id is None: + first = next(iter(annotation), None) + recording_id = getattr(first, "recording_id", "") if first is not None else "" + + for start, end, speaker in _iter_annotation_segments(annotation): + duration = end - start + if duration <= 0: + continue + file_handle.write( + "SPEAKER {rid} {chnl} {start:.3f} {dur:.3f} {spk} \n".format( + rid=recording_id, chnl=channel, start=start, dur=duration, spk=speaker + ) + ) + + def get_partial_ref_labels(pred_labels: List[str], ref_labels: List[str]) -> List[str]: """ For evaluation of online diarization performance, generate partial reference labels @@ -109,22 +325,109 @@ def get_online_DER_stats( return der_dict, der_stat_dict -def uem_timeline_from_file(uem_file, uniq_name=''): - """ - Generate pyannote timeline segments for uem file +def _build_mapping_from_data( + ref_data: Dict, + sys_data: Dict, + uem_data: Optional[Dict], +) -> Dict[str, SpeakerMap]: + """Build per-file optimal speaker mappings from parsed RTTM data.""" + mapping_dict: Dict[str, SpeakerMap] = {} + for file_id in sorted(ref_data.keys()): + for chnl in sorted(ref_data[file_id].keys()): + ref_spkr_data = ref_data[file_id][chnl].get("SPEAKER") + sys_spkr_data = sys_data.get(file_id, {}).get(chnl, {}).get("SPEAKER", {}) + if not ref_spkr_data: + continue + + for segs in sys_spkr_data.values(): + for seg in segs: + seg.setdefault("RTBEG", seg["TBEG"]) + seg.setdefault("RTEND", seg["TEND"]) + seg["RTDUR"] = seg["RTEND"] - seg["RTBEG"] + seg["RTMID"] = seg["RTBEG"] + seg["RTDUR"] / 2.0 + + ref_rttm = ref_data[file_id][chnl].get("RTTM", []) + uem = uem_data.get(file_id, {}).get(chnl) if uem_data else None + if uem is None: + uem = uem_from_rttm(ref_rttm) + + uem_sd_eval = add_exclusion_zones_to_uem(NOEVAL_SD, uem, ref_rttm) + if not uem_sd_eval: + uem_sd_eval = uem + + eval_segs = create_speaker_segs(uem_sd_eval, ref_spkr_data, sys_spkr_data) + spkr_overlap: SpeakerOverlap = {} + for seg in eval_segs: + if not seg["REF"]: + continue + for rs in seg["REF"]: + for ss in seg["SYS"]: + spkr_overlap.setdefault(rs, {})[ss] = spkr_overlap.get(rs, {}).get(ss, 0.0) + seg["TDUR"] + mapping_dict[file_id] = map_speakers(spkr_overlap) + return mapping_dict + + +def _extract_errors(cum: Dict) -> Tuple[float, float, float, float]: + """Extract (DER, CER, FA, MISS) from aggregate stats.""" + scored = cum.get("SCORED_SPEAKER", 0.0) or EPSILON + if scored <= EPSILON: + raise ValueError("Total evaluation time is 0. Abort.") + missed = cum.get("MISSED_SPEAKER", 0.0) + falarm = cum.get("FALARM_SPEAKER", 0.0) + error = cum.get("SPEAKER_ERROR", 0.0) + DER = (missed + falarm + error) / scored + CER = error / scored + FA = falarm / scored + MISS = missed / scored + return DER, CER, FA, MISS + + +def _default_uem_from_ref_sys( + ref_data: Dict[str, Dict[str, Dict[str, Any]]], + sys_data: Dict[str, Dict[str, Dict[str, Any]]], +) -> Dict[str, Dict[str, List[Dict[str, float]]]]: + """Auto-derive a UEM that spans the union of reference and system extents. + + NeMo's DER wrappers historically delegated to an external scoring engine + that, when no UEM was provided, built its evaluation map from the union of + reference and hypothesis extents. Matching that convention here keeps DER + numbers reported by NeMo consistent with previously published results + (any system time extending past the last reference segment is correctly + counted as false alarm). + + The underlying ``md_eval.evaluate`` function defaults to a stricter + NIST ``md-eval-22.pl`` behaviour (eval map = reference extent only). This + helper bridges the two by constructing an explicit single-segment UEM + per ``(file_id, channel)`` pair that covers + ``[min(ref ∪ sys TBEG), max(ref ∪ sys TEND)]`` and passing it down so + ``evaluate`` uses it verbatim. - file format - UNIQ_SPEAKER_ID CHANNEL START_TIME END_TIME - """ - timeline = Timeline(uri=uniq_name) - with open(uem_file, 'r', encoding='utf-8') as f: - lines = f.readlines() - for line in lines: - line = line.strip() - _, _, start_time, end_time = line.split() - timeline.add(Segment(float(start_time), float(end_time))) + Args: + ref_data: Merged reference RTTM data dict + (``{file_id: {chnl: {"RTTM": [...], "SPEAKER": {...}, ...}}}``). + sys_data: Merged system RTTM data dict, same shape as ``ref_data``. - return timeline + Returns: + UEM data dict ``{file_id: {chnl: [{"TBEG": ..., "TEND": ...}]}}`` + with one segment per ``(file_id, chnl)`` pair found in either input. + Empty when both inputs are empty. + """ + valid_types = {"SEGMENT", "SPEAKER", "SU", "EDIT", "FILLER", "IP", "CB", "A/P", "LEXEME", "NON-LEX"} + file_ids = set(ref_data.keys()) | set(sys_data.keys()) + uem_data: Dict[str, Dict[str, List[Dict[str, float]]]] = {} + for file_id in file_ids: + chnls = set(ref_data.get(file_id, {}).keys()) | set(sys_data.get(file_id, {}).keys()) + for chnl in chnls: + tbeg, tend = float("inf"), float("-inf") + for src in (ref_data, sys_data): + rttm = src.get(file_id, {}).get(chnl, {}).get("RTTM", []) + for tok in rttm: + if tok.get("TYPE") in valid_types: + tbeg = min(tbeg, tok["TBEG"]) + tend = max(tend, tok["TEND"]) + if tend > tbeg: + uem_data.setdefault(file_id, {})[chnl] = [{"TBEG": tbeg, "TEND": tend}] + return uem_data def score_labels( @@ -135,92 +438,121 @@ def score_labels( collar: float = 0.25, ignore_overlap: bool = True, verbose: bool = True, -) -> Optional[Tuple[DiarizationErrorRate, Dict]]: +) -> Optional[Tuple[DiarizationErrorResult, Dict]]: """ - Calculate DER, CER, FA and MISS rate from hypotheses and references. Hypothesis results are - coming from Pyannote-formatted speaker diarization results and References are coming from - Pyannote-formatted RTTM data. + Calculate DER, CER, FA and MISS rate from hypotheses and references. Hypothesis and + reference annotations are lists of :class:`lhotse.SupervisionSegment` (typically + produced by :func:`labels_to_supervisions` from RTTM label strings). + + Internally uses the md-eval engine (a Python port of NIST md-eval-22.pl) + for DER computation. Args: AUDIO_RTTM_MAP (dict): Dictionary containing information provided from manifestpath - all_reference (list[uniq_name,Annotation]): reference annotations for score calculation - all_hypothesis (list[uniq_name,Annotation]): hypothesis annotations for score calculation + all_reference (list[uniq_name, list[SupervisionSegment]]): reference annotations + for score calculation. + all_hypothesis (list[uniq_name, list[SupervisionSegment]]): hypothesis annotations + for score calculation. all_uem (list[list[float]]): List of UEM segments for each audio file. If UEM file is not provided, it will be read from manifestpath - collar (float): Length of collar (in seconds) for diarization error rate calculation + collar (float): No-score collar **half-width** in seconds, following NIST + ``md-eval-22.pl`` semantics. The total no-score zone around every + reference boundary is ``2 * collar`` seconds. This matches the + historical NeMo public contract. + + Note on cross-implementation parity: some external annotation + libraries define their ``collar`` argument as the **total** width + of the no-score zone (i.e., they use ``collar / 2`` on each side). + To reproduce a NeMo result with such a library, pass ``2 * collar`` + there. For example, NeMo's ``collar=0.05`` is equivalent to those + libraries' ``collar=0.10``. ignore_overlap (bool): If True, overlapping segments in reference and hypothesis will be ignored verbose (bool): If True, warning messages will be printed Returns: - metric (pyannote.DiarizationErrorRate): Pyannote Diarization Error Rate metric object. - This object contains detailed scores of each audiofile. + metric (DiarizationErrorResult): Diarization Error Rate metric object. + This object contains detailed scores of each audiofile. mapping (dict): Mapping dict containing the mapping speaker label for each audio input itemized_errors (tuple): Tuple containing (DER, CER, FA, MISS) for each audio file. - DER: Diarization Error Rate, which is sum of all three errors, CER + FA + MISS. - CER: Confusion Error Rate, which is sum of all errors - FA: False Alarm Rate, which is the number of false alarm segments - MISS: Missed Detection Rate, which is the number of missed detection segments - - < Caveat > - Unlike md-eval.pl, "no score" collar in pyannote.metrics is the maximum length of - "no score" collar from left to right. Therefore, if 0.25s is applied for "no score" - collar in md-eval.pl, 0.5s should be applied for pyannote.metrics. """ - metric = None - if len(all_reference) == len(all_hypothesis): - metric = DiarizationErrorRate(collar=2 * collar, skip_overlap=ignore_overlap) - - mapping_dict, correct_spk_count = {}, 0 - for idx, (reference, hypothesis) in enumerate(zip(all_reference, all_hypothesis)): - ref_key, ref_labels = reference - _, hyp_labels = hypothesis - if len(ref_labels.labels()) == len(hyp_labels.labels()): - correct_spk_count += 1 - if verbose and len(ref_labels.labels()) != len(hyp_labels.labels()): - logging.info( - f"Wrong Spk. Count with uniq_id:...{ref_key[-10:]}, " - f"Ref: {len(ref_labels.labels())}, Hyp: {len(hyp_labels.labels())}" - ) - uem_obj = None - if all_uem is not None: - metric(ref_labels, hyp_labels, uem=all_uem[idx], detailed=True) - elif AUDIO_RTTM_MAP[ref_key].get('uem_filepath', None) is not None: - uem_file = AUDIO_RTTM_MAP[ref_key].get('uem_filepath', None) - uem_obj = uem_timeline_from_file(uem_file=uem_file, uniq_name=ref_key) - metric(ref_labels, hyp_labels, uem=uem_obj, detailed=True) - else: - metric(ref_labels, hyp_labels, detailed=True) - mapping_dict[ref_key] = metric.optimal_mapping(ref_labels, hyp_labels) - - spk_count_acc = correct_spk_count / len(all_reference) - DER = abs(metric) - if metric['total'] == 0: - raise ValueError("Total evaluation time is 0. Abort.") - CER = metric['confusion'] / metric['total'] - FA = metric['false alarm'] / metric['total'] - MISS = metric['missed detection'] / metric['total'] - - itemized_errors = (DER, CER, FA, MISS) - + if len(all_reference) != len(all_hypothesis): if verbose: - pd.set_option('display.max_rows', None) # Show all rows - pd.set_option('display.max_columns', None) # Show all columns - pd.set_option('display.width', None) # Adjust width to avoid line wrapping - pd.set_option('display.max_colwidth', None) # Show full content of each cell - logging.info(f"\n{metric.report()}") - logging.info( - f"Cumulative Results for collar {collar} sec and ignore_overlap {ignore_overlap}: \n" - f"| FA: {FA:.4f} | MISS: {MISS:.4f} | CER: {CER:.4f} | DER: {DER:.4f} | " - f"Spk. Count Acc. {spk_count_acc:.4f}\n" - ) - - return metric, mapping_dict, itemized_errors - elif verbose: - logging.warning( - "Check if each ground truth RTTMs were present in the provided manifest file. " - "Skipping calculation of Diariazation Error Rate" - ) - return None + logging.warning( + "Check if each ground truth RTTMs were present in the provided manifest file. " + "Skipping calculation of Diarization Error Rate" + ) + return None + + ref_dicts = [] + sys_dicts = [] + uem_dicts = [] + correct_spk_count = 0 + + for idx, (reference, hypothesis) in enumerate(zip(all_reference, all_hypothesis)): + ref_key, ref_labels = reference + _, hyp_labels = hypothesis + + ref_n_spk = len(unique_speakers(ref_labels)) + hyp_n_spk = len(unique_speakers(hyp_labels)) + if ref_n_spk == hyp_n_spk: + correct_spk_count += 1 + if verbose and ref_n_spk != hyp_n_spk: + logging.info(f"Wrong Spk. Count with uniq_id:...{ref_key[-10:]}, " f"Ref: {ref_n_spk}, Hyp: {hyp_n_spk}") + + ref_dicts.append(_annotation_to_rttm_data(ref_key, ref_labels)) + sys_dicts.append(_annotation_to_rttm_data(ref_key, hyp_labels)) + + if all_uem is not None: + uem_obj = all_uem[idx] + uem_segs = [[seg.start, seg.end] for seg in uem_obj] + uem_dicts.append(_uem_list_to_uem_data(ref_key, uem_segs)) + elif AUDIO_RTTM_MAP[ref_key].get('uem_filepath', None) is not None: + uem_file_data = get_uem_data(AUDIO_RTTM_MAP[ref_key]['uem_filepath']) + if uem_file_data: + uem_dicts.append(uem_file_data) + + ref_data = _merge_rttm_dicts(ref_dicts) + sys_data = _merge_rttm_dicts(sys_dicts) + uem_data = _merge_uem_dicts(uem_dicts) if uem_dicts else None + if uem_data is None: + uem_data = _default_uem_from_ref_sys(ref_data, sys_data) + + all_scores, cum = evaluate( + ref_data, + sys_data, + uem_data=uem_data, + collar=collar, + opt_1=ignore_overlap, + verbose=False, + ) + + mapping_dict = _build_mapping_from_data(ref_data, sys_data, uem_data) + + DER, CER, FA, MISS = _extract_errors(cum) + itemized_errors = (DER, CER, FA, MISS) + spk_count_acc = correct_spk_count / len(all_reference) + + metric = DiarizationErrorResult( + all_scores=all_scores, + cum=cum, + mapping_dict=mapping_dict, + collar=collar, + ignore_overlap=ignore_overlap, + ) + + if verbose: + logging.info(f"\n{metric.report()}") + logging.info( + f"Cumulative Results for collar {collar} sec and ignore_overlap {ignore_overlap}: \n" + f"| FA: {FA:.4f} | MISS: {MISS:.4f} | CER: {CER:.4f} | DER: {DER:.4f} | " + f"Spk. Count Acc. {spk_count_acc:.4f}\n" + ) + + return metric, mapping_dict, itemized_errors def evaluate_der(audio_rttm_map_dict, all_reference, all_hypothesis, diar_eval_mode='all'): @@ -271,6 +603,99 @@ def evaluate_der(audio_rttm_map_dict, all_reference, all_hypothesis, diar_eval_m return diar_score +def score_labels_from_rttm_labels( + ref_labels_list: List[Tuple[str, List[str]]], + hyp_labels_list: List[Tuple[str, List[str]]], + uem_segments_list: Optional[List[Tuple[str, List[List[float]]]]] = None, + collar: float = 0.25, + ignore_overlap: bool = True, + verbose: bool = True, +) -> Optional[Tuple[DiarizationErrorResult, Dict[str, SpeakerMap], Tuple[float, float, float, float]]]: + """Score diarization directly from plain ``"start end speaker"`` label strings. + + Convenience function for callers that have labels as ``"start end speaker"`` + strings rather than pre-built supervision lists. + + Args: + ref_labels_list: List of ``(uniq_id, [label_strings])`` reference labels. + hyp_labels_list: List of ``(uniq_id, [label_strings])`` hypothesis labels. + uem_segments_list: Optional list of ``(uniq_id, [[start, end], ...])`` UEM segments. + collar: No-score collar **half-width** in seconds, following NIST + ``md-eval-22.pl`` semantics. The total no-score zone around every + reference boundary is ``2 * collar`` seconds. To reproduce a NeMo + result with an external annotation library that defines its + ``collar`` argument as the **total** width of the no-score zone, + pass ``2 * collar`` there. For example, NeMo's ``collar=0.05`` is + equivalent to those libraries' ``collar=0.10``. + ignore_overlap: If ``True``, restrict scoring to single-speaker regions. + verbose: If ``True``, log detailed results. + + Returns: + Same format as :func:`score_labels`, or ``None`` if counts don't match. + """ + if len(ref_labels_list) != len(hyp_labels_list): + if verbose: + logging.warning( + "Reference and hypothesis label lists must have the same length. " + "Skipping calculation of Diarization Error Rate" + ) + return None + + ref_dicts = [_labels_to_rttm_data(uid, labels) for uid, labels in ref_labels_list] + sys_dicts = [_labels_to_rttm_data(uid, labels) for uid, labels in hyp_labels_list] + + uem_dicts = [] + if uem_segments_list: + for uid, segs in uem_segments_list: + uem_dicts.append(_uem_list_to_uem_data(uid, segs)) + + ref_data = _merge_rttm_dicts(ref_dicts) + sys_data = _merge_rttm_dicts(sys_dicts) + uem_data = _merge_uem_dicts(uem_dicts) if uem_dicts else None + if uem_data is None: + uem_data = _default_uem_from_ref_sys(ref_data, sys_data) + + all_scores, cum = evaluate( + ref_data, + sys_data, + uem_data=uem_data, + collar=collar, + opt_1=ignore_overlap, + verbose=False, + ) + + mapping_dict = _build_mapping_from_data(ref_data, sys_data, uem_data) + + DER, CER, FA, MISS = _extract_errors(cum) + itemized_errors = (DER, CER, FA, MISS) + + correct_spk_count = 0 + for (_, ref_labels), (_, hyp_labels) in zip(ref_labels_list, hyp_labels_list): + ref_spkrs = {lbl.strip().split()[2] for lbl in ref_labels} + hyp_spkrs = {lbl.strip().split()[2] for lbl in hyp_labels} + if len(ref_spkrs) == len(hyp_spkrs): + correct_spk_count += 1 + spk_count_acc = correct_spk_count / len(ref_labels_list) + + metric = DiarizationErrorResult( + all_scores=all_scores, + cum=cum, + mapping_dict=mapping_dict, + collar=collar, + ignore_overlap=ignore_overlap, + ) + + if verbose: + logging.info(f"\n{metric.report()}") + logging.info( + f"Cumulative Results for collar {collar} sec and ignore_overlap {ignore_overlap}: \n" + f"| FA: {FA:.4f} | MISS: {MISS:.4f} | CER: {CER:.4f} | DER: {DER:.4f} | " + f"Spk. Count Acc. {spk_count_acc:.4f}\n" + ) + + return metric, mapping_dict, itemized_errors + + def calculate_session_cpWER_bruteforce(spk_hypothesis: List[str], spk_reference: List[str]) -> Tuple[float, str, str]: """ Calculate cpWER with brute-force permutation search. Matches MeetEval's cpWER algorithm: diff --git a/nemo/collections/asr/metrics/md_eval.py b/nemo/collections/asr/metrics/md_eval.py new file mode 100644 index 000000000000..a94c914f3464 --- /dev/null +++ b/nemo/collections/asr/metrics/md_eval.py @@ -0,0 +1,1394 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script is mostly a Python port of the NIST 'md-eval' (version 22) +Perl script originally found in the dscore repository: +https://github.com/nryant/dscore/blob/master/scorelib/md-eval-22.pl + +Original Author: NIST (National Institute of Standards and Technology) +Ported by: [Your Name/Organization] + +Bipartite speaker matching uses ``scipy.optimize.linear_sum_assignment`` +(Hungarian algorithm, minimisation) in place of the Perl ``weighted_bipartite_graph_match``, +with negated overlap costs so that maximising total overlap becomes a minimisation problem. + +Data-flow is a direct translation of the Perl: + ``get_rttm_data`` → ``evaluate`` → ``score_speaker_diarization`` + → ``create_speaker_segs`` (builds overlap matrix) + → ``map_speakers`` (``linear_sum_assignment``) + → ``score_speaker_segments`` (tallies DER components) + → ``format_sd_scores`` +""" +# ============================================================================== +# ORIGINAL BSD-2-CLAUSE LICENSE NOTICE (from dscore) +# ============================================================================== +# Copyright (c) 2017, Neville Ryant +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# ============================================================================== + +import re +import warnings +from typing import Any, Dict, Iterator, List, Optional, Tuple + +import numpy as np +from scipy.optimize import linear_sum_assignment + +from nemo.utils import logging + +__all__ = [ + 'EPSILON', + 'DEFAULT_COLLAR', + 'DEFAULT_EXTEND', + 'NOEVAL_SD', + 'NOSCORE_SD', + 'RTTM_DATATYPES', + 'get_rttm_data', + 'get_uem_data', + 'uem_from_rttm', + 'add_exclusion_zones_to_uem', + 'add_collars_to_uem', + 'exclude_overlapping_speech_from_uem', + 'create_speaker_segs', + 'map_speakers', + 'score_speaker_diarization', + 'evaluate', + 'DiarizationErrorResult', +] + +# ─── Type aliases ────────────────────────────────────────────────────────── +Token = Dict[str, Any] +UEMSegment = Dict[str, float] +SpeakerData = Dict[str, List[Token]] +SpeakerOverlap = Dict[str, Dict[str, float]] +SpeakerMap = Dict[str, str] +ScoreSegment = Dict[str, Any] +SDStats = Dict[str, Any] + +# ─── Constants ───────────────────────────────────────────────────────────── + +EPSILON: float = 1e-8 +MISS_NAME: str = " MISS" +FA_NAME: str = " FALSE ALARM" +DEFAULT_COLLAR: float = 0.00 +DEFAULT_EXTEND: float = 0.50 + +NOEVAL_SD: Dict[str, Dict[str, int]] = { + "NOSCORE": {"": 1}, +} +NOSCORE_SD: Dict[str, Dict[str, int]] = { + "NOSCORE": {"": 1}, + "NON-LEX": {"laugh": 1, "breath": 1, "lipsmack": 1, "cough": 1, "sneeze": 1, "other": 1}, +} + +RTTM_DATATYPES: Dict[str, Dict[str, int]] = { + "SEGMENT": {"eval": 1, "": 1}, + "NOSCORE": {"": 1}, + "NO_RT_METADATA": {"": 1}, + "LEXEME": { + "lex": 1, + "fp": 1, + "frag": 1, + "un-lex": 1, + "for-lex": 1, + "alpha": 1, + "acronym": 1, + "interjection": 1, + "propernoun": 1, + "other": 1, + }, + "NON-LEX": {"laugh": 1, "breath": 1, "lipsmack": 1, "cough": 1, "sneeze": 1, "other": 1}, + "NON-SPEECH": {"noise": 1, "music": 1, "other": 1}, + "FILLER": { + "filled_pause": 1, + "discourse_marker": 1, + "discourse_response": 1, + "explicit_editing_term": 1, + "other": 1, + }, + "EDIT": {"repetition": 1, "restart": 1, "revision": 1, "simple": 1, "complex": 1, "other": 1}, + "IP": {"edit": 1, "filler": 1, "edit&filler": 1, "other": 1}, + "SU": { + "statement": 1, + "backchannel": 1, + "question": 1, + "incomplete": 1, + "unannotated": 1, + "other": 1, + }, + "CB": {"coordinating": 1, "clausal": 1, "other": 1}, + "A/P": {"": 1}, + "SPEAKER": {"": 1}, + "SPKR-INFO": {"adult_male": 1, "adult_female": 1, "child": 1, "unknown": 1}, +} + + +# ─── RTTM / UEM parsing ─────────────────────────────────────────────────── + + +def get_rttm_data(data: Dict[str, Dict[str, Dict[str, Any]]], rttm_file: Optional[str]) -> None: + """Parse one RTTM file into a nested data dictionary (in-place). + + The resulting structure is:: + + data[file_id][chnl]['SPEAKER'][spkr] = [token, ...] + data[file_id][chnl]['RTTM'] = [token, ...] (all non-SPKR-INFO) + data[file_id][chnl]['SPKR-INFO'][spkr]= {'GENDER': str} + data[file_id][chnl]['LEXEME'] = [token, ...] + + Args: + data: Mutable dictionary to populate with parsed RTTM data. + rttm_file: Path to the RTTM file. If ``None``, the function returns immediately. + + Raises: + ValueError: If a record has fewer than 9 fields or contains a negative duration. + """ + if rttm_file is None: + return + + with open(rttm_file, encoding='utf-8') as fh: + for raw in fh: + line = raw.strip() + if not line or line[0] in "#;": + continue + fields = line.split() + if fields[0] == "": + fields = fields[1:] + if len(fields) < 9: + raise ValueError(f"Insufficient fields in RTTM file '{rttm_file}'\n record: {line!r}") + + data_type = fields[0].upper() + tbeg_str = fields[3].replace("*", "") + tdur_str = fields[4].replace("*", "") + token: Token = { + "TYPE": data_type, + "FILE": fields[1], + "CHNL": fields[2].lower(), + "TBEG": 0.0 if tbeg_str.lower() == "" else float(tbeg_str), + "TDUR": 0.0 if tdur_str.lower() == "" else float(tdur_str), + "WORD": fields[5].lower(), + "SUBT": fields[6].lower(), + "SPKR": fields[7] if len(fields) > 7 else "", + "CONF": fields[8].lower() if len(fields) > 8 else "-", + } + if token["TDUR"] < 0: + raise ValueError(f"Negative duration in '{rttm_file}': {line!r}") + + file_id = token["FILE"] + chnl = token["CHNL"] + + if data_type != "SPKR-INFO": + token["TEND"] = token["TBEG"] + token["TDUR"] + token["TMID"] = token["TBEG"] + token["TDUR"] / 2.0 + + if data_type == "SPKR-INFO": + (data.setdefault(file_id, {}).setdefault(chnl, {}).setdefault("SPKR-INFO", {}))[token["SPKR"]] = { + "GENDER": token["SUBT"] + } + + elif data_type == "SPEAKER": + ( + data.setdefault(file_id, {}) + .setdefault(chnl, {}) + .setdefault("SPEAKER", {}) + .setdefault(token["SPKR"], []) + ).append(token) + data[file_id][chnl].setdefault("RTTM", []).append(token) + + elif data_type == "LEXEME": + data.setdefault(file_id, {}).setdefault(chnl, {}) + data[file_id][chnl].setdefault("LEXEME", []).append(token) + data[file_id][chnl].setdefault("RTTM", []).append(token) + + else: + data.setdefault(file_id, {}).setdefault(chnl, {}) + data[file_id][chnl].setdefault("RTTM", []).append(token) + + # Post-parse: sort SPEAKER segments; stamp gender from SPKR-INFO + for file_id in data: + for chnl in data[file_id]: + spkr_info = data[file_id][chnl].get("SPKR-INFO", {}) + spkr_data = data[file_id][chnl].get("SPEAKER", {}) + for spkr, segs in spkr_data.items(): + gender = spkr_info.get(spkr, {}).get("GENDER") or "unknown" + spkr_info.setdefault(spkr, {})["GENDER"] = gender + segs.sort(key=lambda t: t["TMID"]) + for tok in segs: + tok["SUBT"] = gender + + +def get_uem_data( + uem_file: Optional[str], keep_directory: bool = False +) -> Optional[Dict[str, Dict[str, List[UEMSegment]]]]: + """Parse a UEM (Un-partitioned Evaluation Map) file. + + Args: + uem_file: Path to the UEM file. If ``None``, returns ``None``. + keep_directory: If ``False``, strip directory prefixes from file IDs. + + Returns: + Nested dictionary ``data[file_id][chnl] = [seg, ...]`` where each segment is a + dict with keys ``FILE``, ``CHNL``, ``TBEG``, ``TEND``. Returns ``None`` when + *uem_file* is ``None``. + + Raises: + ValueError: If a record has fewer than 4 fields. + """ + if uem_file is None: + return None + + data: Dict[str, Dict[str, List[UEMSegment]]] = {} + with open(uem_file, encoding='utf-8') as fh: + for raw in fh: + line = raw.strip() + if not line or line[0] in "#;": + continue + fields = line.split() + if len(fields) < 4: + raise ValueError(f"Insufficient UEM fields: {line!r}") + seg: UEMSegment = { + "FILE": fields[0], + "CHNL": fields[1].lower(), + "TBEG": float(re.sub(r"[^0-9.]", "", fields[2])), + "TEND": float(re.sub(r"[^0-9.]", "", fields[3])), + } + if not keep_directory: + seg["FILE"] = re.sub(r".*/", "", seg["FILE"]) + seg["FILE"] = re.sub(r"\.[^.]*$", "", seg["FILE"]) + data.setdefault(seg["FILE"], {}).setdefault(seg["CHNL"], []).append(seg) + + for file_id in data: + for chnl in data[file_id]: + data[file_id][chnl].sort(key=lambda s: s["TBEG"]) + return data + + +def uem_from_rttm(rttm_data: List[Token]) -> List[UEMSegment]: + """Derive a single UEM partition spanning all RTTM tokens. + + Args: + rttm_data: List of RTTM token dicts, each with keys ``TYPE``, ``TBEG``, ``TEND``. + + Returns: + A single-element list ``[{"TBEG": ..., "TEND": ...}]`` covering the full time range. + """ + valid = {"SEGMENT", "SPEAKER", "SU", "EDIT", "FILLER", "IP", "CB", "A/P", "LEXEME", "NON-LEX"} + tbeg, tend = 1e30, 0.0 + for tok in rttm_data: + if tok["TYPE"] in valid: + tbeg = min(tbeg, tok["TBEG"]) + tend = max(tend, tok["TEND"]) + return [{"TBEG": tbeg, "TEND": tend}] + + +# ─── UEM manipulation helpers ───────────────────────────────────────────── + + +def _key_end_before_beg(e: Dict[str, Any]) -> Tuple[float, int]: + """Sort key: order by time; END (0) before BEG (1) at equal times.""" + return (e["TIME"], 0 if e["EVENT"] == "END" else 1) + + +def _key_beg_before_end(e: Dict[str, Any]) -> Tuple[float, int]: + """Sort key: order by time; BEG (0) before END (1) at equal times.""" + return (e["TIME"], 0 if e["EVENT"] == "BEG" else 1) + + +def add_exclusion_zones_to_uem( + excluded_tokens: Dict[str, Dict[str, int]], + uem_score: Optional[List[UEMSegment]], + rttm_data: List[Token], + max_extend: Optional[float] = None, +) -> List[UEMSegment]: + """Remove excluded-token regions from the UEM. + + Direct port of the Perl ``add_exclusion_zones_to_uem`` subroutine. Extends NON-LEX + no-score zones by *max_extend* seconds toward speech anchors. + + Args: + excluded_tokens: Mapping ``{type: {subtype: 1}, ...}`` specifying which tokens to exclude. + uem_score: Current scored UEM segments. May be ``None``. + rttm_data: List of all RTTM tokens for this file/channel. + max_extend: Maximum extension (seconds) for NON-LEX no-score zones. + + Returns: + Updated UEM segment list with exclusion zones removed. Returns *uem_score* unchanged + when there is nothing to exclude. + """ + if not excluded_tokens: + return uem_score + + max_ext = max_extend if (max_extend and max_extend >= EPSILON) else EPSILON + ns_events: List[Dict[str, Any]] = [] + + for tok in rttm_data: + if tok.get("TDUR", 0) <= 0: + continue + ttype = tok["TYPE"] + subt = tok["SUBT"] + + if ttype == "LEXEME": + if not excluded_tokens.get("LEXEME", {}).get(subt): + ns_events.append({"TYPE": "LEX", "EVENT": "BEG", "TIME": tok["TBEG"]}) + ns_events.append({"TYPE": "LEX", "EVENT": "END", "TIME": tok["TEND"]}) + elif ttype == "SPEAKER": + ns_events.append({"TYPE": "SEG", "EVENT": "BEG", "TIME": tok["TBEG"]}) + ns_events.append({"TYPE": "SEG", "EVENT": "END", "TIME": tok["TEND"]}) + elif excluded_tokens.get(ttype, {}).get(subt): + ns_events.append({"TYPE": "NSZ", "EVENT": "BEG", "TIME": tok["TBEG"]}) + ns_events.append({"TYPE": "NSZ", "EVENT": "END", "TIME": tok["TEND"]}) + + ns_events.sort(key=_key_end_before_beg) + + # Phase 1: build noscore-zone boundary events + events: List[Dict[str, Any]] = [] + evaluating = 1 + tseg = tend_nsz = tend_lex = 0.0 + lex_cnt = nsz_cnt = 0 + + for ev in ns_events: + etype, eevt, etime = ev["TYPE"], ev["EVENT"], ev["TIME"] + + if etype == "LEX": + if eevt == "BEG": + lex_cnt += 1 + else: + lex_cnt -= 1 + if lex_cnt == 0: + tend_lex = etime + elif etype == "NSZ": + if eevt == "BEG": + nsz_cnt += 1 + else: + nsz_cnt -= 1 + if nsz_cnt == 0: + tend_nsz = etime + elif etype == "SEG": + tseg = etime + + if evaluating: + if nsz_cnt == 0 or etype != "NSZ": + continue + tstop = etime if lex_cnt > 0 else max(tend_lex, tseg, etime - max_ext) + events.append({"TYPE": "NSZ", "EVENT": "BEG", "TIME": tstop}) + evaluating = 0 + elif nsz_cnt == 0 and (lex_cnt > 0 or etype == "SEG"): + tstart = min(tend_nsz + max_ext, etime) + events.append({"TYPE": "NSZ", "EVENT": "END", "TIME": tstart}) + evaluating = 1 + elif nsz_cnt == 1 and etype == "NSZ" and eevt == "BEG" and etime > tend_nsz + 2 * max_ext: + events.append({"TYPE": "NSZ", "EVENT": "END", "TIME": tend_nsz + max_ext}) + events.append({"TYPE": "NSZ", "EVENT": "BEG", "TIME": etime - max_ext}) + evaluating = 0 + + # Phase 2: merge NSZ events with UEM + for uem in uem_score or []: + if uem["TEND"] - uem["TBEG"] > 0: + events.append({"TYPE": "UEM", "EVENT": "BEG", "TIME": uem["TBEG"]}) + events.append({"TYPE": "UEM", "EVENT": "END", "TIME": uem["TEND"]}) + + events.sort(key=_key_end_before_beg) + + evl_cnt = nsz_cnt = evaluating = 0 + tbeg = 0.0 + uem_ex: List[UEMSegment] = [] + for ev in events: + if ev["TYPE"] == "UEM": + evl_cnt += 1 if ev["EVENT"] == "BEG" else -1 + elif ev["TYPE"] == "NSZ": + nsz_cnt += 1 if ev["EVENT"] == "BEG" else -1 + + if evaluating and (evl_cnt == 0 or nsz_cnt > 0) and ev["TIME"] > tbeg: + uem_ex.append({"TBEG": tbeg, "TEND": ev["TIME"]}) + evaluating = 0 + elif evl_cnt > 0 and nsz_cnt == 0: + tbeg = ev["TIME"] + evaluating = 1 + + return uem_ex if uem_ex else uem_score + + +def add_collars_to_uem( + uem_eval: List[UEMSegment], + ref_spkr_data: SpeakerData, + collar: float, +) -> List[UEMSegment]: + """Remove ±collar-second zones around every reference speaker boundary. + + Args: + uem_eval: Evaluation UEM segments. + ref_spkr_data: Reference speaker data ``{spkr: [token, ...]}``. + collar: No-score collar width in seconds. + + Returns: + New UEM segment list with collar zones removed. + """ + events: List[Dict[str, Any]] = [] + for uem in uem_eval: + events.append({"EVENT": "BEG", "TIME": uem["TBEG"]}) + events.append({"EVENT": "END", "TIME": uem["TEND"]}) + for segs in ref_spkr_data.values(): + for seg in segs: + events.append({"EVENT": "END", "TIME": seg["TBEG"] - collar}) + events.append({"EVENT": "BEG", "TIME": seg["TBEG"] + collar}) + events.append({"EVENT": "END", "TIME": seg["TEND"] - collar}) + events.append({"EVENT": "BEG", "TIME": seg["TEND"] + collar}) + + events.sort(key=_key_beg_before_end) + + evaluate = 0 + tbeg = 0.0 + uem_out: List[UEMSegment] = [] + for ev in events: + if ev["EVENT"] == "BEG": + evaluate += 1 + if evaluate == 1: + tbeg = ev["TIME"] + else: + evaluate -= 1 + if evaluate == 0 and ev["TIME"] > tbeg: + uem_out.append({"TBEG": tbeg, "TEND": ev["TIME"]}) + return uem_out + + +def exclude_overlapping_speech_from_uem(uem_data: List[UEMSegment], rttm_data: List[Token]) -> List[UEMSegment]: + """Remove regions where two or more reference speakers overlap simultaneously. + + Args: + uem_data: Current UEM segments to modify. + rttm_data: List of all RTTM tokens for this file/channel. + + Returns: + New UEM segment list with overlap regions excluded. + """ + spkr_evs: List[Dict[str, Any]] = [] + for tok in rttm_data: + if tok["TYPE"] == "SPEAKER" and tok["TDUR"] > 0: + spkr_evs.append({"EVENT": "BEG", "TIME": tok["TBEG"]}) + spkr_evs.append({"EVENT": "END", "TIME": tok["TEND"]}) + + spkr_evs.sort(key=_key_end_before_beg) + + events: List[Dict[str, Any]] = [] + spkr_cnt = 0 + tbeg_ovlap = 0.0 + for ev in spkr_evs: + if ev["EVENT"] == "BEG": + spkr_cnt += 1 + if spkr_cnt == 2: + tbeg_ovlap = ev["TIME"] + else: + spkr_cnt -= 1 + if spkr_cnt == 1: + events.append({"TYPE": "NSZ", "EVENT": "BEG", "TIME": tbeg_ovlap}) + events.append({"TYPE": "NSZ", "EVENT": "END", "TIME": ev["TIME"]}) + + for uem in uem_data: + if uem["TEND"] - uem["TBEG"] > 0: + events.append({"TYPE": "UEM", "EVENT": "BEG", "TIME": uem["TBEG"]}) + events.append({"TYPE": "UEM", "EVENT": "END", "TIME": uem["TEND"]}) + + events.sort(key=_key_end_before_beg) + + tbeg = 0.0 + evl_cnt = nsz_cnt = evaluating = 0 + uem_ex: List[UEMSegment] = [] + for ev in events: + if ev["TYPE"] == "UEM": + evl_cnt += 1 if ev["EVENT"] == "BEG" else -1 + elif ev["TYPE"] == "NSZ": + nsz_cnt += 1 if ev["EVENT"] == "BEG" else -1 + + if evaluating and (evl_cnt == 0 or nsz_cnt > 0) and ev["TIME"] > tbeg: + uem_ex.append({"TBEG": tbeg, "TEND": ev["TIME"]}) + evaluating = 0 + elif evl_cnt > 0 and nsz_cnt == 0: + tbeg = ev["TIME"] + evaluating = 1 + return uem_ex + + +# ─── Speaker segment timeline ───────────────────────────────────────────── + + +def create_speaker_segs( + uem_score: Optional[List[UEMSegment]], + ref_data: SpeakerData, + sys_data: SpeakerData, +) -> List[ScoreSegment]: + """Build a piecewise-constant timeline of ``(ref_spkrs, sys_spkrs)`` sets. + + Ports the Perl ``create_speaker_segs`` exactly: + - UEM gates which time regions are evaluated + - Segments are cut at every event boundary + - At equal times, END events are processed before BEG events + (with ε-tolerance matching the Perl epsilon comparison) + + Args: + uem_score: Scored UEM segments. May be ``None``. + ref_data: Reference speaker data ``{spkr: [token, ...]}``. + sys_data: System speaker data ``{spkr: [token, ...]}``. + + Returns: + List of score segments, each a dict with keys ``REF``, ``SYS``, ``TBEG``, + ``TEND``, ``TDUR``. + """ + events: List[Dict[str, Any]] = [] + for uem in uem_score or []: + if uem["TEND"] > uem["TBEG"] + EPSILON: + events.append({"TYPE": "UEM", "EVENT": "BEG", "TIME": uem["TBEG"]}) + events.append({"TYPE": "UEM", "EVENT": "END", "TIME": uem["TEND"]}) + + for spkr, segs in ref_data.items(): + for seg in segs: + if seg["TDUR"] > 0: + events.append({"TYPE": "REF", "SPKR": spkr, "EVENT": "BEG", "TIME": seg["TBEG"]}) + events.append({"TYPE": "REF", "SPKR": spkr, "EVENT": "END", "TIME": seg["TEND"]}) + + for spkr, segs in sys_data.items(): + for seg in segs: + if seg["TDUR"] > 0: + tbeg = seg.get("RTBEG", seg["TBEG"]) + tend = seg.get("RTEND", seg["TEND"]) + events.append({"TYPE": "SYS", "SPKR": spkr, "EVENT": "BEG", "TIME": tbeg}) + events.append({"TYPE": "SYS", "SPKR": spkr, "EVENT": "END", "TIME": tend}) + + events.sort(key=_key_end_before_beg) + + evaluate = 0 + tbeg = 0.0 + ref_spkrs: Dict[str, int] = {} + sys_spkrs: Dict[str, int] = {} + segments: List[ScoreSegment] = [] + + for ev in events: + if evaluate and tbeg < ev["TIME"] - EPSILON: + tend = ev["TIME"] + segments.append( + { + "REF": dict(ref_spkrs), + "SYS": dict(sys_spkrs), + "TBEG": tbeg, + "TEND": tend, + "TDUR": tend - tbeg, + } + ) + tbeg = tend + + if ev["TYPE"] == "UEM": + if ev["EVENT"] == "BEG": + evaluate = 1 + tbeg = ev["TIME"] + else: + evaluate = 0 + else: + spkrs = ref_spkrs if ev["TYPE"] == "REF" else sys_spkrs + spkr = ev["SPKR"] + if ev["EVENT"] == "BEG": + spkrs[spkr] = spkrs.get(spkr, 0) + 1 + if spkrs[spkr] > 1: + warnings.warn(f"Speaker {spkr} speaking more than once at t={ev['TIME']}") + else: + cnt = spkrs.get(spkr, 0) - 1 + if cnt <= 0: + spkrs.pop(spkr, None) + else: + spkrs[spkr] = cnt + + return segments + + +# ─── Bipartite speaker matching ─────────────────────────────────────────── + + +def map_speakers(spkr_overlap: SpeakerOverlap) -> SpeakerMap: + """Map reference speakers to system speakers to maximise total overlap time. + + Direct replacement for the Perl ``weighted_bipartite_graph_match``. + ``scipy.optimize.linear_sum_assignment`` minimises cost; the overlap matrix is + negated to turn maximisation into minimisation. + + Args: + spkr_overlap: Mapping ``{ref_spkr: {sys_spkr: seconds_overlap}}``. + + Returns: + Mapping ``{ref_spkr: sys_spkr}`` containing only genuinely-overlapping pairs. + """ + if not spkr_overlap: + return {} + + ref_spkrs = sorted(spkr_overlap.keys()) + sys_spkrs_set: set = set() + for r in spkr_overlap: + sys_spkrs_set.update(spkr_overlap[r].keys()) + sys_spkrs = sorted(sys_spkrs_set) + + nref = len(ref_spkrs) + nsys = len(sys_spkrs) + + cost = np.zeros((nref, nsys)) + for i, ref in enumerate(ref_spkrs): + for j, sys_ in enumerate(sys_spkrs): + cost[i, j] = -spkr_overlap[ref].get(sys_, 0.0) + + row_ind, col_ind = linear_sum_assignment(cost) + + result: SpeakerMap = {} + for i, j in zip(row_ind, col_ind): + ref = ref_spkrs[i] + sys_ = sys_spkrs[j] + if spkr_overlap[ref].get(sys_, 0.0) > 0: + result[ref] = sys_ + return result + + +# ─── Per-segment speaker scoring ───────────────────────────────────────── + + +def _speakers_match(ref_spkrs: Dict[str, int], sys_spkrs: Dict[str, int], spkr_map: SpeakerMap) -> bool: + """Check whether every ref speaker in a segment maps to a present sys speaker.""" + if len(ref_spkrs) != len(sys_spkrs): + return False + for rs in ref_spkrs: + mapped = spkr_map.get(rs) + if mapped is None or mapped not in sys_spkrs: + return False + return True + + +def _speaker_mapping_scores( + spkr_map: SpeakerMap, + spkr_info: Dict[str, Dict[str, Dict[str, Any]]], +) -> Dict[str, Dict]: + """Count-based (NSPK) speaker-type confusion statistics.""" + stats: Dict[str, Dict] = {"REF": {}, "SYS": {}, "JOINT": {}} + imap: Dict[str, str] = {} + + for rs, info in spkr_info["REF"].items(): + if not info.get("TIME"): + continue + rt = info.get("TYPE", "unknown") + stats["REF"][rt] = stats["REF"].get(rt, 0) + 1 + ss = spkr_map.get(rs) + st = spkr_info["SYS"].get(ss, {}).get("TYPE", MISS_NAME) if ss else MISS_NAME + stats["JOINT"].setdefault(rt, {})[st] = stats["JOINT"].get(rt, {}).get(st, 0) + 1 + if ss: + imap[ss] = rs + + for ss, info in spkr_info["SYS"].items(): + if not info.get("TIME"): + continue + st = info.get("TYPE", "unknown") + stats["SYS"][st] = stats["SYS"].get(st, 0) + 1 + if ss not in imap: + stats["JOINT"].setdefault(FA_NAME, {})[st] = stats["JOINT"].get(FA_NAME, {}).get(st, 0) + 1 + + return stats + + +def _score_speaker_segments( + stats: SDStats, + score_segs: List[ScoreSegment], + ref_wds: List[Token], + spkr_map: SpeakerMap, + spkr_info: Dict[str, Dict[str, Dict[str, Any]]], +) -> None: + """Accumulate DER components over all scored segments (in-place). + + Args: + stats: Mutable statistics dictionary to update. + score_segs: Scored timeline segments. + ref_wds: Sorted reference word tokens. + spkr_map: Reference-to-system speaker mapping. + spkr_info: Speaker metadata dict with REF/SYS sub-dicts. + """ + ref_wds_list = sorted(ref_wds, key=lambda w: w["TMID"]) + wi = 0 + + for seg in score_segs: + dur = seg["TDUR"] + nref = len(seg["REF"]) + nsys = len(seg["SYS"]) + + stats["SCORED_TIME"] = stats.get("SCORED_TIME", 0.0) + dur + stats["SCORED_SPEECH"] = stats.get("SCORED_SPEECH", 0.0) + (dur if nref else 0.0) + stats["MISSED_SPEECH"] = stats.get("MISSED_SPEECH", 0.0) + (dur if nref and not nsys else 0.0) + stats["FALARM_SPEECH"] = stats.get("FALARM_SPEECH", 0.0) + (dur if nsys and not nref else 0.0) + stats["SCORED_SPEAKER"] = stats.get("SCORED_SPEAKER", 0.0) + dur * nref + stats["MISSED_SPEAKER"] = stats.get("MISSED_SPEAKER", 0.0) + dur * max(nref - nsys, 0) + stats["FALARM_SPEAKER"] = stats.get("FALARM_SPEAKER", 0.0) + dur * max(nsys - nref, 0) + + nmap = sum(1 for rs in seg["REF"] if spkr_map.get(rs) in seg["SYS"]) + stats["SPEAKER_ERROR"] = stats.get("SPEAKER_ERROR", 0.0) + dur * (min(nref, nsys) - nmap) + + while wi < len(ref_wds_list) and ref_wds_list[wi]["TMID"] < seg["TBEG"]: + wi += 1 + tmp = wi + sw = mw = ew = 0 + while tmp < len(ref_wds_list) and ref_wds_list[tmp]["TMID"] <= seg["TEND"]: + if ref_wds_list[tmp].get("SCOREABLE"): + sw += 1 + if not nsys: + mw += 1 + if not _speakers_match(seg["REF"], seg["SYS"], spkr_map): + ew += 1 + tmp += 1 + stats["SCORED_WORDS"] = stats.get("SCORED_WORDS", EPSILON) + sw + stats["MISSED_WORDS"] = stats.get("MISSED_WORDS", EPSILON) + mw + stats["ERROR_WORDS"] = stats.get("ERROR_WORDS", EPSILON) + ew + + num_ref: Dict[str, int] = {} + num_sys: Dict[str, int] = {} + for rs in seg["REF"]: + rt = spkr_info["REF"].get(rs, {}).get("TYPE", "unknown") + num_ref[rt] = num_ref.get(rt, 0) + 1 + for ss in seg["SYS"]: + st = spkr_info["SYS"].get(ss, {}).get("TYPE", "unknown") + num_sys[st] = num_sys.get(st, 0) + 1 + + tt = stats["TYPE"].setdefault("TIME", {"REF": {}, "SYS": {}, "JOINT": {}}) + for rt, nrt in num_ref.items(): + tt["REF"][rt] = tt["REF"].get(rt, 0.0) + nrt * dur + for st, nst in num_sys.items(): + tt["JOINT"].setdefault(rt, {})[st] = tt["JOINT"].get(rt, {}).get(st, 0.0) + min(nrt, nst) * dur + tt["JOINT"].setdefault(rt, {})[MISS_NAME] = ( + tt["JOINT"].get(rt, {}).get(MISS_NAME, 0.0) + max(nrt - nsys, 0) * dur + ) + for st, nst in num_sys.items(): + tt["SYS"][st] = tt["SYS"].get(st, 0.0) + nst * dur + tt["JOINT"].setdefault(FA_NAME, {})[st] = ( + tt["JOINT"].get(FA_NAME, {}).get(st, 0.0) + max(nst - nref, 0) * dur + ) + + +# ─── Main diarization scoring ───────────────────────────────────────────── + + +def score_speaker_diarization( + file_id: str, + chnl: str, + ref_spkr_data: SpeakerData, + sys_spkr_data: SpeakerData, + ref_wds: List[Token], + uem_eval: List[UEMSegment], + rttm_data: List[Token], + collar: float = DEFAULT_COLLAR, + opt_1: bool = False, + noscore_sd: Optional[Dict[str, Dict[str, int]]] = None, + max_extend: float = DEFAULT_EXTEND, +) -> Tuple[SDStats, SpeakerMap]: + """Score speaker diarization for a single file/channel pair. + + Ports the Perl ``score_speaker_diarization`` subroutine exactly. + + Args: + file_id: File identifier string. + chnl: Channel identifier string. + ref_spkr_data: Reference speaker data ``{spkr: [token, ...]}``. + sys_spkr_data: System speaker data ``{spkr: [token, ...]}``. + ref_wds: Reference word (LEXEME) tokens. + uem_eval: Evaluation UEM segments. + rttm_data: All RTTM tokens for this file/channel. + collar: No-score collar width in seconds. + opt_1: If ``True``, restrict scoring to single-speaker regions. + noscore_sd: No-score conditions for speaker diarization. + max_extend: Maximum extension for NON-LEX no-score zones. + + Returns: + A tuple ``(stats, spkr_map)`` where *stats* is a dictionary of DER component + accumulators and *spkr_map* is the optimal ref→sys speaker mapping. + """ + stats: SDStats = { + "EVAL_WORDS": EPSILON, + "SCORED_WORDS": EPSILON, + "MISSED_WORDS": EPSILON, + "ERROR_WORDS": EPSILON, + "EVAL_TIME": 0.0, + "EVAL_SPEECH": 0.0, + "SCORED_TIME": 0.0, + "SCORED_SPEECH": 0.0, + "MISSED_SPEECH": 0.0, + "FALARM_SPEECH": 0.0, + "SCORED_SPEAKER": 0.0, + "MISSED_SPEAKER": 0.0, + "FALARM_SPEAKER": 0.0, + "SPEAKER_ERROR": 0.0, + "TYPE": {}, + } + + ref_wds_list = sorted(ref_wds, key=lambda w: w["TMID"]) + wi = 0 + for seg in uem_eval or []: + stats["EVAL_TIME"] += seg["TEND"] - seg["TBEG"] + while wi < len(ref_wds_list) and ref_wds_list[wi]["TMID"] < seg["TBEG"]: + wi += 1 + while wi < len(ref_wds_list) and ref_wds_list[wi]["TMID"] <= seg["TEND"]: + stats["EVAL_WORDS"] += 1 + wi += 1 + + eval_segs = create_speaker_segs(uem_eval, ref_spkr_data, sys_spkr_data) + spkr_info: Dict[str, Dict[str, Dict[str, Any]]] = {"REF": {}, "SYS": {}} + spkr_overlap: SpeakerOverlap = {} + + for seg in eval_segs: + for rs in seg["REF"]: + spkr_info["REF"].setdefault(rs, {"TIME": 0.0}) + spkr_info["REF"][rs]["TIME"] += seg["TDUR"] + if ref_spkr_data.get(rs): + spkr_info["REF"][rs]["TYPE"] = ref_spkr_data[rs][0].get("SUBT", "unknown") + for ss in seg["SYS"]: + spkr_info["SYS"].setdefault(ss, {"TIME": 0.0}) + spkr_info["SYS"][ss]["TIME"] += seg["TDUR"] + if sys_spkr_data.get(ss): + spkr_info["SYS"][ss]["TYPE"] = sys_spkr_data[ss][0].get("SUBT", "unknown") + + if not seg["REF"]: + continue + stats["EVAL_SPEECH"] += seg["TDUR"] + for rs in seg["REF"]: + for ss in seg["SYS"]: + spkr_overlap.setdefault(rs, {})[ss] = spkr_overlap.get(rs, {}).get(ss, 0.0) + seg["TDUR"] + + spkr_map = map_speakers(spkr_overlap) + + uem_score = add_collars_to_uem(uem_eval, ref_spkr_data, collar) if collar > 0 else uem_eval + + if noscore_sd: + uem_score = add_exclusion_zones_to_uem(noscore_sd, uem_score, rttm_data) + noscore_nl = noscore_sd.get("NON-LEX") + if noscore_nl: + uem_score = add_exclusion_zones_to_uem({"NON-LEX": noscore_nl}, uem_score, rttm_data, max_extend) + + if opt_1: + uem_score = exclude_overlapping_speech_from_uem(uem_score, rttm_data) + + score_segs = create_speaker_segs(uem_score, ref_spkr_data, sys_spkr_data) + _score_speaker_segments(stats, score_segs, ref_wds_list, spkr_map, spkr_info) + stats["TYPE"]["NSPK"] = _speaker_mapping_scores(spkr_map, spkr_info) + + return stats, spkr_map + + +# ─── Output formatting ──────────────────────────────────────────────────── + + +def _summarize_speaker_type_performance(cls: str, stats: Dict[str, Dict]) -> str: + """Format speaker-type confusion matrix as a multi-line string. + + Args: + cls: Either ``"NSPK"`` (count-weighted) or ``"TIME"`` (time-weighted). + stats: Type confusion statistics dict with keys ``REF``, ``SYS``, ``JOINT``. + + Returns: + Formatted confusion matrix string. + """ + sys_types = sorted(stats.get("SYS", {}).keys()) + label = " REF\\SYS (count) " if cls == "NSPK" else " REF\\SYS (seconds) " + + lines = [label + "".join(f"{st:<20}" for st in sys_types + [MISS_NAME])] + ref_tot = sum(stats.get("REF", {}).values()) + + for rt in sorted(stats.get("REF", {}).keys()) + [FA_NAME]: + parts = [f"{rt:<16}"] + for st in sys_types + [MISS_NAME]: + if rt == FA_NAME and st == MISS_NAME: + continue + val = stats.get("JOINT", {}).get(rt, {}).get(st, 0) + pct = min(999.9, 100.0 * val / ref_tot) if ref_tot else 9e9 + if cls == "NSPK": + parts.append(f"{int(val):>11} /{pct:>6.1f}%") + else: + parts.append(f"{val:>11.2f} /{pct:>6.1f}%") + lines.append("".join(parts)) + + return "\n".join(lines) + + +def format_sd_scores(condition: str, scores: SDStats) -> str: + """Format speaker diarization scores as a human-readable string. + + Args: + condition: Label for the evaluation condition (e.g. ``"ALL"``). + scores: Aggregated DER component statistics dictionary. + + Returns: + Multi-line formatted string containing DER breakdown and confusion matrices. + """ + scored = scores.get("SCORED_SPEAKER", 0.0) or EPSILON + missed = scores.get("MISSED_SPEAKER", 0.0) + falarm = scores.get("FALARM_SPEAKER", 0.0) + error = scores.get("SPEAKER_ERROR", 0.0) + der = 100.0 * (missed + falarm + error) / scored + + lines = [ + f"\n*** Performance analysis for Speaker Diarization for {condition} ***\n", + f"SCORED SPEAKER TIME ={scored:f} secs", + f"MISSED SPEAKER TIME ={missed:f} secs", + f"FALARM SPEAKER TIME ={falarm:f} secs", + f"SPEAKER ERROR TIME ={error:f} secs", + f" OVERALL SPEAKER DIARIZATION ERROR = {der:.2f} percent of scored speaker time" f" `({condition})", + "---------------------------------------------", + " Speaker type confusion matrix -- speaker weighted", + _summarize_speaker_type_performance("NSPK", scores.get("TYPE", {}).get("NSPK", {})), + "---------------------------------------------", + " Speaker type confusion matrix -- time weighted", + _summarize_speaker_type_performance("TIME", scores.get("TYPE", {}).get("TIME", {})), + "---------------------------------------------", + ] + return "\n".join(lines) + + +# ─── Top-level evaluate ─────────────────────────────────────────────────── + + +def evaluate( + ref_data: Dict[str, Dict[str, Dict[str, Any]]], + sys_data: Dict[str, Dict[str, Dict[str, Any]]], + uem_data: Optional[Dict[str, Dict[str, List[UEMSegment]]]] = None, + collar: float = DEFAULT_COLLAR, + opt_1: bool = False, + noeval_sd: Optional[Dict[str, Dict[str, int]]] = None, + noscore_sd: Optional[Dict[str, Dict[str, int]]] = None, + max_extend: float = DEFAULT_EXTEND, + verbose: bool = True, +) -> Tuple[Dict[str, Dict[str, SDStats]], SDStats]: + """Evaluate speaker diarization across all files and channels. + + Ports the Perl ``evaluate`` subroutine (speaker-diarization path only). + + Args: + ref_data: Parsed reference RTTM data + ``{file_id: {chnl: {"SPEAKER": ..., "RTTM": ...}}}``. + sys_data: Parsed system RTTM data (same structure as *ref_data*). + uem_data: Parsed UEM data ``{file_id: {chnl: [seg, ...]}}``. If ``None``, + UEM partitions are derived from the reference RTTM. + collar: No-score collar width in seconds. + opt_1: If ``True``, restrict scoring to single-speaker regions. + noeval_sd: No-eval conditions. Defaults to :data:`NOEVAL_SD`. + noscore_sd: No-score conditions. Defaults to :data:`NOSCORE_SD`. + max_extend: Maximum extension for NON-LEX no-score zones. + verbose: If ``True``, log the final DER summary. + + Returns: + A tuple ``(all_scores, cum)`` where *all_scores* maps + ``file_id → chnl → stats`` and *cum* is the aggregate statistics dictionary. + """ + noeval_sd = noeval_sd if noeval_sd is not None else NOEVAL_SD + noscore_sd = noscore_sd if noscore_sd is not None else NOSCORE_SD + + all_scores: Dict[str, Dict[str, SDStats]] = {} + + for file_id in sorted(ref_data.keys()): + for chnl in sorted(ref_data[file_id].keys()): + ref_rttm = ref_data[file_id][chnl].get("RTTM", []) + ref_spkr_data = ref_data[file_id][chnl].get("SPEAKER") + if not ref_spkr_data: + continue + + sys_spkr_data = sys_data.get(file_id, {}).get(chnl, {}).get("SPEAKER", {}) + ref_wds = ref_data[file_id][chnl].get("LEXEME", []) + + uem = uem_data.get(file_id, {}).get(chnl) if uem_data else None + if uem is None: + uem = uem_from_rttm(ref_rttm) + + for segs in sys_spkr_data.values(): + for seg in segs: + seg.setdefault("RTBEG", seg["TBEG"]) + seg.setdefault("RTEND", seg["TEND"]) + seg["RTDUR"] = seg["RTEND"] - seg["RTBEG"] + seg["RTMID"] = seg["RTBEG"] + seg["RTDUR"] / 2.0 + + uem_sd_eval = add_exclusion_zones_to_uem(noeval_sd, uem, ref_rttm) + if not uem_sd_eval: + uem_sd_eval = uem + + stats, _ = score_speaker_diarization( + file_id, + chnl, + ref_spkr_data, + sys_spkr_data, + ref_wds, + uem_sd_eval, + ref_rttm, + collar=collar, + opt_1=opt_1, + noscore_sd=noscore_sd, + max_extend=max_extend, + ) + all_scores.setdefault(file_id, {})[chnl] = stats + + # Aggregate across files/channels + cum: SDStats = { + "EVAL_TIME": 0.0, + "EVAL_SPEECH": 0.0, + "SCORED_TIME": 0.0, + "SCORED_SPEECH": 0.0, + "MISSED_SPEECH": 0.0, + "FALARM_SPEECH": 0.0, + "SCORED_SPEAKER": 0.0, + "MISSED_SPEAKER": 0.0, + "FALARM_SPEAKER": 0.0, + "SPEAKER_ERROR": 0.0, + "EVAL_WORDS": 0.0, + "SCORED_WORDS": 0.0, + "MISSED_WORDS": 0.0, + "ERROR_WORDS": 0.0, + "TYPE": { + "NSPK": {"REF": {}, "SYS": {}, "JOINT": {}}, + "TIME": {"REF": {}, "SYS": {}, "JOINT": {}}, + }, + } + + _scalar_keys = ( + "EVAL_TIME", + "EVAL_SPEECH", + "SCORED_TIME", + "SCORED_SPEECH", + "MISSED_SPEECH", + "FALARM_SPEECH", + "SCORED_SPEAKER", + "MISSED_SPEAKER", + "FALARM_SPEAKER", + "SPEAKER_ERROR", + "EVAL_WORDS", + "SCORED_WORDS", + "MISSED_WORDS", + "ERROR_WORDS", + ) + + for file_id in all_scores: + for chnl in all_scores[file_id]: + s = all_scores[file_id][chnl] + for k in _scalar_keys: + cum[k] += s.get(k, 0.0) + for cls in ("NSPK", "TIME"): + src = s.get("TYPE", {}).get(cls, {}) + dst = cum["TYPE"][cls] + for kind in ("REF", "SYS"): + for t, v in src.get(kind, {}).items(): + dst[kind][t] = dst[kind].get(t, 0) + v + for rt, sm in src.get("JOINT", {}).items(): + for st, v in sm.items(): + dst["JOINT"].setdefault(rt, {})[st] = dst["JOINT"].get(rt, {}).get(st, 0) + v + + if verbose: + logging.info(format_sd_scores("ALL", cum)) + + return all_scores, cum + + +# ─── DER result wrapper ──────────────────────────────────────────────────── +# +# ``DiarizationErrorResult`` is the result object returned by the public DER +# entry points (``score_labels`` and ``score_labels_from_rttm_labels`` in +# ``der.py``). It exposes a small, dict-like interface that the rest of NeMo +# (and downstream user code) consume. +# ─────────────────────────────────────────────────────────────────────────── + + +class DiarizationErrorResult: + """Result object returned by NeMo's DER entry points. + + Supports: + - ``abs(result)`` → overall DER (float) + - ``result['total']``, ``result['confusion']``, ``result['false alarm']``, + ``result['missed detection']`` + - ``result.results_`` → list of ``(uniq_id, score_dict)`` per file + - ``result.optimal_mapping(ref, hyp)`` → speaker mapping dict for a file + - ``result.report()`` → formatted string summary + + Args: + all_scores: Per-file stats ``{file_id: {chnl: SDStats}}``. + cum: Aggregate stats dict. + mapping_dict: ``{uniq_id: {ref_spkr: sys_spkr}}``. + collar: Collar value used. + ignore_overlap: Whether overlap was ignored. + """ + + def __init__( + self, + all_scores: Dict[str, Dict[str, SDStats]], + cum: SDStats, + mapping_dict: Dict[str, SpeakerMap], + collar: float, + ignore_overlap: bool, + ): + self._all_scores = all_scores + self._cum = cum + self._mapping_dict = mapping_dict + self._collar = collar + self._ignore_overlap = ignore_overlap + + scored = cum.get("SCORED_SPEAKER", 0.0) or EPSILON + self._total = scored + self._confusion = cum.get("SPEAKER_ERROR", 0.0) + self._false_alarm = cum.get("FALARM_SPEAKER", 0.0) + self._missed = cum.get("MISSED_SPEAKER", 0.0) + self._der = (self._confusion + self._false_alarm + self._missed) / self._total + + self.results_: List[Tuple[str, Dict[str, float]]] = [] + for file_id in sorted(all_scores.keys()): + for chnl in sorted(all_scores[file_id].keys()): + s = all_scores[file_id][chnl] + s_scored = s.get("SCORED_SPEAKER", 0.0) or EPSILON + self.results_.append( + ( + file_id, + { + "total": s_scored, + "confusion": s.get("SPEAKER_ERROR", 0.0), + "false alarm": s.get("FALARM_SPEAKER", 0.0), + "missed detection": s.get("MISSED_SPEAKER", 0.0), + }, + ) + ) + + def __abs__(self) -> float: + return self._der + + def __getitem__(self, key: str) -> float: + return { + "total": self._total, + "confusion": self._confusion, + "false alarm": self._false_alarm, + "missed detection": self._missed, + }[key] + + def optimal_mapping(self, ref_labels: Any, hyp_labels: Any) -> SpeakerMap: + """Return the optimal speaker mapping for a given ref/hyp pair. + + When called with a string key, the mapping is looked up directly. For + annotation-like objects, the recording id is taken from a ``.uri`` / + ``.recording_id`` attribute when present, otherwise the object's + ``str(...)`` representation. + """ + if isinstance(ref_labels, str): + key = ref_labels + else: + key = getattr(ref_labels, 'uri', None) or getattr(ref_labels, 'recording_id', None) or str(ref_labels) + return self._mapping_dict.get(key, {}) + + def report(self) -> str: + """Return a human-readable string of per-file DER scores.""" + lines = [] + header = f"{'file':<40} {'total':>10} {'confusion':>10} {'false alarm':>12} {'missed':>10} {'DER':>8}" + lines.append(header) + lines.append("-" * len(header)) + for file_id, score in self.results_: + total = score["total"] + conf = score["confusion"] + fa = score["false alarm"] + miss = score["missed detection"] + der = 100.0 * (conf + fa + miss) / total if total > 0 else 0.0 + lines.append(f"{file_id:<40} {total:>10.2f} {conf:>10.2f} {fa:>12.2f} {miss:>10.2f} {der:>7.2f}%") + total = self._total + lines.append("-" * len(header)) + lines.append( + f"{'TOTAL':<40} {total:>10.2f} {self._confusion:>10.2f} " + f"{self._false_alarm:>12.2f} {self._missed:>10.2f} {abs(self) * 100:>7.2f}%" + ) + return "\n".join(lines) + + +def _iter_annotation_segments(annotation: Any) -> Iterator[Tuple[float, float, str]]: + """Yield ``(start, end, speaker)`` tuples from an annotation-like object. + + Supports lhotse ``SupervisionSet`` / iterable of ``SupervisionSegment`` + (the preferred representation), as well as any iterable of objects + exposing ``.start`` plus either ``.end`` or ``.duration`` and ``.speaker``. + Inputs that expose an ``.itertracks(yield_label=True)`` iterator are also + supported for compatibility with annotation objects from external + annotation libraries. + """ + if hasattr(annotation, "itertracks"): + for segment, _track, speaker in annotation.itertracks(yield_label=True): + yield float(segment.start), float(segment.end), str(speaker) + return + + for item in annotation: + start = float(item.start) + if hasattr(item, "end") and item.end is not None: + end = float(item.end) + elif hasattr(item, "duration"): + end = start + float(item.duration) + else: + raise TypeError(f"Annotation item of type {type(item).__name__} has no 'end' or 'duration' attribute.") + speaker = getattr(item, "speaker", None) + if speaker is None: + raise TypeError(f"Annotation item of type {type(item).__name__} has no 'speaker' attribute.") + yield start, end, str(speaker) + + +def _annotation_to_rttm_data( + uniq_id: str, + annotation: Any, +) -> Dict[str, Dict[str, Dict[str, Any]]]: + """Convert an annotation-like object into the nested RTTM data structure + expected by :func:`evaluate`. + + Accepts any of the following: + * a lhotse ``SupervisionSet`` or iterable of ``SupervisionSegment`` + (each item exposes ``.start``, ``.end``/``.duration``, ``.speaker``). + * any iterable of objects that have ``.start``, ``.end`` (or + ``.duration``) and ``.speaker``. + * any object exposing ``.itertracks(yield_label=True)`` (compatibility + with annotation objects from external annotation libraries). + + Args: + uniq_id: Unique file identifier used as ``file_id``. + annotation: An annotation-like object as described above. + + Returns: + RTTM data dict ``{file_id: {chnl: {"SPEAKER": ..., "RTTM": ...}}}``. + """ + data: Dict[str, Dict[str, Dict[str, Any]]] = {} + chnl = "1" + + for tbeg, tend, speaker in _iter_annotation_segments(annotation): + tdur = tend - tbeg + if tdur <= 0: + continue + token: Token = { + "TYPE": "SPEAKER", + "FILE": uniq_id, + "CHNL": chnl, + "TBEG": tbeg, + "TDUR": tdur, + "TEND": tend, + "TMID": tbeg + tdur / 2.0, + "WORD": "", + "SUBT": "", + "SPKR": str(speaker), + "CONF": "-", + } + ( + data.setdefault(uniq_id, {}).setdefault(chnl, {}).setdefault("SPEAKER", {}).setdefault(str(speaker), []) + ).append(token) + data[uniq_id][chnl].setdefault("RTTM", []).append(token) + + # Sort speaker segments by midpoint (matching get_rttm_data post-parse) + for file_id in data: + for ch in data[file_id]: + for spkr, segs in data[file_id][ch].get("SPEAKER", {}).items(): + segs.sort(key=lambda t: t["TMID"]) + + return data + + +def _labels_to_rttm_data( + uniq_id: str, + labels: List[str], +) -> Dict[str, Dict[str, Dict[str, Any]]]: + """Convert a list of ``"start end speaker"`` label strings into the nested RTTM + data structure expected by :func:`evaluate`. + + Args: + uniq_id: Unique file identifier. + labels: List of label strings, each formatted as ``"start end speaker"``. + + Returns: + RTTM data dict ``{file_id: {chnl: {"SPEAKER": ..., "RTTM": ...}}}``. + """ + data: Dict[str, Dict[str, Dict[str, Any]]] = {} + chnl = "1" + + for label in labels: + parts = label.strip().split() + tbeg, tend = float(parts[0]), float(parts[1]) + speaker = parts[2] + tdur = tend - tbeg + if tdur <= 0: + continue + token: Token = { + "TYPE": "SPEAKER", + "FILE": uniq_id, + "CHNL": chnl, + "TBEG": tbeg, + "TDUR": tdur, + "TEND": tend, + "TMID": tbeg + tdur / 2.0, + "WORD": "", + "SUBT": "", + "SPKR": speaker, + "CONF": "-", + } + (data.setdefault(uniq_id, {}).setdefault(chnl, {}).setdefault("SPEAKER", {}).setdefault(speaker, [])).append( + token + ) + data[uniq_id][chnl].setdefault("RTTM", []).append(token) + + for file_id in data: + for ch in data[file_id]: + for spkr, segs in data[file_id][ch].get("SPEAKER", {}).items(): + segs.sort(key=lambda t: t["TMID"]) + + return data + + +def _uem_list_to_uem_data( + uniq_id: str, + uem_segments: List[List[float]], +) -> Dict[str, Dict[str, List[UEMSegment]]]: + """Convert a list of ``[start, end]`` pairs into UEM data structure. + + Args: + uniq_id: Unique file identifier. + uem_segments: List of ``[start_time, end_time]`` pairs. + + Returns: + UEM data dict ``{file_id: {chnl: [seg, ...]}}``. + """ + chnl = "1" + segs = [{"TBEG": float(s), "TEND": float(e)} for s, e in uem_segments] + return {uniq_id: {chnl: segs}} + + +def _merge_rttm_dicts(dicts: List[Dict[str, Dict[str, Dict[str, Any]]]]) -> Dict[str, Dict[str, Dict[str, Any]]]: + """Merge multiple single-file RTTM data dicts into one combined dict.""" + merged: Dict[str, Dict[str, Dict[str, Any]]] = {} + for d in dicts: + for file_id, channels in d.items(): + merged.setdefault(file_id, {}).update(channels) + return merged + + +def _merge_uem_dicts(dicts: List[Dict[str, Dict[str, List[UEMSegment]]]]) -> Dict[str, Dict[str, List[UEMSegment]]]: + """Merge multiple single-file UEM data dicts into one combined dict.""" + merged: Dict[str, Dict[str, List[UEMSegment]]] = {} + for d in dicts: + for file_id, channels in d.items(): + merged.setdefault(file_id, {}).update(channels) + return merged diff --git a/nemo/collections/asr/parts/utils/diarization_utils.py b/nemo/collections/asr/parts/utils/diarization_utils.py index a99f86d6a634..2f33207a4d88 100644 --- a/nemo/collections/asr/parts/utils/diarization_utils.py +++ b/nemo/collections/asr/parts/utils/diarization_utils.py @@ -23,9 +23,12 @@ from typing import Dict, List, Optional, Tuple import numpy as np -from pyannote.metrics.diarization import DiarizationErrorRate -from nemo.collections.asr.metrics.der import calculate_session_cpWER, concat_perm_word_error_rate +from nemo.collections.asr.metrics.der import ( + calculate_session_cpWER, + concat_perm_word_error_rate, + score_labels_from_rttm_labels, +) from nemo.collections.asr.metrics.wer import word_error_rate from nemo.collections.asr.models import ClusteringDiarizer from nemo.collections.asr.parts.utils.speaker_utils import ( @@ -670,44 +673,61 @@ def evaluate(self, ref_seglst, hyp_seglst, chunk_size=10.0, verbose=True): assert ref_session_id == hyp_session_id, "Session IDs of reference and hypothesis should match" - # Only care about the sessions in reference only session_id = ref_session_id ref_speaker_words = defaultdict(list) hyp_speaker_words = defaultdict(list) - der_metric = DiarizationErrorRate(collar=2 * self.collar, skip_overlap=self.ignore_overlap) cpwer_metric = calculate_session_cpWER der_list, cpwer_list = [], [] + + cum_ref_labels: List[str] = [] + cum_hyp_labels: List[str] = [] + for chunk_idx in range(max_idx): - ref_seglst = chunked_ref_seglst[chunk_idx] - hyp_seglst = chunked_hyp_seglst[chunk_idx] + ref_seglst_chunk = chunked_ref_seglst[chunk_idx] + hyp_seglst_chunk = chunked_hyp_seglst[chunk_idx] if len(ref_speaker_words) == 0: ref_speaker_words = ['' for _ in ref_speakers] if len(hyp_speaker_words) == 0: hyp_speaker_words = ['' for _ in hyp_speakers] - hyp_speaker_timestamps, hyp_speaker_word = convert_seglst(hyp_seglst, hyp_speakers) - ref_speaker_timestamps, ref_speaker_word = convert_seglst(ref_seglst, ref_speakers) + hyp_speaker_timestamps, hyp_speaker_word = convert_seglst(hyp_seglst_chunk, hyp_speakers) + ref_speaker_timestamps, ref_speaker_word = convert_seglst(ref_seglst_chunk, ref_speakers) for idx, speaker in enumerate(ref_speakers): ref_speaker_words[idx] += ref_speaker_word[idx] + for st, et in ref_speaker_timestamps[idx]: + cum_ref_labels.append(f"{st} {et} {speaker}") for idx, speaker in enumerate(hyp_speakers): hyp_speaker_words[idx] += hyp_speaker_word[idx] + for st, et in hyp_speaker_timestamps[idx]: + cum_hyp_labels.append(f"{st} {et} {speaker}") - # Normalize the text for spk_idx in range(len(hyp_speaker_words)): hyp_speaker_words[spk_idx] = ( hyp_speaker_words[spk_idx].translate(str.maketrans('', '', string.punctuation)).lower() ) cpWER, min_perm_hyp_trans, ref_trans = cpwer_metric(ref_speaker_words, hyp_speaker_words) + der = 0.0 + if cum_ref_labels: + result = score_labels_from_rttm_labels( + ref_labels_list=[(session_id, list(cum_ref_labels))], + hyp_labels_list=[(session_id, list(cum_hyp_labels))], + collar=self.collar, + ignore_overlap=self.ignore_overlap, + verbose=False, + ) + if result is not None: + der = abs(result[0]) * 100 + if verbose: logging.info( f"Session ID: {session_id} Chunk ID: {chunk_idx} from 0.0s to {(chunk_idx+1)*chunk_size}s" ) - logging.info(f"DER: {abs(der_metric)*100:.2f}%, cpWER: {cpWER*100:.2f}%") + logging.info(f"DER: {der:.2f}%, cpWER: {cpWER*100:.2f}%") - der_list.append(abs(der_metric) * 100) + der_list.append(der) cpwer_list.append(cpWER * 100) return der_list, cpwer_list @@ -890,9 +910,9 @@ def run_diarization(self, diar_model_config, word_timestamps) -> Dict[str, List[ Returns: diar_hyp (dict): A dictionary containing rttm results which are indexed by a unique ID. - score Tuple[pyannote object, dict]: - A tuple containing pyannote metric instance and mapping dictionary between - speakers in hypotheses and speakers in reference RTTM files. + score (Tuple[DiarizationErrorResult, dict]): + A tuple containing the DER result object and a mapping dictionary + between speakers in hypotheses and speakers in reference RTTM files. """ if diar_model_config.diarizer.asr.parameters.asr_based_vad: @@ -949,11 +969,11 @@ def gather_eval_results( decimals: int = 4, ) -> Dict[str, Dict[str, float]]: """ - Gather diarization evaluation results from pyannote DiarizationErrorRate metric object. + Gather diarization evaluation results from DiarizationErrorResult metric object. Args: - metric (DiarizationErrorRate metric): - DiarizationErrorRate metric pyannote object + metric (DiarizationErrorResult): + DiarizationErrorResult metric object from md_eval trans_info_dict (dict): Dictionary containing word timestamps, speaker labels and words from all sessions. Each session is indexed by unique ID as a key. diff --git a/nemo/collections/asr/parts/utils/speaker_utils.py b/nemo/collections/asr/parts/utils/speaker_utils.py index 99bd29760ffc..98a79a83eafb 100644 --- a/nemo/collections/asr/parts/utils/speaker_utils.py +++ b/nemo/collections/asr/parts/utils/speaker_utils.py @@ -22,11 +22,12 @@ import numpy as np import soundfile as sf import torch +from lhotse import SupervisionSegment from omegaconf.listconfig import ListConfig -from pyannote.core import Annotation, Segment, Timeline from tqdm import tqdm from nemo.collections.asr.data.audio_to_label import repeat_signal +from nemo.collections.asr.metrics.der import make_diar_annotation, make_uem_timeline, write_supervisions_to_rttm from nemo.collections.asr.parts.utils.longform_clustering import LongFormSpeakerClustering from nemo.utils import logging @@ -310,17 +311,24 @@ def merge_stamps(lines): return overlap_stamps -def labels_to_pyannote_object(labels, uniq_name=''): - """ - Convert the given labels to pyannote object to calculate DER and for visualization - """ - annotation = Annotation(uri=uniq_name) - for label in labels: - start, end, speaker = label.strip().split() - start, end = float(start), float(end) - annotation[Segment(start, end)] = speaker +def labels_to_supervisions(labels, uniq_name=''): + """Convert ``"start end speaker"`` label strings to a diarization annotation. + + Returns a ``list`` of :class:`lhotse.SupervisionSegment`, which is the + annotation type used throughout NeMo's DER pipeline. The returned object + can be passed to ``score_labels`` / ``score_labels_from_rttm_labels`` + and other DER helpers. - return annotation + Args: + labels: Iterable of label strings, each formatted as + ``"start end speaker"``. + uniq_name: Recording / file identifier (used as the recording id of + each emitted supervision). + + Returns: + List of :class:`lhotse.SupervisionSegment` objects, one per label. + """ + return make_diar_annotation(labels, uniq_name=uniq_name) def labels_to_rttmfile(labels, uniq_id, out_rttm_dir): @@ -457,8 +465,10 @@ def perform_clustering( Enable TQDM progress bar. Returns: - all_reference (list[uniq_name,Annotation]): reference annotations for score calculation - all_hypothesis (list[uniq_name,Annotation]): hypothesis annotations for score calculation + all_reference (list[uniq_name, list[SupervisionSegment]]): + reference annotations for score calculation. + all_hypothesis (list[uniq_name, list[SupervisionSegment]]): + hypothesis annotations for score calculation. """ all_hypothesis = [] @@ -518,13 +528,13 @@ def perform_clustering( if out_rttm_dir: labels_to_rttmfile(labels, uniq_id, out_rttm_dir) lines_cluster_labels.extend([f'{uniq_id} {seg_line}\n' for seg_line in lines]) - hypothesis = labels_to_pyannote_object(labels, uniq_name=uniq_id) + hypothesis = labels_to_supervisions(labels, uniq_name=uniq_id) all_hypothesis.append([uniq_id, hypothesis]) rttm_file = audio_rttm_values.get('rttm_filepath', None) if rttm_file is not None and os.path.exists(rttm_file) and not no_references: ref_labels = rttm_to_labels(rttm_file) - reference = labels_to_pyannote_object(ref_labels, uniq_name=uniq_id) + reference = labels_to_supervisions(ref_labels, uniq_name=uniq_id) all_reference.append([uniq_id, reference]) else: no_references = True @@ -1361,17 +1371,20 @@ def get_online_subsegments_from_buffer( return sigs_list, sig_rangel_list, sig_indexes -def timestamps_to_pyannote_object( +def timestamps_to_supervisions( speaker_timestamps: List[Tuple[float, float]], uniq_id: str, audio_rttm_values: Dict[str, str], - all_hypothesis: List[Tuple[str, Timeline]], - all_reference: List[Tuple[str, Timeline]], - all_uems: List[Tuple[str, Timeline]], + all_hypothesis: List[Tuple[str, List[SupervisionSegment]]], + all_reference: List[Tuple[str, List[SupervisionSegment]]], + all_uems: List[Tuple[str, List[SupervisionSegment]]], out_rttm_dir: str | None, ): - """ - Convert speaker timestamps to pyannote.core.Timeline object. + """Convert speaker timestamps into the diarization annotation lists used by DER. + + Hypothesis / reference / UEM are represented as lists of + :class:`lhotse.SupervisionSegment`; every consumer in NeMo's DER pipeline + accepts this representation. Args: speaker_timestamps (List[Tuple[float, float]]): @@ -1380,51 +1393,50 @@ def timestamps_to_pyannote_object( Unique ID of each speaker. audio_rttm_values (Dict[str, str]): Dictionary of manifest values. - all_hypothesis (List[Tuple[str, pyannote.core.Timeline]]): - List of hypothesis in pyannote.core.Timeline object. - all_reference (List[Tuple[str, pyannote.core.Timeline]]): - List of reference in pyannote.core.Timeline object. - all_uems (List[Tuple[str, pyannote.core.Timeline]]): - List of uems in pyannote.core.Timeline object. + all_hypothesis (List[Tuple[str, List[SupervisionSegment]]]): + Accumulator list of hypothesis annotations. + all_reference (List[Tuple[str, List[SupervisionSegment]]]): + Accumulator list of reference annotations. + all_uems (List[Tuple[str, List[SupervisionSegment]]]): + Accumulator list of UEM timelines. out_rttm_dir (str | None): Directory to save RTTMs Returns: - all_hypothesis (List[Tuple[str, pyannote.core.Timeline]]): - List of hypothesis in pyannote.core.Timeline object with an added Timeline object. - all_reference (List[Tuple[str, pyannote.core.Timeline]]): - List of reference in pyannote.core.Timeline object with an added Timeline object. - all_uems (List[Tuple[str, pyannote.core.Timeline]]): - List of uems in pyannote.core.Timeline object with an added Timeline object. + Updated ``(all_hypothesis, all_reference, all_uems)`` tuple, each + entry of the form ``(uniq_id, list_of_SupervisionSegment)``. """ offset, dur = float(audio_rttm_values.get('offset', None)), float(audio_rttm_values.get('duration', None)) hyp_labels = generate_diarization_output_lines( speaker_timestamps=speaker_timestamps, model_spk_num=len(speaker_timestamps) ) - hypothesis = labels_to_pyannote_object(hyp_labels, uniq_name=uniq_id) + hypothesis = labels_to_supervisions(hyp_labels, uniq_name=uniq_id) if out_rttm_dir is not None and os.path.exists(out_rttm_dir): with open(f'{out_rttm_dir}/{uniq_id}.rttm', 'w') as f: - hypothesis.write_rttm(f) + write_supervisions_to_rttm(hypothesis, f, recording_id=uniq_id) all_hypothesis.append([uniq_id, hypothesis]) rttm_file = audio_rttm_values.get('rttm_filepath', None) if rttm_file is not None and os.path.exists(rttm_file): uem_lines = [[offset, dur + offset]] org_ref_labels = rttm_to_labels(rttm_file) ref_labels = org_ref_labels - reference = labels_to_pyannote_object(ref_labels, uniq_name=uniq_id) + reference = labels_to_supervisions(ref_labels, uniq_name=uniq_id) uem_obj = get_uem_object(uem_lines, uniq_id=uniq_id) all_uems.append(uem_obj) all_reference.append([uniq_id, reference]) return all_hypothesis, all_reference, all_uems -def get_uem_object(uem_lines: List[List[float]], uniq_id: str): - """ - Generate pyannote timeline segments for uem file. +def get_uem_object(uem_lines: List[List[float]], uniq_id: str) -> List[SupervisionSegment]: + """Generate the UEM (evaluation regions) timeline for a session. file format UNIQ_SPEAKER_ID CHANNEL START_TIME END_TIME + Returns a ``list`` of :class:`lhotse.SupervisionSegment` (each with + ``speaker="UEM"``), which is the UEM representation used throughout + NeMo's DER pipeline. + Args: uem_lines (list): list of session ID and start, end times. Example: @@ -1432,13 +1444,9 @@ def get_uem_object(uem_lines: List[List[float]], uniq_id: str): uniq_id (str): Unique session ID. Returns: - timeline (pyannote.core.Timeline): pyannote timeline object. + List of :class:`lhotse.SupervisionSegment` representing the UEM. """ - timeline = Timeline(uri=uniq_id) - for uem_stt_end in uem_lines: - start_time, end_time = uem_stt_end - timeline.add(Segment(float(start_time), float(end_time))) - return timeline + return make_uem_timeline(uem_lines, uniq_id=uniq_id) def embedding_normalize(embs, use_std=False, eps=1e-10): diff --git a/nemo/collections/asr/parts/utils/vad_utils.py b/nemo/collections/asr/parts/utils/vad_utils.py index 0848a4f9c3f1..b2725d3f622e 100644 --- a/nemo/collections/asr/parts/utils/vad_utils.py +++ b/nemo/collections/asr/parts/utils/vad_utils.py @@ -31,12 +31,21 @@ import pandas as pd import torch import yaml +from lhotse import SupervisionSegment from omegaconf import DictConfig, OmegaConf -from pyannote.core import Annotation, Segment -from pyannote.metrics import detection from sklearn.metrics import roc_auc_score from sklearn.model_selection import ParameterGrid from tqdm import tqdm + +from nemo.collections.asr.metrics.der import make_diar_segment +from nemo.collections.asr.metrics.md_eval import ( + EPSILON, + _annotation_to_rttm_data, + _merge_rttm_dicts, + _merge_uem_dicts, + _uem_list_to_uem_data, + evaluate, +) from nemo.collections.asr.models import EncDecClassificationModel, EncDecFrameClassificationModel from nemo.collections.common.parts.preprocessing.manifest import get_full_path from nemo.utils import logging @@ -79,7 +88,7 @@ def load_postprocessing_from_yaml(postprocessing_yaml: str = None) -> PostProces postprocessing_params = OmegaConf.structured(PostProcessingParams()) if postprocessing_yaml is None: logging.info( - f"No postprocessing YAML file has been provided. Default postprocessing configurations will be applied." + "No postprocessing YAML file has been provided. Default postprocessing configurations will be applied." ) else: # Load postprocessing params from the provided YAML file @@ -332,14 +341,15 @@ def generate_overlap_vad_seq( if num_workers is not None and num_workers > 1: with multiprocessing.Pool(processes=num_workers) as p: inputs = zip(frame_filepathlist, repeat(per_args)) - results = list( - tqdm( - p.imap(generate_overlap_vad_seq_per_file_star, inputs), - total=len(frame_filepathlist), - desc='generating preds', - leave=True, - ) - ) + # Force-consume ``imap`` so the worker pool actually runs each task; + # the per-file results are written to disk by the worker, not returned. + for _ in tqdm( + p.imap(generate_overlap_vad_seq_per_file_star, inputs), + total=len(frame_filepathlist), + desc='generating preds', + leave=True, + ): + pass else: for frame_filepath in tqdm(frame_filepathlist, desc='generating preds', leave=False): @@ -416,7 +426,7 @@ def generate_overlap_vad_seq_per_tensor( if j <= target_len - 1: preds[j] = torch.cat((preds[j], og_pred.unsqueeze(0)), 0) - preds = torch.stack([torch.nanquantile(l, q=0.5) for l in preds]) + preds = torch.stack([torch.nanquantile(per_frame_preds, q=0.5) for per_frame_preds in preds]) nan_idx = torch.isnan(preds) last_non_nan_pred = preds[~nan_idx][-1] preds[nan_idx] = last_non_nan_pred @@ -524,7 +534,8 @@ def binarization(sequence: torch.Tensor, per_args: Dict[str, float]) -> torch.Te Reference Paper: Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of RNN-based Voice Activity Detection", InterSpeech 2015. - Implementation: https://github.com/pyannote/pyannote-audio/blob/master/pyannote/audio/utils/signal.py + Implementation: see the equivalent reference implementation in the External + Annotation Library's audio toolkit (``utils/signal.py``). Args: sequence (torch.Tensor) : A tensor of frame level predictions. @@ -616,8 +627,8 @@ def filtering(speech_segments: torch.Tensor, per_args: Dict[str, float]) -> torc Reference: Paper: Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of RNN-based Voice Activity Detection", InterSpeech 2015. - Implementation: - https://github.com/pyannote/pyannote-audio/blob/master/pyannote/audio/utils/signal.py + Implementation: see the equivalent reference implementation in the + External Annotation Library's audio toolkit (``utils/signal.py``). Args: speech_segments (torch.Tensor): @@ -742,9 +753,9 @@ def generate_vad_segment_table_per_file(pred_filepath: str, per_args: dict) -> s if preds.shape[0] == 0: with open(save_path, "w", encoding='utf-8') as fp: if per_args.get("use_rttm", False): - fp.write(f"SPEAKER 1 0 0 speech \n") + fp.write("SPEAKER 1 0 0 speech \n") else: - fp.write(f"0 0 speech\n") + fp.write("0 0 speech\n") else: with open(save_path, "w", encoding='utf-8') as fp: for i in preds: @@ -825,32 +836,38 @@ def generate_vad_segment_table_per_file_star(args): return generate_vad_segment_table_per_file(*args) -def vad_construct_pyannote_object_per_file( +def vad_construct_supervisions_per_file( vad_table_filepath: str, groundtruth_RTTM_file: str -) -> Tuple[Annotation, Annotation]: - """ - Construct a Pyannote object for evaluation. +) -> Tuple[List[SupervisionSegment], List[SupervisionSegment]]: + """Construct annotation objects for VAD evaluation. + + Returns lists of :class:`lhotse.SupervisionSegment` that are accepted by + every NeMo DER helper. + Args: vad_table_filepath(str) : path of vad rttm-like table. groundtruth_RTTM_file(str): path of groundtruth rttm file. Returns: - reference(pyannote.Annotation): groundtruth - hypothesis(pyannote.Annotation): prediction + reference(List[SupervisionSegment]): groundtruth + hypothesis(List[SupervisionSegment]): prediction """ pred = pd.read_csv(vad_table_filepath, sep=" ", header=None) label = pd.read_csv(groundtruth_RTTM_file, sep=" ", delimiter=None, header=None) label = label.rename(columns={3: "start", 4: "dur", 7: "speaker"}) - # construct reference - reference = Annotation() - for index, row in label.iterrows(): - reference[Segment(row['start'], row['start'] + row['dur'])] = row['speaker'] - - # construct hypothsis - hypothesis = Annotation() - for index, row in pred.iterrows(): - hypothesis[Segment(float(row[0]), float(row[0]) + float(row[1]))] = 'Speech' + rec_id = os.path.splitext(os.path.basename(groundtruth_RTTM_file))[0] + reference: List[SupervisionSegment] = [] + for _index, row in label.iterrows(): + start = float(row['start']) + end = start + float(row['dur']) + reference.append(make_diar_segment(start, end, str(row['speaker']), recording_id=rec_id)) + + hypothesis: List[SupervisionSegment] = [] + for _index, row in pred.iterrows(): + start = float(row[0]) + end = start + float(row[1]) + hypothesis.append(make_diar_segment(start, end, "Speech", recording_id=rec_id)) return reference, hypothesis @@ -872,6 +889,143 @@ def get_parameter_grid(params: dict) -> list: return params_grid +class _DetectionErrorRateAccumulator: + """md-eval-backed replacement for the external library's DetectionErrorRate. + + Detection Error Rate (DetER) is the fraction of reference-speech time that + is either missed or falsely detected: + + DetER = (false_alarm + missed) / scored + + Equivalent to Diarization Error Rate (DER) when the *speaker confusion* + component is dropped (which is appropriate for VAD, where every speech + frame is collapsed into a single "speech" speaker). + + This class mimics the call surface used in :func:`vad_tune_threshold_on_dev` + and :func:`frame_vad_eval_detection_error`: + + * ``metric(reference, hypothesis)`` accumulates one file pair. + * ``metric.report(display=False)`` returns a pandas ``DataFrame`` whose + last row carries the cumulative percentages (``("detection error rate", "%")``, + ``("false alarm", "%")``, ``("miss", "%")``). + * ``metric.reset()`` clears the accumulator. + """ + + _CUM_INDEX = "TOTAL" + + def __init__(self) -> None: + self._files: List[Tuple[str, Dict[str, float]]] = [] + self._counter = 0 + + def __call__( + self, + reference, + hypothesis, + uem=None, + file_id: Optional[str] = None, + ) -> None: + """Accumulate the (FA, MISS) for a single ``(reference, hypothesis)`` pair. + + Args: + reference: Annotation-like object accepted by md_eval (e.g. list of + :class:`lhotse.SupervisionSegment`). + hypothesis: Annotation-like object accepted by md_eval. + uem: Optional UEM as a list of ``[start, end]`` pairs or + annotation-like object iterable of ``SupervisionSegment``. + file_id: Optional unique file id; auto-generated when omitted. + """ + self._counter += 1 + fid = file_id or f"file_{self._counter:06d}" + + ref_data = _merge_rttm_dicts([_annotation_to_rttm_data(fid, reference)]) + sys_data = _merge_rttm_dicts([_annotation_to_rttm_data(fid, hypothesis)]) + + uem_data = None + if uem is not None: + if hasattr(uem, "__iter__") and not isinstance(uem, (list, tuple)): + uem = list(uem) + if uem and hasattr(uem[0], "start"): + uem_pairs = [[float(s.start), float(s.end)] for s in uem] + else: + uem_pairs = [[float(s), float(e)] for s, e in uem] + if uem_pairs: + uem_data = _merge_uem_dicts([_uem_list_to_uem_data(fid, uem_pairs)]) + + _, cum = evaluate( + ref_data, + sys_data, + uem_data=uem_data, + collar=0.0, + opt_1=False, + verbose=False, + ) + scored = cum.get("SCORED_SPEAKER", 0.0) + falarm = cum.get("FALARM_SPEAKER", 0.0) + missed = cum.get("MISSED_SPEAKER", 0.0) + self._files.append((fid, {"scored": scored, "false_alarm": falarm, "missed": missed})) + + def reset(self) -> None: + """Clear the internal accumulator.""" + self._files.clear() + self._counter = 0 + + def report(self, display: bool = False): + """Return a pandas DataFrame report of per-file and cumulative DetER. + + Mirrors the column layout consumed by the existing call sites: + ``report.iloc[[-1]][('detection error rate', '%')]``, + ``report.iloc[[-1]][('false alarm', '%')]``, and + ``report.iloc[[-1]][('miss', '%')]``. The last row is labelled + ``"TOTAL"`` and holds the aggregate metrics. + + Args: + display: Kept for API compatibility; ignored (no printing). + + Returns: + ``pandas.DataFrame`` with a 2-level column header + ``(metric, unit)`` and one row per accumulated file plus a final + ``TOTAL`` row. + """ + del display # interface compatibility + rows = [] + index = [] + total_scored = 0.0 + total_falarm = 0.0 + total_missed = 0.0 + for fid, stats in self._files: + scored = stats["scored"] + fa = stats["false_alarm"] + miss = stats["missed"] + deter = 100.0 * (fa + miss) / scored if scored > EPSILON else 0.0 + fa_pct = 100.0 * fa / scored if scored > EPSILON else 0.0 + miss_pct = 100.0 * miss / scored if scored > EPSILON else 0.0 + rows.append( + { + ("detection error rate", "%"): deter, + ("false alarm", "%"): fa_pct, + ("miss", "%"): miss_pct, + } + ) + index.append(fid) + total_scored += scored + total_falarm += fa + total_missed += miss + + deter_total = 100.0 * (total_falarm + total_missed) / total_scored if total_scored > EPSILON else 0.0 + fa_total = 100.0 * total_falarm / total_scored if total_scored > EPSILON else 0.0 + miss_total = 100.0 * total_missed / total_scored if total_scored > EPSILON else 0.0 + rows.append( + { + ("detection error rate", "%"): deter_total, + ("false alarm", "%"): fa_total, + ("miss", "%"): miss_total, + } + ) + index.append(self._CUM_INDEX) + + return pd.DataFrame(rows, index=index) + + def vad_tune_threshold_on_dev( params: dict, vad_pred: str, @@ -905,7 +1059,7 @@ def vad_tune_threshold_on_dev( raise ValueError("Please check if the parameters are valid") paired_filenames, groundtruth_RTTM_dict, vad_pred_dict = pred_rttm_map(vad_pred, groundtruth_RTTM, vad_pred_method) - metric = detection.DetectionErrorRate() + metric = _DetectionErrorRateAccumulator() params_grid = get_parameter_grid(params) for param in params_grid: @@ -922,9 +1076,7 @@ def vad_tune_threshold_on_dev( for filename in paired_filenames: groundtruth_RTTM_file = groundtruth_RTTM_dict[filename] vad_table_filepath = os.path.join(vad_table_dir, filename + ".txt") - reference, hypothesis = vad_construct_pyannote_object_per_file( - vad_table_filepath, groundtruth_RTTM_file - ) + reference, hypothesis = vad_construct_supervisions_per_file(vad_table_filepath, groundtruth_RTTM_file) metric(reference, hypothesis) # accumulation # delete tmp table files @@ -1600,23 +1752,27 @@ def align_labels_to_frames(probs, labels, threshold=0.2): return labels.long().tolist() -def read_rttm_as_pyannote_object(rttm_file: str, speaker_override: Optional[str] = None) -> Annotation: - """ - Read rttm file and construct a Pyannote object. +def read_rttm_as_supervisions(rttm_file: str, speaker_override: Optional[str] = None) -> List[SupervisionSegment]: + """Read an RTTM file and return it as a list of supervision segments. + + Returns a ``list`` of :class:`lhotse.SupervisionSegment`, which is the + annotation type used throughout NeMo's DER pipeline. + Args: rttm_file(str) : path of rttm file. speaker_override(str) : if not None, all speakers will be replaced by this value. Returns: - annotation(pyannote.Annotation): annotation object + annotation(List[SupervisionSegment]): annotation object """ - annotation = Annotation() + rec_id = os.path.splitext(os.path.basename(rttm_file))[0] + annotation: List[SupervisionSegment] = [] data = pd.read_csv(rttm_file, sep=r"\s+", delimiter=None, header=None) data = data.rename(columns={3: "start", 4: "dur", 7: "speaker"}) - for index, row in data.iterrows(): - if speaker_override is not None: - annotation[Segment(row['start'], row['start'] + row['dur'])] = speaker_override - else: - annotation[Segment(row['start'], row['start'] + row['dur'])] = row['speaker'] + for _index, row in data.iterrows(): + start = float(row['start']) + end = start + float(row['dur']) + speaker = speaker_override if speaker_override is not None else str(row['speaker']) + annotation.append(make_diar_segment(start, end, speaker, recording_id=rec_id)) return annotation @@ -1644,41 +1800,46 @@ def convert_labels_to_speech_segments(labels: List[float], frame_length_in_sec: return segments -def frame_vad_construct_pyannote_object_per_file( +def frame_vad_construct_supervisions_per_file( prediction: Union[str, List[float]], groundtruth: Union[str, List[float]], frame_length_in_sec: float = 0.01 -) -> Tuple[Annotation, Annotation]: - """ - Construct a Pyannote object for evaluation. +) -> Tuple[List[SupervisionSegment], List[SupervisionSegment]]: + """Construct annotation objects for frame-level VAD evaluation. + + Returns lists of :class:`lhotse.SupervisionSegment`. + Args: prediction (str) : path of VAD predictions stored as RTTM or CSV-like txt. groundtruth (str): path of groundtruth rttm file. frame_length_in_sec(float): frame length in seconds Returns: - reference(pyannote.Annotation): groundtruth - hypothesis(pyannote.Annotation): prediction + reference(List[SupervisionSegment]): groundtruth + hypothesis(List[SupervisionSegment]): prediction """ - hypothesis = Annotation() + rec_id = "frame_vad" + hypothesis: List[SupervisionSegment] = [] if isinstance(groundtruth, str) and prediction.endswith('.rttm'): - hypothesis = read_rttm_as_pyannote_object(prediction, speaker_override='speech') + hypothesis = read_rttm_as_supervisions(prediction, speaker_override='speech') elif isinstance(groundtruth, str) and prediction.endswith('.txt'): pred = pd.read_csv(prediction, sep=" ", header=None) - for index, row in pred.iterrows(): - hypothesis[Segment(float(row[0]), float(row[0]) + float(row[1]))] = 'speech' + for _index, row in pred.iterrows(): + start = float(row[0]) + end = start + float(row[1]) + hypothesis.append(make_diar_segment(start, end, 'speech', recording_id=rec_id)) elif isinstance(groundtruth, list): segments = convert_labels_to_speech_segments(prediction, frame_length_in_sec) for segment in segments: - hypothesis[Segment(segment[0], segment[1])] = 'speech' + hypothesis.append(make_diar_segment(float(segment[0]), float(segment[1]), 'speech', recording_id=rec_id)) else: raise ValueError('prediction must be a path to rttm file or a list of frame labels.') - reference = Annotation() + reference: List[SupervisionSegment] = [] if isinstance(groundtruth, str) and groundtruth.endswith('.rttm'): - reference = read_rttm_as_pyannote_object(groundtruth, speaker_override='speech') + reference = read_rttm_as_supervisions(groundtruth, speaker_override='speech') elif isinstance(groundtruth, list): segments = convert_labels_to_speech_segments(groundtruth, frame_length_in_sec) for segment in segments: - reference[Segment(segment[0], segment[1])] = 'speech' + reference.append(make_diar_segment(float(segment[0]), float(segment[1]), 'speech', recording_id=rec_id)) else: raise ValueError('groundtruth must be a path to rttm file or a list of frame labels.') return reference, hypothesis @@ -1749,11 +1910,13 @@ def frame_vad_eval_detection_error( frame_length_in_sec: frame length in seconds, e.g. 0.02s Returns: auroc: AUROC score in 0~100% - report: Pyannote detection.DetectionErrorRate() report + report: pandas DataFrame with per-file and cumulative detection error + rate, false alarm, and miss percentages (matches the historical + external-engine report layout). """ all_probs = [] all_labels = [] - metric = detection.DetectionErrorRate() + metric = _DetectionErrorRateAccumulator() key_probs_map = {} predictions_list = list(Path(pred_dir).glob("*.frame")) for frame_pred in tqdm(predictions_list, desc="Evaluating VAD results", total=len(predictions_list)): @@ -1775,7 +1938,7 @@ def frame_vad_eval_detection_error( else: groundtruth = key_labels_map[key] - reference, hypothesis = frame_vad_construct_pyannote_object_per_file( + reference, hypothesis = frame_vad_construct_supervisions_per_file( prediction=key_pred_rttm_map[key], groundtruth=groundtruth, frame_length_in_sec=frame_length_in_sec, diff --git a/requirements/requirements_asr.txt b/requirements/requirements_asr.txt index bc5869a9e5d1..82dab597c4ea 100644 --- a/requirements/requirements_asr.txt +++ b/requirements/requirements_asr.txt @@ -11,8 +11,6 @@ librosa>=0.10.1 marshmallow optuna packaging -pyannote.core -pyannote.metrics pydub pyloudnorm resampy diff --git a/scripts/speaker_tasks/eval_diar_with_asr.py b/scripts/speaker_tasks/eval_diar_with_asr.py index 9fc651e953cd..69e06477797b 100644 --- a/scripts/speaker_tasks/eval_diar_with_asr.py +++ b/scripts/speaker_tasks/eval_diar_with_asr.py @@ -22,7 +22,7 @@ from nemo.collections.asr.parts.utils.manifest_utils import read_file from nemo.collections.asr.parts.utils.speaker_utils import ( get_uniqname_from_filepath, - labels_to_pyannote_object, + labels_to_supervisions, rttm_to_labels, ) @@ -79,23 +79,24 @@ """ -def get_pyannote_objs_from_rttms(rttm_file_path_list): - """Generate PyAnnote objects from RTTM file list +def get_supervisions_from_rttms(rttm_file_path_list): + """Generate diarization annotation objects from a list of RTTM files. + + Each entry in the returned list is ``[uniq_id, list[SupervisionSegment]]``. """ - pyannote_obj_list = [] + annotation_obj_list = [] for rttm_file in rttm_file_path_list: rttm_file = rttm_file.strip() if rttm_file is not None and os.path.exists(rttm_file): uniq_id = get_uniqname_from_filepath(rttm_file) ref_labels = rttm_to_labels(rttm_file) - reference = labels_to_pyannote_object(ref_labels, uniq_name=uniq_id) - pyannote_obj_list.append([uniq_id, reference]) - return pyannote_obj_list + reference = labels_to_supervisions(ref_labels, uniq_name=uniq_id) + annotation_obj_list.append([uniq_id, reference]) + return annotation_obj_list def make_meta_dict(hyp_rttm_list, ref_rttm_list): - """Create a temporary `audio_rttm_map_dict` for evaluation - """ + """Create a temporary `audio_rttm_map_dict` for evaluation""" meta_dict = {} for k, rttm_file in enumerate(ref_rttm_list): uniq_id = get_uniqname_from_filepath(rttm_file) @@ -107,8 +108,7 @@ def make_meta_dict(hyp_rttm_list, ref_rttm_list): def make_trans_info_dict(hyp_json_list_path): - """Create `trans_info_dict` from the `.json` files - """ + """Create `trans_info_dict` from the `.json` files""" trans_info_dict = {} for json_file in hyp_json_list_path: json_file = json_file.strip() @@ -120,8 +120,7 @@ def make_trans_info_dict(hyp_json_list_path): def read_file_path(list_path): - """Read file path and strip to remove line change symbol - """ + """Read file path and strip to remove line change symbol""" return sorted([x.strip() for x in read_file(list_path)]) @@ -146,8 +145,8 @@ def main( trans_info_dict = make_trans_info_dict(hyp_json_list) if hyp_json_list else None - all_hypothesis = get_pyannote_objs_from_rttms(hyp_rttm_list) - all_reference = get_pyannote_objs_from_rttms(ref_rttm_list) + all_hypothesis = get_supervisions_from_rttms(hyp_rttm_list) + all_reference = get_supervisions_from_rttms(ref_rttm_list) diar_score = evaluate_der( audio_rttm_map_dict=audio_rttm_map_dict, diff --git a/tests/collections/asr/utils/test_vad_utils_asr.py b/tests/collections/asr/utils/test_vad_utils_asr.py index a7672e1aa43d..46afc4d3b013 100644 --- a/tests/collections/asr/utils/test_vad_utils_asr.py +++ b/tests/collections/asr/utils/test_vad_utils_asr.py @@ -14,17 +14,17 @@ import numpy as np import pytest -from pyannote.core import Annotation, Segment +from lhotse import SupervisionSegment from nemo.collections.asr.parts.utils.vad_utils import ( align_labels_to_frames, convert_labels_to_speech_segments, - frame_vad_construct_pyannote_object_per_file, + frame_vad_construct_supervisions_per_file, get_frame_labels, get_nonspeech_segments, load_speech_overlap_segments_from_rttm, load_speech_segments_from_rttm, - read_rttm_as_pyannote_object, + read_rttm_as_supervisions, ) @@ -101,26 +101,41 @@ def test_convert_labels_to_speech_segments(self, test_data_dir): assert speech_segments_new == speech_segments @pytest.mark.unit - def test_read_rttm_as_pyannote_object(self, test_data_dir): + def test_read_rttm_as_supervisions(self, test_data_dir): rttm_file, speech_segments = get_simple_rttm_without_overlap(test_data_dir + "/test6.rttm") - pyannote_object = read_rttm_as_pyannote_object(rttm_file) - pyannote_object_gt = Annotation() - pyannote_object_gt[Segment(0.0, 2.0)] = 'speech' - assert pyannote_object == pyannote_object_gt + annotation = read_rttm_as_supervisions(rttm_file) + assert _annotation_equals(annotation, [(0.0, 2.0, 'speech')]) @pytest.mark.unit - def test_frame_vad_construct_pyannote_object_per_file(self, test_data_dir): + def test_frame_vad_construct_supervisions_per_file(self, test_data_dir): rttm_file, speech_segments = get_simple_rttm_without_overlap(test_data_dir + "/test7.rttm") # test for rttm input - ref, hyp = frame_vad_construct_pyannote_object_per_file(rttm_file, rttm_file) - pyannote_object_gt = Annotation() - pyannote_object_gt[Segment(0.0, 2.0)] = 'speech' - assert ref == hyp == pyannote_object_gt + ref, hyp = frame_vad_construct_supervisions_per_file(rttm_file, rttm_file) + expected = [(0.0, 2.0, 'speech')] + assert _annotation_equals(ref, expected) + assert _annotation_equals(hyp, expected) # test for list input speech_segments = load_speech_segments_from_rttm(rttm_file) frame_labels = get_frame_labels(speech_segments, 0.02, 0.0, 3.0, as_str=False) speech_segments_new = convert_labels_to_speech_segments(frame_labels, 0.02) assert speech_segments_new == speech_segments - ref, hyp = frame_vad_construct_pyannote_object_per_file(frame_labels, frame_labels, 0.02) - assert ref == hyp == pyannote_object_gt + ref, hyp = frame_vad_construct_supervisions_per_file(frame_labels, frame_labels, 0.02) + assert _annotation_equals(ref, expected) + assert _annotation_equals(hyp, expected) + + +def _annotation_equals(annotation, expected_segments, *, atol=1e-6): + """Compare a list of :class:`lhotse.SupervisionSegment` to expected ``(start, end, speaker)`` tuples.""" + assert isinstance(annotation, list) + assert all(isinstance(s, SupervisionSegment) for s in annotation) + if len(annotation) != len(expected_segments): + return False + for seg, (exp_start, exp_end, exp_spk) in zip(annotation, expected_segments): + if abs(float(seg.start) - exp_start) > atol: + return False + if abs(float(seg.end) - exp_end) > atol: + return False + if seg.speaker != exp_spk: + return False + return True diff --git a/tests/collections/speaker_tasks/utils/test_der.py b/tests/collections/speaker_tasks/utils/test_der.py new file mode 100644 index 000000000000..2b59a63435cc --- /dev/null +++ b/tests/collections/speaker_tasks/utils/test_der.py @@ -0,0 +1,1635 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tests for DER calculation in nemo.collections.asr.metrics.der and +nemo.collections.asr.metrics.md_eval. + +All expected values are pre-verified against an external annotation +library (3.x, the historical NeMo dependency that exposed +``Annotation`` / ``Segment`` / ``Timeline`` and a reference +``DiarizationErrorRate``). The values are hardcoded here so that +**this file does not import any external annotation library**. + +md-eval (NIST md-eval-22.pl) and the external reference engine share the +same DER semantics (optimal speaker mapping via the Hungarian algorithm, +same collar / overlap conventions) and produce identical results when +the UEM is equivalent. The only behavioural difference captured by these +tests is that md-eval derives the evaluation region (UEM) from the +*reference* extent whereas the external engine uses the *union* of +reference and hypothesis extents. Tests for that case use an explicit +UEM to keep the two engines aligned. + +The :class:`TestLhotseAnnotation` group additionally covers the +lhotse-based replacement for the external annotation library types +(``Annotation`` / ``Segment`` / ``Timeline``) introduced in +:mod:`nemo.collections.asr.metrics.der`. Annotations are built as lists +of :class:`lhotse.SupervisionSegment` and must produce bit-identical DER +to the legacy label-string path. +""" + +import io + +import pytest +from lhotse import SupervisionSegment, SupervisionSet + +from nemo.collections.asr.metrics.der import ( + make_diar_annotation, + make_diar_segment, + make_uem_timeline, + score_labels, + score_labels_from_rttm_labels, + unique_speakers, + write_supervisions_to_rttm, +) +from nemo.collections.asr.metrics.md_eval import ( + EPSILON, + DiarizationErrorResult, + _iter_annotation_segments, + _labels_to_rttm_data, + _merge_rttm_dicts, + _merge_uem_dicts, + _uem_list_to_uem_data, + evaluate, +) + +# ─── Helpers ────────────────────────────────────────────────────────────── + + +def _seg(start: float, end: float, spk: str) -> str: + """Create a ``"start end speaker"`` label string.""" + return f"{start} {end} {spk}" + + +def _labels(*segments): + """Convert ``(start, end, speaker)`` tuples to label strings.""" + return [_seg(s, e, k) for s, e, k in segments] + + +def _score( + ref_segs, + hyp_segs, + collar=0.0, + ignore_overlap=False, + uem_segs=None, + file_id="file1", +): + """Score a single file through the public ``score_labels_from_rttm_labels`` API.""" + ref_labels = _labels(*ref_segs) + hyp_labels = _labels(*hyp_segs) + ref_list = [(file_id, ref_labels)] + hyp_list = [(file_id, hyp_labels)] + uem_list = [(file_id, uem_segs)] if uem_segs else None + result = score_labels_from_rttm_labels( + ref_list, + hyp_list, + uem_segments_list=uem_list, + collar=collar, + ignore_overlap=ignore_overlap, + verbose=False, + ) + assert result is not None, "score_labels_from_rttm_labels returned None" + return result + + +def _score_raw( + ref_segs, + hyp_segs, + collar=0.0, + ignore_overlap=False, + uem_segs=None, + file_id="file1", +): + """Score a single file through the low-level ``evaluate`` API in md_eval.""" + ref_labels = _labels(*ref_segs) + hyp_labels = _labels(*hyp_segs) + ref_data = _merge_rttm_dicts([_labels_to_rttm_data(file_id, ref_labels)]) + sys_data = _merge_rttm_dicts([_labels_to_rttm_data(file_id, hyp_labels)]) + uem_data = None + if uem_segs: + uem_data = _merge_uem_dicts([_uem_list_to_uem_data(file_id, uem_segs)]) + _, cum = evaluate( + ref_data, + sys_data, + uem_data=uem_data, + collar=collar, + opt_1=ignore_overlap, + verbose=False, + ) + scored = cum.get("SCORED_SPEAKER", 0.0) or EPSILON + missed = cum.get("MISSED_SPEAKER", 0.0) + falarm = cum.get("FALARM_SPEAKER", 0.0) + error = cum.get("SPEAKER_ERROR", 0.0) + return { + "DER": (missed + falarm + error) / scored, + "CER": error / scored, + "FA": falarm / scored, + "MISS": missed / scored, + "scored": scored, + } + + +def assert_der(actual, expected, tol=1e-6): + diff = abs(actual - expected) + assert diff <= tol, f"DER mismatch: actual={actual:.8f}, expected={expected:.8f}" + + +def _score_lhotse( + ref_segs, + hyp_segs, + collar=0.0, + ignore_overlap=False, + uem_segs=None, + file_id="file1", +): + """Score a single file through ``score_labels`` using lhotse-based annotations. + + Mirrors :func:`_score` but builds the reference and hypothesis as lists of + ``lhotse.SupervisionSegment`` (via :func:`make_diar_annotation`) instead of + label strings, exercising the new lhotse-based pipeline end-to-end. + """ + ref_labels = _labels(*ref_segs) + hyp_labels = _labels(*hyp_segs) + ref_ann = make_diar_annotation(ref_labels, uniq_name=file_id) + hyp_ann = make_diar_annotation(hyp_labels, uniq_name=file_id) + all_uem = [make_uem_timeline(uem_segs, uniq_id=file_id)] if uem_segs else None + audio_rttm_map = {file_id: {}} + result = score_labels( + audio_rttm_map, + [(file_id, ref_ann)], + [(file_id, hyp_ann)], + all_uem=all_uem, + collar=collar, + ignore_overlap=ignore_overlap, + verbose=False, + ) + assert result is not None, "score_labels returned None" + return result + + +# ─── Tests: md_eval low-level engine ────────────────────────────────────── + + +class TestMdEvalBasic: + """Verify the md_eval engine produces correct DER for basic scenarios. + + Expected values verified against the external annotation library's + reference ``DiarizationErrorRate`` implementation. + """ + + @pytest.mark.unit + def test_perfect_match(self): + """Two speakers, perfect hypothesis → DER = 0.""" + r = _score_raw([(0, 5, "A"), (5, 10, "B")], [(0, 5, "A"), (5, 10, "B")]) + assert_der(r["DER"], 0.0) + assert_der(r["scored"], 10.0) + + @pytest.mark.unit + def test_complete_miss(self): + """Empty hypothesis → everything is missed.""" + r = _score_raw([(0, 5, "A"), (5, 10, "B")], []) + assert_der(r["DER"], 1.0) + assert_der(r["MISS"], 1.0) + assert_der(r["CER"], 0.0) + assert_der(r["FA"], 0.0) + assert_der(r["scored"], 10.0) + + @pytest.mark.unit + def test_speaker_swap_optimal_mapping(self): + """Swapped speaker labels → optimal mapping gives DER = 0.""" + r = _score_raw([(0, 5, "A"), (5, 10, "B")], [(0, 5, "B"), (5, 10, "A")]) + assert_der(r["DER"], 0.0) + assert_der(r["scored"], 10.0) + + @pytest.mark.unit + def test_partial_miss(self): + """Hypothesis covers first half only → 50% miss.""" + r = _score_raw([(0, 10, "A")], [(0, 5, "A")]) + assert_der(r["DER"], 0.5) + assert_der(r["MISS"], 0.5) + assert_der(r["scored"], 10.0) + + @pytest.mark.unit + def test_false_alarm_extend_with_uem(self): + """Hypothesis extends beyond reference; explicit UEM covers full range. + + With UEM [0, 10]: ref covers [0, 5], hyp covers [0, 10]. + Scored = 5.0 (only ref speech), FA = 5.0 → DER = 1.0. + """ + r = _score_raw([(0, 5, "A")], [(0, 10, "A")], uem_segs=[[0, 10]]) + assert_der(r["DER"], 1.0) + assert_der(r["FA"], 1.0) + assert_der(r["scored"], 5.0) + + @pytest.mark.unit + def test_false_alarm_extend_no_uem(self): + """Without explicit UEM, md-eval derives UEM from reference extent only. + + Hypothesis beyond ref boundary is not scored → DER = 0. + """ + r = _score_raw([(0, 5, "A")], [(0, 10, "A")]) + assert_der(r["DER"], 0.0) + assert_der(r["scored"], 5.0) + + @pytest.mark.unit + def test_single_speaker_confusion(self): + """Single ref speaker A, hyp speaker B → optimal mapping A↔B gives DER = 0.""" + r = _score_raw([(0, 10, "A")], [(0, 10, "B")]) + assert_der(r["DER"], 0.0) + + @pytest.mark.unit + def test_gap_perfect(self): + """Silence between speakers; perfect hypothesis.""" + r = _score_raw([(0, 3, "A"), (7, 10, "B")], [(0, 3, "A"), (7, 10, "B")]) + assert_der(r["DER"], 0.0) + assert_der(r["scored"], 6.0) + + @pytest.mark.unit + def test_false_alarm_in_gap(self): + """Spurious speaker in a silence gap → false alarm.""" + r = _score_raw( + [(0, 3, "A"), (7, 10, "B")], + [(0, 3, "A"), (4, 6, "X"), (7, 10, "B")], + ) + assert_der(r["DER"], 1 / 3) + assert_der(r["FA"], 1 / 3) + assert_der(r["scored"], 6.0) + + +class TestMdEvalCollar: + """Verify collar (no-score zone) handling.""" + + @pytest.mark.unit + def test_collar_perfect(self): + """Perfect hypothesis with collar → DER = 0, scored shrinks by collar.""" + r = _score_raw( + [(0, 5, "A"), (5, 10, "B")], + [(0, 5, "A"), (5, 10, "B")], + collar=0.25, + ) + assert_der(r["DER"], 0.0) + assert_der(r["scored"], 9.0) + + @pytest.mark.unit + def test_collar_absorbs_offset(self): + """Hypothesis boundary offset (0.2s) inside collar (0.25s) → DER = 0.""" + r = _score_raw( + [(0, 5, "A"), (5, 10, "B")], + [(0, 5.2, "A"), (5.2, 10, "B")], + collar=0.25, + ) + assert_der(r["DER"], 0.0) + assert_der(r["scored"], 9.0) + + @pytest.mark.unit + def test_collar_boundary_error_within(self): + """Gap of 1.0s centred on boundary; collar of 0.25s covers 0.5s total. + + ref: A=[0,5], B=[5,10]; hyp: A=[0,4.5], B=[5.5,10]; collar=0.25. + No-score zone: [4.75, 5.25]. Miss from [4.5, 4.75] = 0.25s. + Miss from [5.25, 5.5] = 0.25s. Total miss = 0.5s. Scored = 9.0. + DER = 0.5/9.0 ≈ 0.0556. + """ + r = _score_raw( + [(0, 5, "A"), (5, 10, "B")], + [(0, 4.5, "A"), (5.5, 10, "B")], + collar=0.25, + ) + assert_der(r["DER"], 0.5 / 9.0) + assert_der(r["MISS"], 0.5 / 9.0) + + @pytest.mark.unit + def test_collar_boundary_error_exceeds(self): + """Larger gap at boundary exceeding collar. + + ref: A=[0,5], B=[5,10]; hyp: A=[0,4], B=[6,10]; collar=0.25. + No-score zone: [4.75, 5.25]. Miss outside collar: [4, 4.75]=0.75 + [5.25, 6]=0.75 = 1.5s. + Scored = 9.0. DER = 1.5/9.0 ≈ 0.1667. + """ + r = _score_raw( + [(0, 5, "A"), (5, 10, "B")], + [(0, 4, "A"), (6, 10, "B")], + collar=0.25, + ) + assert_der(r["DER"], 1.5 / 9.0) + assert_der(r["MISS"], 1.5 / 9.0) + + @pytest.mark.unit + def test_collar_3spk_perfect(self): + """Three speakers with large collar (0.5s) → scored = 7.0.""" + r = _score_raw( + [(0, 4, "A"), (4, 7, "B"), (7, 10, "C")], + [(0, 4, "A"), (4, 7, "B"), (7, 10, "C")], + collar=0.5, + ) + assert_der(r["DER"], 0.0) + assert_der(r["scored"], 7.0) + + +class TestMdEvalOverlap: + """Verify overlap handling with skip_overlap / ignore_overlap.""" + + @pytest.mark.unit + def test_overlap_perfect_skip(self): + """Overlapping ref [5,7]: skip_overlap=True → scored = 8.""" + r = _score_raw( + [(0, 7, "A"), (5, 10, "B")], + [(0, 7, "A"), (5, 10, "B")], + ignore_overlap=True, + ) + assert_der(r["DER"], 0.0) + assert_der(r["scored"], 8.0) + + @pytest.mark.unit + def test_overlap_perfect_noskip(self): + """Overlapping ref [5,7]: skip_overlap=False → scored = 12 (each speaker scored).""" + r = _score_raw( + [(0, 7, "A"), (5, 10, "B")], + [(0, 7, "A"), (5, 10, "B")], + ignore_overlap=False, + ) + assert_der(r["DER"], 0.0) + assert_der(r["scored"], 12.0) + + @pytest.mark.unit + def test_overlap_miss_one_speaker_skip(self): + """Overlap region [5,7]: hyp only has A (missed B). + + skip_overlap=True → overlap excluded. Scored = 8. + In non-overlap region [7,10]: B is present in ref, A covers it → confusion = 3.0. + DER = 3/8 = 0.375. + """ + r = _score_raw( + [(0, 7, "A"), (5, 10, "B")], + [(0, 10, "A")], + ignore_overlap=True, + ) + assert_der(r["DER"], 0.375) + assert_der(r["CER"], 0.375) + assert_der(r["scored"], 8.0) + + @pytest.mark.unit + def test_overlap_miss_one_speaker_noskip(self): + """Overlap region [5,7]: hyp only has A (missed B). + + skip_overlap=False → overlap included. Scored = 12. + Missed B in [5,7] = 2. Confusion B↔A in [7,10] = 3. Total = 5. + DER = 5/12 ≈ 0.4167. + """ + r = _score_raw( + [(0, 7, "A"), (5, 10, "B")], + [(0, 10, "A")], + ignore_overlap=False, + ) + assert_der(r["DER"], 5 / 12) + assert_der(r["CER"], 3 / 12) + assert_der(r["MISS"], 2 / 12) + assert_der(r["scored"], 12.0) + + +class TestMdEvalSpeakerCount: + """Verify speaker count mismatch scenarios.""" + + @pytest.mark.unit + def test_three_speakers_boundary_shift(self): + """Boundary shift between B and C: confusion in [6,7]. + + ref: A=[0,3], B=[3,7], C=[7,10]; hyp: A=[0,3], B=[3,6], C=[6,10]. + C is mapped to B in [6,7] → confusion = 1.0. Scored = 10. DER = 0.1. + """ + r = _score_raw( + [(0, 3, "A"), (3, 7, "B"), (7, 10, "C")], + [(0, 3, "A"), (3, 6, "B"), (6, 10, "C")], + ) + assert_der(r["DER"], 0.1) + assert_der(r["CER"], 0.1) + assert_der(r["scored"], 10.0) + + @pytest.mark.unit + def test_extra_hyp_speaker(self): + """Hypothesis has extra speaker C; ref only has A, B. + + ref: A=[0,5], B=[5,10]; hyp: A=[0,5], B=[5,8], C=[8,10]. + C covers ref B region → confusion = 2.0. DER = 0.2. + """ + r = _score_raw( + [(0, 5, "A"), (5, 10, "B")], + [(0, 5, "A"), (5, 8, "B"), (8, 10, "C")], + ) + assert_der(r["DER"], 0.2) + assert_der(r["CER"], 0.2) + + @pytest.mark.unit + def test_missing_hyp_speaker(self): + """Hypothesis missing speaker C; ref has A, B, C. + + ref: A=[0,5], B=[5,8], C=[8,10]; hyp: A=[0,5], B=[5,10]. + B covers ref C region → confusion = 2.0. DER = 0.2. + """ + r = _score_raw( + [(0, 5, "A"), (5, 8, "B"), (8, 10, "C")], + [(0, 5, "A"), (5, 10, "B")], + ) + assert_der(r["DER"], 0.2) + assert_der(r["CER"], 0.2) + + +class TestMdEvalUEM: + """Verify UEM (Un-partitioned Evaluation Map) handling.""" + + @pytest.mark.unit + def test_uem_restricts_evaluation(self): + """UEM restricts to [2, 8] out of [0, 10]. + + ref: A=[0,10]; hyp: A=[0,5], B=[5,10]. UEM=[2,8]. + Scored region: ref A in [2,8] = 6.0. + B covers [5,8] of ref A → confusion = 3.0. DER = 3/6 = 0.5. + """ + r = _score_raw( + [(0, 10, "A")], + [(0, 5, "A"), (5, 10, "B")], + uem_segs=[[2, 8]], + ) + assert_der(r["DER"], 0.5) + assert_der(r["CER"], 0.5) + assert_der(r["scored"], 6.0) + + +# ─── Tests: der.py public API (score_labels_from_rttm_labels) ──────────── + + +class TestScoreLabelsFromRttmLabels: + """Test the public ``score_labels_from_rttm_labels`` function in der.py. + + Verifies: return type, DiarizationErrorResult interface, and DER values. + """ + + @pytest.mark.unit + def test_perfect_match_returns_correct_types(self): + metric, mapping, (DER, CER, FA, MISS) = _score( + [(0, 5, "A"), (5, 10, "B")], + [(0, 5, "A"), (5, 10, "B")], + ) + assert isinstance(metric, DiarizationErrorResult) + assert isinstance(mapping, dict) + assert_der(DER, 0.0) + assert_der(CER, 0.0) + assert_der(FA, 0.0) + assert_der(MISS, 0.0) + + @pytest.mark.unit + def test_result_abs_interface(self): + """``abs(metric)`` returns overall DER.""" + metric, _, _ = _score( + [(0, 5, "A"), (5, 10, "B")], + [(0, 5, "A"), (5, 10, "B")], + ) + assert_der(abs(metric), 0.0) + + @pytest.mark.unit + def test_result_getitem_interface(self): + """``metric['total']`` etc. return correct values.""" + metric, _, _ = _score( + [(0, 10, "A")], + [(0, 5, "A")], + ) + assert_der(metric["total"], 10.0) + assert_der(metric["confusion"], 0.0) + assert_der(metric["false alarm"], 0.0) + assert_der(metric["missed detection"], 5.0) + assert_der(abs(metric), 0.5) + + @pytest.mark.unit + def test_result_optimal_mapping(self): + """Speaker mapping is accessible via ``metric.optimal_mapping()``.""" + metric, _, _ = _score( + [(0, 5, "A"), (5, 10, "B")], + [(0, 5, "B"), (5, 10, "A")], + ) + file_mapping = metric.optimal_mapping("file1", None) + assert "A" in file_mapping + assert file_mapping["A"] == "B" + assert file_mapping["B"] == "A" + + @pytest.mark.unit + def test_result_report(self): + """``metric.report()`` returns a non-empty string.""" + metric, _, _ = _score( + [(0, 5, "A"), (5, 10, "B")], + [(0, 5, "A"), (5, 10, "B")], + ) + report = metric.report() + assert isinstance(report, str) + assert len(report) > 0 + assert "file1" in report + + @pytest.mark.unit + def test_results_list(self): + """``metric.results_`` contains per-file score dicts.""" + metric, _, _ = _score( + [(0, 5, "A"), (5, 10, "B")], + [(0, 5, "A"), (5, 10, "B")], + ) + assert len(metric.results_) == 1 + file_id, scores = metric.results_[0] + assert file_id == "file1" + assert_der(scores["total"], 10.0) + assert_der(scores["confusion"], 0.0) + + @pytest.mark.unit + def test_complete_miss(self): + _, _, (DER, _, _, MISS) = _score( + [(0, 5, "A"), (5, 10, "B")], + [], + ) + assert_der(DER, 1.0) + assert_der(MISS, 1.0) + + @pytest.mark.unit + def test_speaker_swap(self): + _, _, (DER, _, _, _) = _score( + [(0, 5, "A"), (5, 10, "B")], + [(0, 5, "B"), (5, 10, "A")], + ) + assert_der(DER, 0.0) + + @pytest.mark.unit + def test_partial_miss(self): + _, _, (DER, _, _, MISS) = _score([(0, 10, "A")], [(0, 5, "A")]) + assert_der(DER, 0.5) + assert_der(MISS, 0.5) + + @pytest.mark.unit + def test_collar(self): + _, _, (DER, _, _, _) = _score( + [(0, 5, "A"), (5, 10, "B")], + [(0, 5, "A"), (5, 10, "B")], + collar=0.25, + ) + assert_der(DER, 0.0) + + @pytest.mark.unit + def test_collar_offset(self): + _, _, (DER, _, _, _) = _score( + [(0, 5, "A"), (5, 10, "B")], + [(0, 5.2, "A"), (5.2, 10, "B")], + collar=0.25, + ) + assert_der(DER, 0.0) + + @pytest.mark.unit + def test_overlap_skip(self): + _, _, (DER, _, _, _) = _score( + [(0, 7, "A"), (5, 10, "B")], + [(0, 7, "A"), (5, 10, "B")], + ignore_overlap=True, + ) + assert_der(DER, 0.0) + + @pytest.mark.unit + def test_overlap_miss_skip(self): + _, _, (DER, CER, _, _) = _score( + [(0, 7, "A"), (5, 10, "B")], + [(0, 10, "A")], + ignore_overlap=True, + ) + assert_der(DER, 0.375) + assert_der(CER, 0.375) + + @pytest.mark.unit + def test_overlap_miss_noskip(self): + _, _, (DER, CER, _, MISS) = _score( + [(0, 7, "A"), (5, 10, "B")], + [(0, 10, "A")], + ignore_overlap=False, + ) + assert_der(DER, 5 / 12) + assert_der(CER, 3 / 12) + assert_der(MISS, 2 / 12) + + @pytest.mark.unit + def test_three_speakers(self): + _, _, (DER, _, _, _) = _score( + [(0, 3, "A"), (3, 7, "B"), (7, 10, "C")], + [(0, 3, "A"), (3, 6, "B"), (6, 10, "C")], + ) + assert_der(DER, 0.1) + + @pytest.mark.unit + def test_extra_hyp_speaker(self): + _, _, (DER, CER, _, _) = _score( + [(0, 5, "A"), (5, 10, "B")], + [(0, 5, "A"), (5, 8, "B"), (8, 10, "C")], + ) + assert_der(DER, 0.2) + assert_der(CER, 0.2) + + @pytest.mark.unit + def test_missing_hyp_speaker(self): + _, _, (DER, CER, _, _) = _score( + [(0, 5, "A"), (5, 8, "B"), (8, 10, "C")], + [(0, 5, "A"), (5, 10, "B")], + ) + assert_der(DER, 0.2) + assert_der(CER, 0.2) + + @pytest.mark.unit + def test_false_alarm_in_gap(self): + _, _, (DER, _, FA, _) = _score( + [(0, 3, "A"), (7, 10, "B")], + [(0, 3, "A"), (4, 6, "X"), (7, 10, "B")], + ) + assert_der(DER, 1 / 3) + assert_der(FA, 1 / 3) + + @pytest.mark.unit + def test_uem_restrict(self): + _, _, (DER, CER, _, _) = _score( + [(0, 10, "A")], + [(0, 5, "A"), (5, 10, "B")], + uem_segs=[[2, 8]], + ) + assert_der(DER, 0.5) + assert_der(CER, 0.5) + + @pytest.mark.unit + def test_length_mismatch_returns_none(self): + """Mismatched ref/hyp list lengths should return None.""" + result = score_labels_from_rttm_labels( + [("f1", _labels((0, 5, "A")))], + [("f1", _labels((0, 5, "A"))), ("f2", _labels((0, 5, "B")))], + verbose=False, + ) + assert result is None + + +# ─── Tests: Multi-file scoring ─────────────────────────────────────────── + + +class TestMultiFile: + """Verify multi-file cumulative scoring.""" + + @pytest.mark.unit + def test_two_files_one_perfect_one_confusion(self): + """File1: perfect. File2: all confusion (mapped away). + + Combined: scored=10, DER=0 (optimal mapping maps C→B). + """ + ref_list = [ + ("file1", _labels((0, 5, "A"))), + ("file2", _labels((0, 5, "B"))), + ] + hyp_list = [ + ("file1", _labels((0, 5, "A"))), + ("file2", _labels((0, 5, "C"))), + ] + result = score_labels_from_rttm_labels( + ref_list, + hyp_list, + collar=0.0, + ignore_overlap=False, + verbose=False, + ) + assert result is not None + metric, _, (DER, _, _, _) = result + assert_der(DER, 0.0) + assert_der(metric["total"], 10.0) + assert len(metric.results_) == 2 + + @pytest.mark.unit + def test_two_files_one_miss(self): + """File1: perfect 5s. File2: complete miss 5s. + + Combined: scored=10, missed=5, DER=0.5. + """ + ref_list = [ + ("file1", _labels((0, 5, "A"))), + ("file2", _labels((0, 5, "B"))), + ] + hyp_list = [ + ("file1", _labels((0, 5, "A"))), + ("file2", []), + ] + result = score_labels_from_rttm_labels( + ref_list, + hyp_list, + collar=0.0, + ignore_overlap=False, + verbose=False, + ) + assert result is not None + _, _, (DER, _, _, MISS) = result + assert_der(DER, 0.5) + assert_der(MISS, 0.5) + + +# ─── Tests: External-engine-verified values (cross-validated) ──────────── + + +class TestExternalEngineVerifiedValues: + """Cross-validation against an external annotation library. + + All expected values in this class have been computed with the external + annotation library's reference ``DiarizationErrorRate`` (3.x), using + its ``collar=2*collar_value`` convention, and hardcoded here. + + This class does **not** import the external library; it only checks + that the md-eval engine reproduces the same numbers. + """ + + @pytest.mark.unit + def test_external_perfect(self): + """External engine: DER=0.0, total=10.0.""" + r = _score_raw([(0, 5, "A"), (5, 10, "B")], [(0, 5, "A"), (5, 10, "B")]) + assert_der(r["DER"], 0.0) + + @pytest.mark.unit + def test_external_complete_miss(self): + """External engine: DER=1.0, MISS=1.0, total=10.0.""" + r = _score_raw([(0, 5, "A"), (5, 10, "B")], []) + assert_der(r["DER"], 1.0) + assert_der(r["MISS"], 1.0) + + @pytest.mark.unit + def test_external_swap(self): + """External engine: DER=0.0 (optimal mapping), total=10.0.""" + r = _score_raw([(0, 5, "A"), (5, 10, "B")], [(0, 5, "B"), (5, 10, "A")]) + assert_der(r["DER"], 0.0) + + @pytest.mark.unit + def test_external_partial_miss(self): + """External engine: DER=0.5, MISS=0.5, total=10.0.""" + r = _score_raw([(0, 10, "A")], [(0, 5, "A")]) + assert_der(r["DER"], 0.5) + + @pytest.mark.unit + def test_external_collar_perfect(self): + """External engine: DER=0.0, total=9.0 (collar removes 1s).""" + r = _score_raw( + [(0, 5, "A"), (5, 10, "B")], + [(0, 5, "A"), (5, 10, "B")], + collar=0.25, + ) + assert_der(r["DER"], 0.0) + assert_der(r["scored"], 9.0) + + @pytest.mark.unit + def test_external_collar_offset(self): + """External engine: DER=0.0, total=9.0. 0.2s offset within 0.25s collar.""" + r = _score_raw( + [(0, 5, "A"), (5, 10, "B")], + [(0, 5.2, "A"), (5.2, 10, "B")], + collar=0.25, + ) + assert_der(r["DER"], 0.0) + + @pytest.mark.unit + def test_external_overlap_skip_perfect(self): + """External engine: DER=0.0, total=8.0. Overlap [5,7] excluded.""" + r = _score_raw( + [(0, 7, "A"), (5, 10, "B")], + [(0, 7, "A"), (5, 10, "B")], + ignore_overlap=True, + ) + assert_der(r["DER"], 0.0) + assert_der(r["scored"], 8.0) + + @pytest.mark.unit + def test_external_overlap_noskip_perfect(self): + """External engine: DER=0.0, total=12.0. Overlap scored for both speakers.""" + r = _score_raw( + [(0, 7, "A"), (5, 10, "B")], + [(0, 7, "A"), (5, 10, "B")], + ignore_overlap=False, + ) + assert_der(r["DER"], 0.0) + assert_der(r["scored"], 12.0) + + @pytest.mark.unit + def test_external_overlap_miss_skip(self): + """External engine: DER=0.375, CER=0.375, total=8.0.""" + r = _score_raw( + [(0, 7, "A"), (5, 10, "B")], + [(0, 10, "A")], + ignore_overlap=True, + ) + assert_der(r["DER"], 0.375) + assert_der(r["CER"], 0.375) + + @pytest.mark.unit + def test_external_overlap_miss_noskip(self): + """External engine: DER=5/12≈0.4167, CER=3/12=0.25, MISS=2/12≈0.1667, total=12.0.""" + r = _score_raw( + [(0, 7, "A"), (5, 10, "B")], + [(0, 10, "A")], + ignore_overlap=False, + ) + assert_der(r["DER"], 5 / 12) + assert_der(r["CER"], 3 / 12) + assert_der(r["MISS"], 2 / 12) + assert_der(r["scored"], 12.0) + + @pytest.mark.unit + def test_external_3spk_boundary(self): + """External engine: DER=0.1, CER=0.1, total=10.0.""" + r = _score_raw( + [(0, 3, "A"), (3, 7, "B"), (7, 10, "C")], + [(0, 3, "A"), (3, 6, "B"), (6, 10, "C")], + ) + assert_der(r["DER"], 0.1) + assert_der(r["CER"], 0.1) + + @pytest.mark.unit + def test_external_extra_hyp(self): + """External engine: DER=0.2, CER=0.2, total=10.0.""" + r = _score_raw( + [(0, 5, "A"), (5, 10, "B")], + [(0, 5, "A"), (5, 8, "B"), (8, 10, "C")], + ) + assert_der(r["DER"], 0.2) + assert_der(r["CER"], 0.2) + + @pytest.mark.unit + def test_external_missing_hyp(self): + """External engine: DER=0.2, CER=0.2, total=10.0.""" + r = _score_raw( + [(0, 5, "A"), (5, 8, "B"), (8, 10, "C")], + [(0, 5, "A"), (5, 10, "B")], + ) + assert_der(r["DER"], 0.2) + assert_der(r["CER"], 0.2) + + @pytest.mark.unit + def test_external_gap(self): + """External engine: DER=0.0, total=6.0.""" + r = _score_raw([(0, 3, "A"), (7, 10, "B")], [(0, 3, "A"), (7, 10, "B")]) + assert_der(r["DER"], 0.0) + assert_der(r["scored"], 6.0) + + @pytest.mark.unit + def test_external_false_alarm_in_gap(self): + """External engine: DER=1/3, FA=1/3, total=6.0.""" + r = _score_raw( + [(0, 3, "A"), (7, 10, "B")], + [(0, 3, "A"), (4, 6, "X"), (7, 10, "B")], + ) + assert_der(r["DER"], 1 / 3) + assert_der(r["FA"], 1 / 3) + + @pytest.mark.unit + def test_external_uem(self): + """External engine: DER=0.5, CER=0.5, total=6.0.""" + r = _score_raw( + [(0, 10, "A")], + [(0, 5, "A"), (5, 10, "B")], + uem_segs=[[2, 8]], + ) + assert_der(r["DER"], 0.5) + assert_der(r["CER"], 0.5) + assert_der(r["scored"], 6.0) + + @pytest.mark.unit + def test_external_collar_3spk(self): + """External engine: DER=0.0, total=7.0.""" + r = _score_raw( + [(0, 4, "A"), (4, 7, "B"), (7, 10, "C")], + [(0, 4, "A"), (4, 7, "B"), (7, 10, "C")], + collar=0.5, + ) + assert_der(r["DER"], 0.0) + assert_der(r["scored"], 7.0) + + @pytest.mark.unit + def test_external_collar_boundary_error(self): + """External engine: DER=0.5/9≈0.0556, MISS=0.5/9.""" + r = _score_raw( + [(0, 5, "A"), (5, 10, "B")], + [(0, 4.5, "A"), (5.5, 10, "B")], + collar=0.25, + ) + assert_der(r["DER"], 0.5 / 9.0) + assert_der(r["MISS"], 0.5 / 9.0) + + @pytest.mark.unit + def test_external_collar_boundary_error_large(self): + """External engine: DER=1.5/9≈0.1667, MISS=1.5/9.""" + r = _score_raw( + [(0, 5, "A"), (5, 10, "B")], + [(0, 4, "A"), (6, 10, "B")], + collar=0.25, + ) + assert_der(r["DER"], 1.5 / 9.0) + assert_der(r["MISS"], 1.5 / 9.0) + + @pytest.mark.unit + def test_external_single_speaker_confusion(self): + """External engine: DER=0.0 (optimal mapping maps B→A).""" + r = _score_raw([(0, 10, "A")], [(0, 10, "B")]) + assert_der(r["DER"], 0.0) + + @pytest.mark.unit + def test_external_multi_file(self): + """External engine, multi-file: file1 perfect + file2 relabelled → DER=0. + + Both engines map C→B via Hungarian algorithm. + """ + ref_dicts = [ + _labels_to_rttm_data("file1", _labels((0, 5, "A"))), + _labels_to_rttm_data("file2", _labels((0, 5, "B"))), + ] + sys_dicts = [ + _labels_to_rttm_data("file1", _labels((0, 5, "A"))), + _labels_to_rttm_data("file2", _labels((0, 5, "C"))), + ] + ref_data = _merge_rttm_dicts(ref_dicts) + sys_data = _merge_rttm_dicts(sys_dicts) + _, cum = evaluate(ref_data, sys_data, uem_data=None, collar=0.0, opt_1=False, verbose=False) + scored = cum.get("SCORED_SPEAKER", 0.0) or EPSILON + DER = ( + cum.get("MISSED_SPEAKER", 0.0) + cum.get("FALARM_SPEAKER", 0.0) + cum.get("SPEAKER_ERROR", 0.0) + ) / scored + assert_der(DER, 0.0) + assert_der(scored, 10.0) + + +# ─── Tests: regression for no-UEM scoring (parity with external lib) ───── + + +class TestNoUemAutoUnion: + """Regression tests for the auto-derived UEM used when no UEM is provided. + + Historically NeMo's DER was computed via an external annotation library + that, in the no-UEM path, built its scoring map from the union of the + reference and system extents. NIST ``md-eval-22.pl`` (which our + :func:`md_eval.evaluate` faithfully ports) instead defaults to the + reference extent only. The high-level wrappers in :mod:`der` bridge the + two by auto-deriving a ``ref ∪ sys`` UEM whenever the caller does not + supply one. These tests pin down that behaviour with hardcoded values + independently verified by hand and previously by the external library. + """ + + # Sortformer Diar 4spk-v1 dihard3-dev tutorial sample. + _REF = [(0.299, 2.770, "A"), (3.164, 5.147, "B")] + _HYP_RAW = [(0.400, 2.880, "spk0"), (3.200, 5.190, "spk1")] + _HYP_PP = [(0.340, 2.800, "spk0"), (3.220, 5.190, "spk1")] + + @pytest.mark.unit + def test_raw_binarization_matches_external_lib(self): + """Raw binarization output: DER must match the external-lib value 0.065110. + + Hand calculation: + ref total = 2.471 + 1.983 = 4.454 + miss = 0.099 + 0.038 = 0.137 + false alarm = 0.110 + 0.043 = 0.153 + DER = (0.137 + 0.153) / 4.454 = 0.065110 + """ + r = _score(self._REF, self._HYP_RAW, collar=0.0, ignore_overlap=False) + DER, _CER, FA, MISS = r[2] + assert_der(DER, 0.065110, tol=1e-5) + assert_der(FA, 0.153 / 4.454, tol=1e-5) + assert_der(MISS, 0.137 / 4.454, tol=1e-5) + + @pytest.mark.unit + def test_post_processed_matches_external_lib(self): + """Post-processed VAD output: DER must match the external-lib value 0.038168. + + Hand calculation: + ref total = 4.454 + miss = 0.041 + 0.056 = 0.097 + false alarm = 0.030 + 0.043 = 0.073 + DER = (0.097 + 0.073) / 4.454 = 0.038168 + """ + r = _score(self._REF, self._HYP_PP, collar=0.0, ignore_overlap=False) + DER, _CER, FA, MISS = r[2] + assert_der(DER, 0.038168, tol=1e-5) + assert_der(FA, 0.073 / 4.454, tol=1e-5) + assert_der(MISS, 0.097 / 4.454, tol=1e-5) + + @pytest.mark.unit + def test_score_labels_lhotse_path_matches_external_lib_raw(self): + """The lhotse-backed ``score_labels`` entry point must give the same answer.""" + r = _score_lhotse(self._REF, self._HYP_RAW, collar=0.0, ignore_overlap=False) + DER, _CER, _FA, _MISS = r[2] + assert_der(DER, 0.065110, tol=1e-5) + + @pytest.mark.unit + def test_score_labels_lhotse_path_matches_external_lib_post(self): + r = _score_lhotse(self._REF, self._HYP_PP, collar=0.0, ignore_overlap=False) + DER, _CER, _FA, _MISS = r[2] + assert_der(DER, 0.038168, tol=1e-5) + + @pytest.mark.unit + def test_low_level_evaluate_keeps_nist_semantics(self): + """The low-level ``evaluate`` API must keep the NIST ref-extent default. + + Power users that call ``md_eval.evaluate`` directly should still see + the strict NIST behaviour (eval map = ref extent only) when they pass + ``uem_data=None``. The auto-union behaviour is intentionally limited + to the high-level wrappers in :mod:`der`. + """ + r = _score_raw(self._REF, self._HYP_RAW, collar=0.0, ignore_overlap=False) + assert_der(r["DER"], 0.055456, tol=1e-5) + r = _score_raw(self._REF, self._HYP_PP, collar=0.0, ignore_overlap=False) + assert_der(r["DER"], 0.028514, tol=1e-5) + + @pytest.mark.unit + def test_explicit_uem_overrides_auto_union(self): + """An explicit UEM must always take precedence over the auto-derived one.""" + # Use a UEM that exactly equals the reference extent — should reproduce + # the strict NIST numbers even through the high-level wrapper. + r = _score( + self._REF, + self._HYP_RAW, + collar=0.0, + ignore_overlap=False, + uem_segs=[[0.299, 5.147]], + ) + DER, _CER, _FA, _MISS = r[2] + assert_der(DER, 0.055456, tol=1e-5) + + @pytest.mark.unit + def test_collar_is_nist_half_width_raw(self): + """``collar=X`` in NeMo means ±X seconds (NIST half-width). + + The historical NeMo public contract is: ``score_labels(collar=X)`` punches + a ``±X`` second no-score zone around every reference boundary (NIST + ``md-eval-22.pl`` semantics). External annotation libraries that define + ``collar`` as the *total* width of the no-score zone agree with NeMo + when called with ``2 * X``. + + For the tutorial sample, NeMo at ``collar=0.05`` (== ext.lib at + ``collar=0.10``) produces RAW DER = 0.026093. Pinning down the + historical value ensures we don't silently shift NeMo's published + numbers when refactoring the collar plumbing. + """ + r = _score(self._REF, self._HYP_RAW, collar=0.05, ignore_overlap=False) + DER, _CER, _FA, _MISS = r[2] + assert_der(DER, 0.026093, tol=1e-5) + + @pytest.mark.unit + def test_collar_is_nist_half_width_post(self): + """Post-processed counterpart of :meth:`test_collar_is_nist_half_width_raw`. + + NeMo at ``collar=0.05`` (== ext.lib at ``collar=0.10``) produces + POST DER = 0.001410. + """ + r = _score(self._REF, self._HYP_PP, collar=0.05, ignore_overlap=False) + DER, _CER, _FA, _MISS = r[2] + assert_der(DER, 0.001410, tol=1e-5) + + @pytest.mark.unit + def test_collar_2x_equivalence_to_external_lib(self): + """Cross-engine equivalence: NeMo ``collar=X`` ≡ external lib ``collar=2X``. + + The external library reports RAW DER = 0.043638 / POST DER = 0.016077 + when called directly with ``collar=0.10``. NeMo at ``collar=0.05`` + must produce the same numbers — the historical doubling-then-halving + round trip, made explicit by passing ``collar`` straight through to + :func:`md_eval.evaluate` (which uses NIST half-width semantics + natively). Equivalently, NeMo at ``collar=0.025`` must match the + external lib at ``collar=0.05``. + """ + # NeMo collar=0.025 <==> ext.lib collar=0.05 (RAW=0.043638) + r = _score(self._REF, self._HYP_RAW, collar=0.025, ignore_overlap=False) + DER, _CER, _FA, _MISS = r[2] + assert_der(DER, 0.043638, tol=1e-5) + # NeMo collar=0.025 <==> ext.lib collar=0.05 (POST=0.016077) + r = _score(self._REF, self._HYP_PP, collar=0.025, ignore_overlap=False) + DER, _CER, _FA, _MISS = r[2] + assert_der(DER, 0.016077, tol=1e-5) + + @pytest.mark.unit + def test_collar_lhotse_path_matches_string_path(self): + """The lhotse-backed ``score_labels`` collar semantics must agree with ``score_labels_from_rttm_labels``.""" + for collar, expected_raw, expected_post in [ + (0.05, 0.026093, 0.001410), + (0.025, 0.043638, 0.016077), + ]: + r_raw = _score_lhotse(self._REF, self._HYP_RAW, collar=collar, ignore_overlap=False) + assert_der(r_raw[2][0], expected_raw, tol=1e-5) + r_post = _score_lhotse(self._REF, self._HYP_PP, collar=collar, ignore_overlap=False) + assert_der(r_post[2][0], expected_post, tol=1e-5) + + @pytest.mark.unit + def test_default_uem_helper_builds_union(self): + """The internal ``_default_uem_from_ref_sys`` builds the right span.""" + from nemo.collections.asr.metrics.der import _default_uem_from_ref_sys + + ref_data = _merge_rttm_dicts([_labels_to_rttm_data("file1", _labels(*self._REF))]) + sys_data = _merge_rttm_dicts([_labels_to_rttm_data("file1", _labels(*self._HYP_RAW))]) + uem = _default_uem_from_ref_sys(ref_data, sys_data) + assert "file1" in uem + # The ref ends at 5.147, sys ends at 5.190 — auto-union picks 5.190. + # The ref starts at 0.299, sys starts at 0.400 — auto-union picks 0.299. + seg = uem["file1"]["1"][0] + assert abs(seg["TBEG"] - 0.299) < 1e-9 + assert abs(seg["TEND"] - 5.190) < 1e-9 + + +# ─── Tests: lhotse-based replacement for the external annotation lib ───── + + +class TestLhotseShimHelpers: + """Unit tests for the lhotse-based shim helpers in der.py. + + These helpers (``make_diar_segment``, ``make_diar_annotation``, + ``make_uem_timeline``, ``unique_speakers``, ``write_supervisions_to_rttm``) + replace the ``Annotation`` / ``Segment`` / ``Timeline`` types from the + external annotation library that NeMo previously depended on. + """ + + @pytest.mark.unit + def test_make_diar_segment_basic(self): + seg = make_diar_segment(1.5, 4.0, "spk0", recording_id="rec1") + assert isinstance(seg, SupervisionSegment) + assert seg.start == 1.5 + assert seg.duration == 2.5 + assert seg.end == 4.0 + assert seg.speaker == "spk0" + assert seg.recording_id == "rec1" + + @pytest.mark.unit + def test_make_diar_segment_zero_duration_clamped(self): + """Inverted/zero spans clamp to 0 duration (no negative durations).""" + seg = make_diar_segment(5.0, 5.0, "A") + assert seg.duration == 0.0 + seg2 = make_diar_segment(5.0, 4.0, "A") + assert seg2.duration == 0.0 + + @pytest.mark.unit + def test_make_diar_segment_auto_id(self): + """When ``segment_id`` is None, a deterministic id is generated.""" + s1 = make_diar_segment(0.0, 1.0, "A", recording_id="r") + s2 = make_diar_segment(0.0, 1.0, "A", recording_id="r") + assert s1.id == s2.id + s3 = make_diar_segment(0.0, 2.0, "A", recording_id="r") + assert s1.id != s3.id + + @pytest.mark.unit + def test_make_diar_annotation_from_labels(self): + labels = ["0.0 5.0 A", "5.0 10.0 B", "10.0 12.5 A"] + ann = make_diar_annotation(labels, uniq_name="rec42") + assert isinstance(ann, list) + assert len(ann) == 3 + assert all(isinstance(s, SupervisionSegment) for s in ann) + assert all(s.recording_id == "rec42" for s in ann) + assert [s.speaker for s in ann] == ["A", "B", "A"] + assert [s.start for s in ann] == [0.0, 5.0, 10.0] + assert [s.end for s in ann] == [5.0, 10.0, 12.5] + + @pytest.mark.unit + def test_make_diar_annotation_skips_malformed(self): + """Lines with fewer than 3 tokens are ignored (defensive).""" + labels = ["0.0 5.0 A", "garbage", "", "5.0 10.0 B"] + ann = make_diar_annotation(labels, uniq_name="r") + assert len(ann) == 2 + assert [s.speaker for s in ann] == ["A", "B"] + + @pytest.mark.unit + def test_make_uem_timeline_basic(self): + uem = make_uem_timeline([[0.0, 5.0], [10.0, 12.0]], uniq_id="rec1") + assert len(uem) == 2 + assert all(isinstance(s, SupervisionSegment) for s in uem) + assert all(s.speaker == "UEM" for s in uem) + assert all(s.recording_id == "rec1" for s in uem) + assert (uem[0].start, uem[0].end) == (0.0, 5.0) + assert (uem[1].start, uem[1].end) == (10.0, 12.0) + + @pytest.mark.unit + def test_make_uem_timeline_empty(self): + assert make_uem_timeline([], uniq_id="r") == [] + + @pytest.mark.unit + def test_unique_speakers_preserves_first_seen_order(self): + ann = make_diar_annotation(["0 1 B", "1 2 A", "2 3 B", "3 4 C", "4 5 A"], uniq_name="r") + # First-seen order: B, A, C + assert unique_speakers(ann) == ["B", "A", "C"] + + @pytest.mark.unit + def test_unique_speakers_on_supervision_set(self): + ann = make_diar_annotation(["0 1 A", "1 2 B"], uniq_name="r") + ss = SupervisionSet.from_segments(ann) + assert sorted(unique_speakers(ss)) == ["A", "B"] + + @pytest.mark.unit + def test_unique_speakers_on_empty(self): + assert unique_speakers([]) == [] + + @pytest.mark.unit + def test_write_supervisions_to_rttm_format(self): + ann = make_diar_annotation(["0.0 1.5 A", "1.5 3.0 B"], uniq_name="rec1") + buf = io.StringIO() + write_supervisions_to_rttm(ann, buf) + lines = [ln for ln in buf.getvalue().splitlines() if ln.strip()] + assert len(lines) == 2 + # Each line follows: SPEAKER + for ln in lines: + parts = ln.split() + assert parts[0] == "SPEAKER" + assert parts[1] == "rec1" + assert parts[2] == "1" + assert parts[5] == "" and parts[6] == "" + assert parts[8] == "" and parts[9] == "" + # Verify start/dur/speaker on the first line + p0 = lines[0].split() + assert float(p0[3]) == pytest.approx(0.0) + assert float(p0[4]) == pytest.approx(1.5) + assert p0[7] == "A" + + @pytest.mark.unit + def test_write_supervisions_to_rttm_skips_zero_duration(self): + ann = [ + make_diar_segment(0.0, 1.0, "A", recording_id="rec1"), + make_diar_segment(2.0, 2.0, "B", recording_id="rec1"), # zero-duration + make_diar_segment(3.0, 4.5, "C", recording_id="rec1"), + ] + buf = io.StringIO() + write_supervisions_to_rttm(ann, buf) + lines = [ln for ln in buf.getvalue().splitlines() if ln.strip()] + assert len(lines) == 2 + speakers = [ln.split()[7] for ln in lines] + assert speakers == ["A", "C"] + + @pytest.mark.unit + def test_write_supervisions_to_rttm_explicit_recording_id_override(self): + """Explicit ``recording_id`` overrides per-segment ids.""" + ann = make_diar_annotation(["0 1 A"], uniq_name="orig") + buf = io.StringIO() + write_supervisions_to_rttm(ann, buf, recording_id="overridden") + line = buf.getvalue().strip() + assert line.split()[1] == "overridden" + + @pytest.mark.unit + def test_write_supervisions_to_rttm_round_trip(self): + """Write annotations to RTTM, then read them back via lhotse. + + Verifies our RTTM output is parseable by lhotse's RTTM reader, + confirming we follow the same format conventions. + """ + ann = make_diar_annotation(["0.0 2.5 alice", "2.5 5.0 bob", "5.0 7.25 alice"], uniq_name="conv1") + # Write to a temp file (lhotse only reads from path objects). + import tempfile + + with tempfile.NamedTemporaryFile("w", suffix=".rttm", delete=False) as fh: + write_supervisions_to_rttm(ann, fh) + tmp_path = fh.name + try: + parsed = SupervisionSet.from_rttm(tmp_path) + parsed_segs = sorted(list(parsed), key=lambda s: s.start) + finally: + import os + + os.unlink(tmp_path) + assert len(parsed_segs) == 3 + assert [s.speaker for s in parsed_segs] == ["alice", "bob", "alice"] + assert [s.start for s in parsed_segs] == pytest.approx([0.0, 2.5, 5.0]) + assert [s.end for s in parsed_segs] == pytest.approx([2.5, 5.0, 7.25]) + + +class TestIterAnnotationSegments: + """Verify ``md_eval._iter_annotation_segments`` accepts every supported type.""" + + @pytest.mark.unit + def test_iter_list_of_supervision_segments(self): + ann = make_diar_annotation(["0 1 A", "1 3 B"], uniq_name="r") + out = list(_iter_annotation_segments(ann)) + assert out == [(0.0, 1.0, "A"), (1.0, 3.0, "B")] + + @pytest.mark.unit + def test_iter_supervision_set(self): + ann = make_diar_annotation(["0 1 A", "1 3 B"], uniq_name="r") + ss = SupervisionSet.from_segments(ann) + out = list(_iter_annotation_segments(ss)) + assert sorted(out) == [(0.0, 1.0, "A"), (1.0, 3.0, "B")] + + @pytest.mark.unit + def test_iter_duck_typed_objects_with_end(self): + """Plain dataclass-like objects with ``.start``, ``.end``, ``.speaker``.""" + + class _DT: + def __init__(self, start, end, speaker): + self.start = start + self.end = end + self.speaker = speaker + + ann = [_DT(0.0, 2.0, "X"), _DT(2.0, 5.0, "Y")] + assert list(_iter_annotation_segments(ann)) == [(0.0, 2.0, "X"), (2.0, 5.0, "Y")] + + @pytest.mark.unit + def test_iter_duck_typed_objects_with_duration(self): + """Objects exposing ``.duration`` (no ``.end``) are also accepted.""" + + class _DT: + def __init__(self, start, duration, speaker): + self.start = start + self.duration = duration + self.speaker = speaker + self.end = None + + ann = [_DT(0.0, 2.0, "X"), _DT(2.0, 3.0, "Y")] + assert list(_iter_annotation_segments(ann)) == [(0.0, 2.0, "X"), (2.0, 5.0, "Y")] + + @pytest.mark.unit + def test_iter_legacy_itertracks_object(self): + """Objects exposing ``.itertracks(yield_label=True)`` (legacy path). + + This duck-typed fallback keeps backwards compatibility with the + external annotation library's ``Annotation`` API. + """ + + class _Seg: + def __init__(self, s, e): + self.start = s + self.end = e + + class _Ann: + def __init__(self, items): + self._items = items + + def itertracks(self, yield_label=True): + for s, e, spk in self._items: + yield _Seg(s, e), "track", spk + + ann = _Ann([(0.0, 1.5, "A"), (1.5, 4.0, "B")]) + assert list(_iter_annotation_segments(ann)) == [(0.0, 1.5, "A"), (1.5, 4.0, "B")] + + @pytest.mark.unit + def test_iter_missing_end_and_duration_raises(self): + class _Bad: + def __init__(self): + self.start = 0.0 + self.speaker = "A" + + with pytest.raises(TypeError, match="end.*duration"): + list(_iter_annotation_segments([_Bad()])) + + @pytest.mark.unit + def test_iter_missing_speaker_raises(self): + class _Bad: + def __init__(self): + self.start = 0.0 + self.end = 1.0 + + with pytest.raises(TypeError, match="speaker"): + list(_iter_annotation_segments([_Bad()])) + + +class TestLhotseAnnotation: + """End-to-end DER tests using lhotse SupervisionSegment annotations. + + Every scenario here is also covered by the legacy label-string tests + above (``TestScoreLabelsFromRttmLabels``); we re-run them through the + new lhotse pipeline (``score_labels`` + ``make_diar_annotation`` + + ``make_uem_timeline``) and assert **bit-identical** DER/CER/FA/MISS. + Any divergence here means the lhotse adapter has regressed. + """ + + @pytest.mark.unit + def test_perfect_match(self): + metric, mapping, (DER, CER, FA, MISS) = _score_lhotse( + [(0, 5, "A"), (5, 10, "B")], + [(0, 5, "A"), (5, 10, "B")], + ) + assert isinstance(metric, DiarizationErrorResult) + assert isinstance(mapping, dict) + assert_der(DER, 0.0) + assert_der(CER, 0.0) + assert_der(FA, 0.0) + assert_der(MISS, 0.0) + + @pytest.mark.unit + def test_complete_miss(self): + _, _, (DER, _, _, MISS) = _score_lhotse([(0, 5, "A"), (5, 10, "B")], []) + assert_der(DER, 1.0) + assert_der(MISS, 1.0) + + @pytest.mark.unit + def test_speaker_swap_optimal_mapping(self): + _, _, (DER, _, _, _) = _score_lhotse( + [(0, 5, "A"), (5, 10, "B")], + [(0, 5, "B"), (5, 10, "A")], + ) + assert_der(DER, 0.0) + + @pytest.mark.unit + def test_partial_miss(self): + _, _, (DER, _, _, MISS) = _score_lhotse([(0, 10, "A")], [(0, 5, "A")]) + assert_der(DER, 0.5) + assert_der(MISS, 0.5) + + @pytest.mark.unit + def test_partial_false_alarm(self): + """Hyp extends past ref (FA region) — scored only when UEM covers it. + + With an explicit UEM covering [0, 10], the [5, 10] hyp region becomes + a false alarm: FA = 5 / scored(5) = 1.0, DER = 1.0. + Without a UEM, md-eval restricts evaluation to the reference extent + [0, 5] and the extra hyp is not scored — so this test makes the UEM + explicit to keep the FA assertion meaningful. + """ + _, _, (DER, _, FA, _) = _score_lhotse( + [(0, 5, "A")], + [(0, 5, "A"), (5, 10, "A")], + uem_segs=[[0, 10]], + ) + assert_der(DER, 1.0) + assert_der(FA, 1.0) + + @pytest.mark.unit + def test_confusion(self): + _, _, (DER, CER, _, _) = _score_lhotse( + [(0, 5, "A"), (5, 10, "B")], + [(0, 5, "A"), (5, 10, "A")], + ) + assert_der(DER, 0.5) + assert_der(CER, 0.5) + + @pytest.mark.unit + def test_collar_eliminates_boundary_error(self): + _, _, (DER, _, _, _) = _score_lhotse( + [(0, 5, "A"), (5, 10, "B")], + [(0, 4.9, "A"), (5.1, 10, "B")], + collar=0.25, + ) + assert_der(DER, 0.0) + + @pytest.mark.unit + def test_collar_partial(self): + _, _, (DER, _, _, MISS) = _score_lhotse( + [(0, 5, "A"), (5, 10, "B")], + [(0, 4, "A"), (6, 10, "B")], + collar=0.25, + ) + assert_der(DER, 1.5 / 9.0) + assert_der(MISS, 1.5 / 9.0) + + @pytest.mark.unit + def test_uem_restricts(self): + _, _, (DER, CER, _, _) = _score_lhotse( + [(0, 10, "A")], + [(0, 5, "A"), (5, 10, "B")], + uem_segs=[[2, 8]], + ) + assert_der(DER, 0.5) + assert_der(CER, 0.5) + + @pytest.mark.unit + def test_extra_hyp_speaker(self): + _, _, (DER, CER, _, _) = _score_lhotse( + [(0, 5, "A"), (5, 10, "B")], + [(0, 5, "A"), (5, 8, "B"), (8, 10, "C")], + ) + assert_der(DER, 0.2) + assert_der(CER, 0.2) + + @pytest.mark.unit + def test_missing_hyp_speaker(self): + _, _, (DER, CER, _, _) = _score_lhotse( + [(0, 5, "A"), (5, 8, "B"), (8, 10, "C")], + [(0, 5, "A"), (5, 10, "B")], + ) + assert_der(DER, 0.2) + assert_der(CER, 0.2) + + @pytest.mark.unit + def test_ignore_overlap(self): + """``ignore_overlap=True`` should suppress overlap-region scoring.""" + _, _, (DER_no, _, _, _) = _score_lhotse( + [(0, 5, "A"), (3, 7, "B")], + [(0, 5, "A"), (3, 7, "B")], + ignore_overlap=False, + ) + _, _, (DER_yes, _, _, _) = _score_lhotse( + [(0, 5, "A"), (3, 7, "B")], + [(0, 5, "A"), (3, 7, "B")], + ignore_overlap=True, + ) + assert_der(DER_no, 0.0) + assert_der(DER_yes, 0.0) + + @pytest.mark.unit + def test_accepts_supervision_set(self): + """``score_labels`` should accept a ``SupervisionSet`` directly.""" + ref = SupervisionSet.from_segments(make_diar_annotation(["0 5 A", "5 10 B"], uniq_name="f1")) + hyp = SupervisionSet.from_segments(make_diar_annotation(["0 5 A", "5 10 B"], uniq_name="f1")) + result = score_labels( + {"f1": {}}, + [("f1", ref)], + [("f1", hyp)], + collar=0.0, + ignore_overlap=False, + verbose=False, + ) + assert result is not None + _, _, (DER, _, _, _) = result + assert_der(DER, 0.0) + + @pytest.mark.unit + def test_multi_file_scoring(self): + """Two files, one perfect and one with confusion → averaged DER.""" + f1_ref = make_diar_annotation(["0 5 A"], uniq_name="f1") + f1_hyp = make_diar_annotation(["0 5 A"], uniq_name="f1") + f2_ref = make_diar_annotation(["0 4 A", "4 8 B"], uniq_name="f2") + f2_hyp = make_diar_annotation(["0 8 A"], uniq_name="f2") + result = score_labels( + {"f1": {}, "f2": {}}, + [("f1", f1_ref), ("f2", f2_ref)], + [("f1", f1_hyp), ("f2", f2_hyp)], + collar=0.0, + ignore_overlap=False, + verbose=False, + ) + assert result is not None + metric, _, (DER, _, _, _) = result + # f1: perfect (0/5). f2: B confused with A across [4,8] → 4/8. + # Combined: confusion=4 / scored=13 = 4/13. + assert_der(DER, 4.0 / 13.0) + assert len(metric.results_) == 2 + + +class TestLhotseStringEquivalence: + """The lhotse path and the legacy label-string path must agree on every metric. + + Same reference + hypothesis fed through both ``score_labels`` (lhotse + annotations) and ``score_labels_from_rttm_labels`` (label strings) must + produce bit-identical (DER, CER, FA, MISS). + """ + + @staticmethod + def _both(ref_segs, hyp_segs, **kw): + string_path = _score(ref_segs, hyp_segs, **kw) + lhotse_path = _score_lhotse(ref_segs, hyp_segs, **kw) + return string_path[2], lhotse_path[2] # itemized errors + + @pytest.mark.unit + def test_perfect(self): + string_path, lhotse_path = self._both([(0, 5, "A"), (5, 10, "B")], [(0, 5, "A"), (5, 10, "B")]) + assert string_path == pytest.approx(lhotse_path) + + @pytest.mark.unit + def test_complete_miss(self): + string_path, lhotse_path = self._both([(0, 5, "A"), (5, 10, "B")], []) + assert string_path == pytest.approx(lhotse_path) + + @pytest.mark.unit + def test_speaker_swap(self): + string_path, lhotse_path = self._both( + [(0, 5, "A"), (5, 10, "B")], + [(0, 5, "B"), (5, 10, "A")], + ) + assert string_path == pytest.approx(lhotse_path) + + @pytest.mark.unit + def test_collar(self): + string_path, lhotse_path = self._both( + [(0, 5, "A"), (5, 10, "B")], + [(0, 4.9, "A"), (5.1, 10, "B")], + collar=0.25, + ) + assert string_path == pytest.approx(lhotse_path) + + @pytest.mark.unit + def test_uem(self): + string_path, lhotse_path = self._both( + [(0, 10, "A")], + [(0, 5, "A"), (5, 10, "B")], + uem_segs=[[2, 8]], + ) + assert string_path == pytest.approx(lhotse_path) + + @pytest.mark.unit + def test_ignore_overlap(self): + string_path, lhotse_path = self._both( + [(0, 5, "A"), (3, 7, "B")], + [(0, 4, "A"), (3, 7, "B")], + ignore_overlap=True, + ) + assert string_path == pytest.approx(lhotse_path) + + @pytest.mark.unit + def test_three_speakers(self): + string_path, lhotse_path = self._both( + [(0, 4, "A"), (4, 7, "B"), (7, 10, "C")], + [(0, 4, "A"), (4, 7, "B"), (7, 10, "C")], + collar=0.5, + ) + assert string_path == pytest.approx(lhotse_path) + + @pytest.mark.unit + def test_extra_hyp_speaker(self): + string_path, lhotse_path = self._both( + [(0, 5, "A"), (5, 10, "B")], + [(0, 5, "A"), (5, 8, "B"), (8, 10, "C")], + ) + assert string_path == pytest.approx(lhotse_path) diff --git a/tests/collections/speaker_tasks/utils/test_vad_utils_speaker.py b/tests/collections/speaker_tasks/utils/test_vad_utils_speaker.py index a7672e1aa43d..46afc4d3b013 100644 --- a/tests/collections/speaker_tasks/utils/test_vad_utils_speaker.py +++ b/tests/collections/speaker_tasks/utils/test_vad_utils_speaker.py @@ -14,17 +14,17 @@ import numpy as np import pytest -from pyannote.core import Annotation, Segment +from lhotse import SupervisionSegment from nemo.collections.asr.parts.utils.vad_utils import ( align_labels_to_frames, convert_labels_to_speech_segments, - frame_vad_construct_pyannote_object_per_file, + frame_vad_construct_supervisions_per_file, get_frame_labels, get_nonspeech_segments, load_speech_overlap_segments_from_rttm, load_speech_segments_from_rttm, - read_rttm_as_pyannote_object, + read_rttm_as_supervisions, ) @@ -101,26 +101,41 @@ def test_convert_labels_to_speech_segments(self, test_data_dir): assert speech_segments_new == speech_segments @pytest.mark.unit - def test_read_rttm_as_pyannote_object(self, test_data_dir): + def test_read_rttm_as_supervisions(self, test_data_dir): rttm_file, speech_segments = get_simple_rttm_without_overlap(test_data_dir + "/test6.rttm") - pyannote_object = read_rttm_as_pyannote_object(rttm_file) - pyannote_object_gt = Annotation() - pyannote_object_gt[Segment(0.0, 2.0)] = 'speech' - assert pyannote_object == pyannote_object_gt + annotation = read_rttm_as_supervisions(rttm_file) + assert _annotation_equals(annotation, [(0.0, 2.0, 'speech')]) @pytest.mark.unit - def test_frame_vad_construct_pyannote_object_per_file(self, test_data_dir): + def test_frame_vad_construct_supervisions_per_file(self, test_data_dir): rttm_file, speech_segments = get_simple_rttm_without_overlap(test_data_dir + "/test7.rttm") # test for rttm input - ref, hyp = frame_vad_construct_pyannote_object_per_file(rttm_file, rttm_file) - pyannote_object_gt = Annotation() - pyannote_object_gt[Segment(0.0, 2.0)] = 'speech' - assert ref == hyp == pyannote_object_gt + ref, hyp = frame_vad_construct_supervisions_per_file(rttm_file, rttm_file) + expected = [(0.0, 2.0, 'speech')] + assert _annotation_equals(ref, expected) + assert _annotation_equals(hyp, expected) # test for list input speech_segments = load_speech_segments_from_rttm(rttm_file) frame_labels = get_frame_labels(speech_segments, 0.02, 0.0, 3.0, as_str=False) speech_segments_new = convert_labels_to_speech_segments(frame_labels, 0.02) assert speech_segments_new == speech_segments - ref, hyp = frame_vad_construct_pyannote_object_per_file(frame_labels, frame_labels, 0.02) - assert ref == hyp == pyannote_object_gt + ref, hyp = frame_vad_construct_supervisions_per_file(frame_labels, frame_labels, 0.02) + assert _annotation_equals(ref, expected) + assert _annotation_equals(hyp, expected) + + +def _annotation_equals(annotation, expected_segments, *, atol=1e-6): + """Compare a list of :class:`lhotse.SupervisionSegment` to expected ``(start, end, speaker)`` tuples.""" + assert isinstance(annotation, list) + assert all(isinstance(s, SupervisionSegment) for s in annotation) + if len(annotation) != len(expected_segments): + return False + for seg, (exp_start, exp_end, exp_spk) in zip(annotation, expected_segments): + if abs(float(seg.start) - exp_start) > atol: + return False + if abs(float(seg.end) - exp_end) > atol: + return False + if seg.speaker != exp_spk: + return False + return True diff --git a/tutorials/speaker_tasks/End_to_End_Diarization_Inference.ipynb b/tutorials/speaker_tasks/End_to_End_Diarization_Inference.ipynb index c2c46674582a..155dac4c974f 100644 --- a/tutorials/speaker_tasks/End_to_End_Diarization_Inference.ipynb +++ b/tutorials/speaker_tasks/End_to_End_Diarization_Inference.ipynb @@ -266,8 +266,8 @@ "import os\n", "import wget\n", "import pandas as pd\n", - "from nemo.collections.asr.parts.utils.speaker_utils import rttm_to_labels, labels_to_pyannote_object\n", - "from pyannote.metrics.diarization import DiarizationErrorRate\n", + "from nemo.collections.asr.parts.utils.speaker_utils import rttm_to_labels\n", + "from nemo.collections.asr.metrics.der import score_labels_from_rttm_labels\n", "\n", "ROOT = os.getcwd()\n", "data_dir = os.path.join(ROOT,'data')\n", @@ -365,16 +365,22 @@ "# Get the refernce labels from ground-truth RTTM file\n", "ref_labels = rttm_to_labels(an4_rttm)\n", "\n", - "reference = labels_to_pyannote_object(ref_labels, uniq_name=\"binarize\")\n", - "hypothesis1 = labels_to_pyannote_object(pred_list_bn[0], uniq_name=\"binarize\")\n", - "der_metric1 = DiarizationErrorRate(collar=0, skip_overlap=False)\n", - "der_metric1(reference, hypothesis1, detailed=True)\n", + "der_metric1, _, _ = score_labels_from_rttm_labels(\n", + " ref_labels_list=[(\"binarize\", ref_labels)],\n", + " hyp_labels_list=[(\"binarize\", pred_list_bn[0])],\n", + " collar=0.0,\n", + " ignore_overlap=False,\n", + " verbose=False,\n", + ")\n", "print(f\"Raw Binarization DER: {abs(der_metric1):.6f}\")\n", "\n", - "reference = labels_to_pyannote_object(ref_labels, uniq_name=\"post_processing\")\n", - "hypothesis2 = labels_to_pyannote_object(pred_list_pp[0], uniq_name=\"post_processing\")\n", - "der_metric2 = DiarizationErrorRate(collar=0, skip_overlap=False)\n", - "der_metric2(reference, hypothesis2, detailed=True)\n", + "der_metric2, _, _ = score_labels_from_rttm_labels(\n", + " ref_labels_list=[(\"post_processing\", ref_labels)],\n", + " hyp_labels_list=[(\"post_processing\", pred_list_pp[0])],\n", + " collar=0.0,\n", + " ignore_overlap=False,\n", + " verbose=False,\n", + ")\n", "print(f\"Post-Processing DER: {abs(der_metric2):.6f}\")" ] }, @@ -388,7 +394,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "nemo012826", "language": "python", "name": "python3" }, @@ -402,7 +408,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.12" + "version": "3.10.19" }, "pycharm": { "stem_cell": { @@ -412,11 +418,6 @@ }, "source": [] } - }, - "vscode": { - "interpreter": { - "hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49" - } } }, "nbformat": 4, diff --git a/tutorials/speaker_tasks/End_to_End_Diarization_Training.ipynb b/tutorials/speaker_tasks/End_to_End_Diarization_Training.ipynb index b4124032df24..354cf1208672 100644 --- a/tutorials/speaker_tasks/End_to_End_Diarization_Training.ipynb +++ b/tutorials/speaker_tasks/End_to_End_Diarization_Training.ipynb @@ -604,7 +604,7 @@ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import librosa\n", - "from nemo.collections.asr.parts.utils.speaker_utils import rttm_to_labels, labels_to_pyannote_object\n", + "from nemo.collections.asr.parts.utils.speaker_utils import rttm_to_labels, labels_to_supervisions\n", "\n", "ROOT = os.getcwd()\n", "data_dir = os.path.join(ROOT,'simulated_train')\n", @@ -642,7 +642,7 @@ "source": [ "# display speaker labels for reference\n", "labels = rttm_to_labels(rttm)\n", - "reference = labels_to_pyannote_object(labels)\n", + "reference = labels_to_supervisions(labels)\n", "reference" ] }, diff --git a/tutorials/tools/Multispeaker_Simulator.ipynb b/tutorials/tools/Multispeaker_Simulator.ipynb index 234c2441c8ac..2c83da82914a 100644 --- a/tutorials/tools/Multispeaker_Simulator.ipynb +++ b/tutorials/tools/Multispeaker_Simulator.ipynb @@ -268,7 +268,7 @@ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import librosa\n", - "from nemo.collections.asr.parts.utils.speaker_utils import rttm_to_labels, labels_to_pyannote_object\n", + "from nemo.collections.asr.parts.utils.speaker_utils import rttm_to_labels, labels_to_supervisions\n", "\n", "ROOT = os.getcwd()\n", "data_dir = os.path.join(ROOT,'simulated_data')\n", @@ -306,7 +306,7 @@ "source": [ "#display speaker labels for reference\n", "labels = rttm_to_labels(rttm)\n", - "reference = labels_to_pyannote_object(labels)\n", + "reference = labels_to_supervisions(labels)\n", "reference" ] }, diff --git a/uv.lock b/uv.lock index dfe671decabb..f658bb0ce8e1 100644 --- a/uv.lock +++ b/uv.lock @@ -4726,8 +4726,6 @@ all = [ { name = "peft" }, { name = "pesq", marker = "platform_machine != 'x86_64' or sys_platform != 'darwin' or (extra == 'extra-12-nemo-toolkit-cu12' and extra == 'extra-12-nemo-toolkit-cu13')" }, { name = "progress" }, - { name = "pyannote-core" }, - { name = "pyannote-metrics" }, { name = "pydub" }, { name = "pyloudnorm" }, { name = "pyopenjtalk" }, @@ -4789,8 +4787,6 @@ asr = [ { name = "packaging" }, { name = "pandas" }, { name = "peft" }, - { name = "pyannote-core" }, - { name = "pyannote-metrics" }, { name = "pydub" }, { name = "pyloudnorm" }, { name = "resampy" }, @@ -4821,8 +4817,6 @@ asr-only = [ { name = "marshmallow" }, { name = "optuna" }, { name = "packaging" }, - { name = "pyannote-core" }, - { name = "pyannote-metrics" }, { name = "pydub" }, { name = "pyloudnorm" }, { name = "resampy" }, @@ -4963,8 +4957,6 @@ slu = [ { name = "pandas" }, { name = "peft" }, { name = "progress" }, - { name = "pyannote-core" }, - { name = "pyannote-metrics" }, { name = "pydub" }, { name = "pyloudnorm" }, { name = "resampy" }, @@ -5021,8 +5013,6 @@ speechlm2 = [ { name = "packaging" }, { name = "pandas" }, { name = "peft" }, - { name = "pyannote-core" }, - { name = "pyannote-metrics" }, { name = "pydub" }, { name = "pyloudnorm" }, { name = "pyopenjtalk" }, @@ -5133,8 +5123,6 @@ tts = [ { name = "packaging" }, { name = "pandas" }, { name = "peft" }, - { name = "pyannote-core" }, - { name = "pyannote-metrics" }, { name = "pydub" }, { name = "pyloudnorm" }, { name = "pyopenjtalk" }, @@ -5453,18 +5441,6 @@ requires-dist = [ { name = "pesq", marker = "(platform_machine != 'x86_64' and extra == 'audio') or (sys_platform != 'darwin' and extra == 'audio')" }, { name = "progress", marker = "extra == 'all'", specifier = ">=1.5" }, { name = "progress", marker = "extra == 'slu'", specifier = ">=1.5" }, - { name = "pyannote-core", marker = "extra == 'all'" }, - { name = "pyannote-core", marker = "extra == 'asr'" }, - { name = "pyannote-core", marker = "extra == 'asr-only'" }, - { name = "pyannote-core", marker = "extra == 'slu'" }, - { name = "pyannote-core", marker = "extra == 'speechlm2'" }, - { name = "pyannote-core", marker = "extra == 'tts'" }, - { name = "pyannote-metrics", marker = "extra == 'all'" }, - { name = "pyannote-metrics", marker = "extra == 'asr'" }, - { name = "pyannote-metrics", marker = "extra == 'asr-only'" }, - { name = "pyannote-metrics", marker = "extra == 'slu'" }, - { name = "pyannote-metrics", marker = "extra == 'speechlm2'" }, - { name = "pyannote-metrics", marker = "extra == 'tts'" }, { name = "pydub", marker = "extra == 'all'" }, { name = "pydub", marker = "extra == 'asr'" }, { name = "pydub", marker = "extra == 'asr-only'" }, @@ -7145,55 +7121,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/8e/37/efad0257dc6e593a18957422533ff0f87ede7c9c6ea010a2177d738fb82f/pure_eval-0.2.3-py3-none-any.whl", hash = "sha256:1db8e35b67b3d218d818ae653e27f06c3aa420901fa7b081ca98cbedc874e0d0", size = 11842, upload-time = "2024-07-21T12:58:20.04Z" }, ] -[[package]] -name = "pyannote-core" -version = "6.0.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11' or (extra == 'extra-12-nemo-toolkit-cu12' and extra == 'extra-12-nemo-toolkit-cu13')" }, - { name = "numpy", version = "2.4.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11' or (extra == 'extra-12-nemo-toolkit-cu12' and extra == 'extra-12-nemo-toolkit-cu13')" }, - { name = "pandas" }, - { name = "sortedcontainers" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/a3/be/4a35ea31c685aef801f7f35c193e7766ca1bb948ae497a625cbfaa8c31ba/pyannote_core-6.0.1.tar.gz", hash = "sha256:4b4ada3276f6df4e073fa79166636e3597d0dcb5a0fe26014a3477867cc033fb", size = 327540, upload-time = "2025-09-16T09:24:39.081Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/ea/57/ecf62344b9b81debd0ca95ed987135e93d1b039507f8174f52d1d19d8c6b/pyannote_core-6.0.1-py3-none-any.whl", hash = "sha256:924550d6ecf6b05ad13bf3f66f59c29fc740cf1c62a6fca860ac2e66908203e5", size = 57505, upload-time = "2025-09-16T09:24:37.798Z" }, -] - -[[package]] -name = "pyannote-database" -version = "6.1.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "pandas" }, - { name = "pyannote-core" }, - { name = "pyyaml" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/65/45/6210274c187cc457e854be8b56c6819fa14376f27e7e2b6021b2aa02449a/pyannote_database-6.1.1.tar.gz", hash = "sha256:bbe76da738257a9e64061123d9694ad7e949c4f171d91a9269606d873528cd10", size = 112225, upload-time = "2025-12-07T06:33:10.296Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/04/bf/6a6f5abaa4d9f803f34c9883ef5e316624eac6be0eaa87720216be9bba12/pyannote_database-6.1.1-py3-none-any.whl", hash = "sha256:36460c70ce9f50ff25c9ea365bc83ad625bb6b2494deccf6bd3fc750686ae684", size = 53735, upload-time = "2025-12-07T06:33:11.578Z" }, -] - -[[package]] -name = "pyannote-metrics" -version = "4.0.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11' or (extra == 'extra-12-nemo-toolkit-cu12' and extra == 'extra-12-nemo-toolkit-cu13')" }, - { name = "numpy", version = "2.4.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11' or (extra == 'extra-12-nemo-toolkit-cu12' and extra == 'extra-12-nemo-toolkit-cu13')" }, - { name = "pandas" }, - { name = "pyannote-core" }, - { name = "pyannote-database" }, - { name = "scikit-learn", version = "1.7.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11' or (extra == 'extra-12-nemo-toolkit-cu12' and extra == 'extra-12-nemo-toolkit-cu13')" }, - { name = "scikit-learn", version = "1.8.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11' or (extra == 'extra-12-nemo-toolkit-cu12' and extra == 'extra-12-nemo-toolkit-cu13')" }, - { name = "scipy", version = "1.15.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11' or (extra == 'extra-12-nemo-toolkit-cu12' and extra == 'extra-12-nemo-toolkit-cu13')" }, - { name = "scipy", version = "1.17.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11' or (extra == 'extra-12-nemo-toolkit-cu12' and extra == 'extra-12-nemo-toolkit-cu13')" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/3a/98/f8962bb2f5826c9798212797b0fa96ff02f81573a2d7cf1f5b678d6c55a2/pyannote_metrics-4.0.0.tar.gz", hash = "sha256:aec037eb7ca4c0ad5c5bbcc19bc04e9acf24ba42c95f025497378e31db6a0ff4", size = 879283, upload-time = "2025-09-09T14:38:27.073Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/b9/d5/637f67578fd704e27ca2edc45c5b0ad6433684916c08cd7fa54d07482407/pyannote_metrics-4.0.0-py3-none-any.whl", hash = "sha256:618dd4c778cb6a92b809c9aa79ee9b93f12dbe3b11e273431b094b10c53c8dd9", size = 49749, upload-time = "2025-09-09T14:38:24.592Z" }, -] - [[package]] name = "pyarrow" version = "21.0.0"