Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions nemo/collections/common/data/lhotse/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,10 @@ class LhotseDataLoadingConfig:
seed: int | str = 0
num_workers: int = 0
pin_memory: bool = False
# Number of batches each DataLoader worker pre-fetches (only effective when num_workers > 0).
# Higher values build a deeper buffer that can absorb I/O latency spikes at the cost of
# increased CPU memory. When None, PyTorch's default of 2 is used.
prefetch_factor: int | None = None
channel_selector: int | str | None = None

# 4. Optional Lhotse data augmentation.
Expand Down Expand Up @@ -369,12 +373,15 @@ def get_lhotse_dataloader_from_single_config(
# reads only light-weight JSON objects; it samples mini-batches and passes
# the meta-data to Dataset, which performs the actual I/O inside its __getitem__ method.
dloader_kwargs = dict(dataset=dataset, sampler=sampler)
dloader = torch.utils.data.DataLoader(
**dloader_kwargs,
dloader_kwargs.update(
batch_size=None,
num_workers=config.num_workers,
pin_memory=config.pin_memory,
)
if config.num_workers > 0 and config.get("prefetch_factor") is not None:
dloader_kwargs["prefetch_factor"] = config.prefetch_factor

dloader = torch.utils.data.DataLoader(**dloader_kwargs)

return dloader

Expand Down Expand Up @@ -493,12 +500,15 @@ def gather_shared_opts():
# reads only light-weight JSON objects; it samples mini-batches and passes
# the meta-data to Dataset, which performs the actual I/O inside its __getitem__ method.
dloader_kwargs = dict(dataset=dataset, sampler=sampler)
dloader = torch.utils.data.DataLoader(
**dloader_kwargs,
dloader_kwargs.update(
batch_size=None,
num_workers=shared_opts.num_workers,
pin_memory=shared_opts.pin_memory,
)
if shared_opts.num_workers > 0 and shared_opts.get("prefetch_factor") is not None:
dloader_kwargs["prefetch_factor"] = shared_opts.prefetch_factor
Comment on lines +508 to +509

dloader = torch.utils.data.DataLoader(**dloader_kwargs)

return dloader

Expand Down
Loading