[core] Unify validation_step_outputs to always return list-of-lists#15470
[core] Unify validation_step_outputs to always return list-of-lists#15470XuesongYang wants to merge 4 commits intoNVIDIA-NeMo:mainfrom
Conversation
validation_step_outputs and test_step_outputs now always return a list of lists (one inner list per dataloader), eliminating if/else branching in every subclass that handles single-vs-multi dataloader shapes. - validation_step_outputs property: returns [[] for _ in range(num_dl)] - on_validation/test_epoch_end: len()==1 dispatch, all(len(o)==0 ...) empty guard, skip empty DL buckets in multi-DL loop - Normalize _validation_dl to Optional[List[DataLoader]] in resolver - 15 model files: self.validation_step_outputs[dataloader_idx].append() - TTS models: RuntimeError guard for single-DL assumption - Test models: override multi_validation_epoch_end, not on_*_epoch_end - Bug fix: ssl_models test_step appended to wrong outputs list - New test: empty outputs skip multi_epoch_end Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Made-with: Cursor
Same wrapping pattern as _validation_dl: wrap bare DataLoader into [DataLoader] at both single-value paths in resolve_test_dataloaders. Simplify isinstance guards in test_step_outputs and _get_num_dataloaders. Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
There was a problem hiding this comment.
Pull request overview
This PR standardizes ModelPT.validation_step_outputs / test_step_outputs to use a consistent “list-of-lists” shape, simplifying subclass logic by removing single-vs-multi-dataloader branching and improving epoch-end dispatch/guards.
Changes:
- Updated
ModelPTepoch-end logic to dispatch based onlen(outputs)(single vs multi dataloader) and to skip/guard empty per-dataloader outputs. - Normalized validation dataloader storage to
List[DataLoader]inresolve_validation_dataloaders()and refactored many modelvalidation_step/test_stepimplementations to always append via[dataloader_idx]. - Updated unit tests and added a regression test to ensure
multi_validation_epoch_end/multi_test_epoch_endare not called when all outputs are empty.
Reviewed changes
Copilot reviewed 22 out of 22 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/core_ptl/test_ptl_stateless_timer.py | Updates test model hooks to the new list-of-lists output shape and adds an empty-epoch regression test. |
| tests/core_ptl/check_for_ranks.py | Switches test model to append outputs via validation_step_outputs[dataloader_idx] and uses multi_validation_epoch_end. |
| tests/collections/common/test_ema.py | Updates validation/test steps to append via [dataloader_idx] and uses multi_validation_epoch_end. |
| nemo/utils/model_utils.py | Wraps single validation dataloaders into a list to normalize _validation_dl shape. |
| nemo/core/classes/modelPT.py | Implements the unified list-of-lists output cache, updates epoch-end dispatch and empty-output guards. |
| nemo/collections/tts/models/magpietts_preference_optimization.py | Removes single-vs-multi branching in validation output accumulation; adjusts epoch-end logic for the new shape. |
| nemo/collections/tts/models/magpietts.py | Updates validation accumulation and epoch-end collection to use validation_step_outputs[0] consistently. |
| nemo/collections/tts/models/fastpitch.py | Updates validation accumulation and epoch-end processing to use the new output structure. |
| nemo/collections/tts/g2p/models/t5.py | Removes branching on dataloader count; always appends via [dataloader_idx]. |
| nemo/collections/tts/g2p/models/ctc.py | Removes branching on dataloader count; always appends via [dataloader_idx]. |
| nemo/collections/audio/models/audio_to_audio.py | Removes branching on dataloader count; simplifies callback setup in line with _validation_dl normalization. |
| nemo/collections/asr/models/transformer_bpe_models.py | Simplifies multi-epoch-end logic to assume per-dataloader outputs (base class iterates dataloaders). |
| nemo/collections/asr/models/ssl_models.py | Removes branching on dataloader count and fixes test_step to append to test_step_outputs. |
| nemo/collections/asr/models/sortformer_diar_models.py | Removes branching on dataloader count; always appends via [dataloader_idx]. |
| nemo/collections/asr/models/slu_models.py | Removes branching on dataloader count; always appends via [dataloader_idx]. |
| nemo/collections/asr/models/rnnt_models.py | Removes branching on dataloader count; always appends via [dataloader_idx]. |
| nemo/collections/asr/models/label_models.py | Removes branching on dataloader count; always appends via [dataloader_idx]. |
| nemo/collections/asr/models/hybrid_rnnt_ctc_models.py | Removes branching on dataloader count; always appends via [dataloader_idx]. |
| nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models_prompt.py | Removes branching on dataloader count; always appends via [dataloader_idx]. |
| nemo/collections/asr/models/ctc_models.py | Removes branching on dataloader count; always appends via [dataloader_idx]. |
| nemo/collections/asr/models/classification_models.py | Removes branching on dataloader count; always appends via [dataloader_idx]. |
| nemo/collections/asr/models/aed_multitask_models.py | Removes branching on dataloader count; always appends via [dataloader_idx]. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
nemo/core/classes/modelPT.py
Outdated
| num_dl = len(self._validation_dl) if self._validation_dl else 1 | ||
| self._validation_step_outputs = [[] for _ in range(num_dl)] | ||
|
|
There was a problem hiding this comment.
validation_step_outputs computes num_dl using len(self._validation_dl) whenever _validation_dl is truthy. If a caller uses setup_validation_data() directly (common across models) then _validation_dl is typically a single DataLoader, so len(DataLoader) equals the number of batches. That will incorrectly create one output bucket per batch and can force on_validation_epoch_end() into the multi-dataloader branch (which expects _validation_names to be set), leading to crashes. Consider computing num_dl from the number of dataloaders only (e.g., len(_validation_dl) only when _validation_dl is a list/tuple of dataloaders; otherwise treat as 1, and keep the empty-list case consistent with the new [[]] semantics).
nemo/core/classes/modelPT.py
Outdated
| num_dl = len(self._test_dl) if self._test_dl else 1 | ||
| self._test_step_outputs = [[] for _ in range(num_dl)] |
There was a problem hiding this comment.
test_step_outputs uses len(self._test_dl) when _test_dl is a list/tuple. Since ModelPT.test_dataloader() sets _test_dl = [] when unset, this property can return [] (not a list-of-lists), which conflicts with the PR’s stated invariant and differs from validation_step_outputs (which returns [[]] for an empty list). Consider using the same logic as validation (treat empty list as the N=1 case, and only take len() when the container is non-empty).
| def _get_num_dataloaders(self, tag: str = 'val'): | ||
| if tag == 'val': | ||
| num_dataloaders = len(self._validation_dl) if isinstance(self._validation_dl, List) else 1 | ||
| num_dataloaders = len(self._validation_dl) if self._validation_dl else 1 | ||
| elif tag == 'test': | ||
| num_dataloaders = len(self._test_dl) if isinstance(self._test_dl, List) else 1 | ||
| num_dataloaders = len(self._test_dl) if self._test_dl else 1 | ||
| else: |
There was a problem hiding this comment.
_get_num_dataloaders() now returns 1 when _validation_dl is an empty list. This changes the meaning from “number of configured dataloaders” to “at least 1”, which can cause _setup_metrics() to initialize metrics for a non-existent dataloader. Also, isinstance(self._test_dl, List) uses typing.List, which raises TypeError at runtime for isinstance checks; this should be replaced with a runtime type like (list, tuple) (and likely the same empty-list handling as for validation).
Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
c422e93 to
53da4f2
Compare
What does this PR do ?
Unify
ModelPT.validation_step_outputs(andtest_step_outputs) to always return a list of lists, so a single dataloader is simply the N=1 case and subclasses no longer need to branch on the output shape. Normalize both_validation_dland_test_dltoOptional[List[DataLoader]]via their respective resolvers.Collection: Core, ASR, TTS, Audio
Changelog
modelPT.py:validation_step_outputs/test_step_outputsproperties always return[[] for _ in range(num_dl)];on_validation_epoch_end/on_test_epoch_enduselen() == 1instead ofisinstance(..., dict)for single-vs-multi dispatch; empty-output guard updated toall(len(o) == 0 for o in ...)since[[]]is truthy; empty dataloader buckets skipped in multi-DL loopmodel_utils.py:resolve_validation_dataloadersandresolve_test_dataloaderswrap bareDataLoaderinto[DataLoader]at both single-value paths, normalizing_validation_dland_test_dltoOptional[List[DataLoader]]modelPT.py(setup_multiple_validation_data): type annotation updated;isinstanceguard simplified to truthiness check after normalizationif/elsebranching invalidation_step/test_step; always useself.validation_step_outputs[dataloader_idx].append(...)transformer_bpe_models.py: removeisinstance(outputs[0], dict)normalization loop inmulti_validation_epoch_end— base class now iterates dataloaders and calls it once per DLaudio_to_audio.py: simplify_get_num_dataloaders(both val and test) and logging callback setup after normalizationfastpitch.py,magpietts.py,magpietts_preference_optimization.py: addRuntimeErrorguard forlen(validation_step_outputs) != 1; add early-return on empty outputs; useself.validation_step_outputs[0]consistentlyssl_models.py: fixEncDecMaskedTokenPredModel.test_step— was appending tovalidation_step_outputsinstead oftest_step_outputstest_ema.py,check_for_ranks.py,test_ptl_stateless_timer.py): overridemulti_validation_epoch_endinstead ofon_validation_epoch_end; base class handles iteration, clearing, and per-DL prefixtest_empty_epoch_outputs_skip_multi_epoch_end: verifymulti_validation/test_epoch_endis never called when all outputs are emptyUsage
No API changes for single-dataloader models —
dataloader_idx=0is the default. Subclasses should use the[dataloader_idx]indexing pattern: