Skip to content

Commit 8341afe

Browse files
authored
fix(tests): clean up sys.modules pollution in training fixtures (#2168)
Two cooperating fixture bugs caused order-dependent failures in tests/unit_tests/training/: 1. test_timers.py:67 — `sys.modules["torch.distributed"] = dist_stub` bypassed pytest's monkeypatch finalizer, leaking the stub for the rest of the session. Any later test that imported torchao failed with `ModuleNotFoundError: No module named 'torch.distributed.<X>'; 'torch.distributed' is not a package`. Fixed by switching to `monkeypatch.setitem(sys.modules, ...)` which is auto-reverted. 2. test_train_ft_mlflow_logging.py:_install_fake_wandb — the fake wandb stub had `__spec__ = None`, so accelerate's `importlib.util.find_spec("wandb")` raised `ValueError: wandb.__spec__ is None`, breaking the test in isolation. Fixed by setting a valid ModuleSpec via `importlib.util.spec_from_loader`. Verified: `pytest tests/unit_tests/training/` 55 passed (previously 2 failed, both fixture bugs visible). Signed-off-by: Robert Luke <code@robertluke.net>
1 parent 660ed94 commit 8341afe

2 files changed

Lines changed: 4 additions & 1 deletion

File tree

tests/unit_tests/training/test_timers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def _all_gather(dest: torch.Tensor, src: torch.Tensor): # noqa: D401
6464
dist_stub._all_gather_base = _all_gather
6565

6666
monkeypatch.setattr(torch, "distributed", dist_stub, raising=False)
67-
sys.modules["torch.distributed"] = dist_stub
67+
monkeypatch.setitem(sys.modules, "torch.distributed", dist_stub)
6868

6969
# Import the module *after* stubs are in place so it picks them up.
7070
# Force reload in case it was imported by another test module

tests/unit_tests/training/test_train_ft_mlflow_logging.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import importlib.util
1516
import sys
1617
import types
1718
from unittest.mock import Mock
@@ -26,6 +27,8 @@ def _install_fake_wandb():
2627
Provide a minimal 'wandb' package so train_ft can be imported without the real dependency.
2728
"""
2829
wandb = types.ModuleType("wandb")
30+
# accelerate calls importlib.util.find_spec("wandb"), which raises if __spec__ is None
31+
wandb.__spec__ = importlib.util.spec_from_loader("wandb", loader=None)
2932
wandb.run = None
3033

3134
class Settings:

0 commit comments

Comments
 (0)