Skip to content

Commit 14a3902

Browse files
authored
Fix save_hyperparameters ignore precedence in subclasses (#21490)
1 parent a25515e commit 14a3902

4 files changed

Lines changed: 48 additions & 2 deletions

File tree

src/lightning/pytorch/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2626

2727
### Fixed
2828

29+
- Fixed `save_hyperparameters(ignore=...)` behavior so subclass ignore rules override base class rules (#[21490](https://github.com/Lightning-AI/pytorch-lightning/pull/21490))
30+
31+
2932
- Fixed `LightningDataModule.load_from_checkpoint` to restore the datamodule subclass and hyperparameters ([#21478](https://github.com/Lightning-AI/pytorch-lightning/pull/21478))
3033

3134

src/lightning/pytorch/core/mixins/hparams_mixin.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,19 @@ def _set_hparams(self, hp: Union[MutableMapping, Namespace, str]) -> None:
138138
else:
139139
self._hparams = hp
140140

141+
def remove_ignored_hparams(self, ignore_list: list[str]) -> None:
142+
"""Remove ignored hyperparameters from the stored state.
143+
144+
This allows derived classes to drop hyperparameters previously saved
145+
by base classes.
146+
147+
Args:
148+
ignore_list: Names of hyperparameters to remove.
149+
150+
"""
151+
for key in ignore_list:
152+
self._hparams.pop(key, None)
153+
141154
@staticmethod
142155
def _to_hparams_dict(hp: Union[MutableMapping, Namespace, str]) -> Union[MutableMapping, AttributeDict]:
143156
if isinstance(hp, Namespace):

src/lightning/pytorch/utilities/parsing.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,9 @@ def save_hyperparameters(
204204

205205
# `hparams` are expected here
206206
obj._set_hparams(hp)
207+
# Remove ignored hparams from the stored hyperparameters.
208+
# Allows a derived class to drop hparams previously saved by a base class.
209+
obj.remove_ignored_hparams(ignore)
207210

208211
for k, v in obj._hparams.items():
209212
if isinstance(v, nn.Module):

tests/tests_pytorch/models/test_hparams.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -838,7 +838,9 @@ def __init__(self, arg1, arg2, arg3):
838838

839839
# test proper property assignments
840840
assert model.hparams.arg1 == 14
841-
for arg in ignore:
841+
842+
ignore_args = ignore if isinstance(ignore, (list, tuple)) else [ignore]
843+
for arg in ignore_args:
842844
assert arg not in model.hparams
843845

844846
# verify we can train
@@ -854,7 +856,32 @@ def __init__(self, arg1, arg2, arg3):
854856
# verify that model loads correctly
855857
model = LocalModel.load_from_checkpoint(raw_checkpoint_path, arg2=123, arg3=100)
856858
assert model.hparams.arg1 == 14
857-
for arg in ignore:
859+
for arg in ignore_args:
860+
assert arg not in model.hparams
861+
862+
863+
@pytest.mark.parametrize("ignore", ["arg2", ("arg2", "arg3")])
864+
def test_hparams_ignore_in_subclass_overrides_base(tmp_path, ignore):
865+
"""Test that hyperparameters can be ignored when `save_hyperparameters` is called in both a base class and a
866+
subclass, and that ignore rules defined in the subclass override hyperparameters saved by the base class."""
867+
868+
class BaseBoringModel(BoringModel):
869+
def __init__(self, arg1, arg2, arg3):
870+
super().__init__()
871+
self.save_hyperparameters(ignore="arg1")
872+
873+
class LocalModel(BaseBoringModel):
874+
def __init__(self, arg1, arg2, arg3):
875+
super().__init__(arg1=arg1, arg2=arg2, arg3=arg3)
876+
self.save_hyperparameters(ignore=ignore)
877+
878+
model = LocalModel(arg1=14, arg2=90, arg3=50)
879+
880+
# `arg1` was ignored by the base class,
881+
# but the subclass did not ignore it, so it should be present.
882+
assert model.hparams.arg1 == 14
883+
ignore_args = ignore if isinstance(ignore, (list, tuple)) else [ignore]
884+
for arg in ignore_args:
858885
assert arg not in model.hparams
859886

860887

0 commit comments

Comments
 (0)