From 7453ba86aef1286cf63702110104f0a167c607d0 Mon Sep 17 00:00:00 2001 From: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Date: Fri, 1 May 2026 17:18:09 +0000 Subject: [PATCH] Lhotse: add prefetch_factor option to LhotseDataLoadingConfig Add configurable prefetch_factor for PyTorch DataLoader, allowing users to increase the per-worker prefetch buffer depth to absorb I/O latency spikes from network filesystems. Applies to both single-config and multi-config dataloader paths. When unset (None), PyTorch's default of 2 is used, preserving existing behavior. Usage: model.train_ds.prefetch_factor=4 Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> --- .../common/data/lhotse/dataloader.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/nemo/collections/common/data/lhotse/dataloader.py b/nemo/collections/common/data/lhotse/dataloader.py index 5af5f5d004d7..03ae7a3ba3fa 100644 --- a/nemo/collections/common/data/lhotse/dataloader.py +++ b/nemo/collections/common/data/lhotse/dataloader.py @@ -159,6 +159,10 @@ class LhotseDataLoadingConfig: seed: int | str = 0 num_workers: int = 0 pin_memory: bool = False + # Number of batches each DataLoader worker pre-fetches (only effective when num_workers > 0). + # Higher values build a deeper buffer that can absorb I/O latency spikes at the cost of + # increased CPU memory. When None, PyTorch's default of 2 is used. + prefetch_factor: int | None = None channel_selector: int | str | None = None # 4. Optional Lhotse data augmentation. @@ -369,12 +373,15 @@ def get_lhotse_dataloader_from_single_config( # reads only light-weight JSON objects; it samples mini-batches and passes # the meta-data to Dataset, which performs the actual I/O inside its __getitem__ method. dloader_kwargs = dict(dataset=dataset, sampler=sampler) - dloader = torch.utils.data.DataLoader( - **dloader_kwargs, + dloader_kwargs.update( batch_size=None, num_workers=config.num_workers, pin_memory=config.pin_memory, ) + if config.num_workers > 0 and config.get("prefetch_factor") is not None: + dloader_kwargs["prefetch_factor"] = config.prefetch_factor + + dloader = torch.utils.data.DataLoader(**dloader_kwargs) return dloader @@ -493,12 +500,15 @@ def gather_shared_opts(): # reads only light-weight JSON objects; it samples mini-batches and passes # the meta-data to Dataset, which performs the actual I/O inside its __getitem__ method. dloader_kwargs = dict(dataset=dataset, sampler=sampler) - dloader = torch.utils.data.DataLoader( - **dloader_kwargs, + dloader_kwargs.update( batch_size=None, num_workers=shared_opts.num_workers, pin_memory=shared_opts.pin_memory, ) + if shared_opts.num_workers > 0 and shared_opts.get("prefetch_factor") is not None: + dloader_kwargs["prefetch_factor"] = shared_opts.prefetch_factor + + dloader = torch.utils.data.DataLoader(**dloader_kwargs) return dloader