diff --git a/mlpstorage_py/cli_parser.py b/mlpstorage_py/cli_parser.py index af32669f..cafecb98 100755 --- a/mlpstorage_py/cli_parser.py +++ b/mlpstorage_py/cli_parser.py @@ -115,13 +115,20 @@ def parse_arguments(): if hasattr(parsed_args, 'config_file') and parsed_args.config_file: parsed_args = apply_yaml_config_overrides(parsed_args) - # Consolidate the data access protocol into a single field - if parsed_args.file: - parsed_args.data_access_protocol = "file" - else: - parsed_args.data_access_protocol = parsed_args.object - del parsed_args.file - del parsed_args.object + # Consolidate the data access protocol into a single field. + # The --file / --object flags are only defined on benchmark subcommands + # that call add_storage_type_arguments() (training, checkpointing, + # vectordb, kvcache). Other subcommands (reports, history, lockfile) + # do not define them, so guard the consolidation on attribute presence. + if hasattr(parsed_args, "file") or hasattr(parsed_args, "object"): + if getattr(parsed_args, "file", False): + parsed_args.data_access_protocol = "file" + else: + parsed_args.data_access_protocol = getattr(parsed_args, "object", None) + # Clean up the raw flags so downstream code uses data_access_protocol. + for _attr in ("file", "object"): + if hasattr(parsed_args, _attr): + delattr(parsed_args, _attr) """ print(f"Arguments found: {parsed_args}") diff --git a/tests/unit/test_cli.py b/tests/unit/test_cli.py index 236a2f5b..43f5206d 100755 --- a/tests/unit/test_cli.py +++ b/tests/unit/test_cli.py @@ -722,3 +722,90 @@ def test_skips_none_values(self, tmp_path): result = apply_yaml_config_overrides(args) assert result.debug is True # Should not be overwritten assert result.loops == 5 + +class TestParseArgumentsStorageFlagConsolidation: + """Regression tests for issue #367. + + The CLI parser must not crash when a subcommand that doesn't define + --file / --object (reports, history, lockfile) is invoked, and must + still correctly consolidate those flags into data_access_protocol on + benchmark subcommands that do define them (training, checkpointing, + vectordb, kvcache). + """ + + @staticmethod + def _run(monkeypatch, argv): + """Invoke parse_arguments() with a synthetic sys.argv.""" + from mlpstorage_py.cli_parser import parse_arguments + monkeypatch.setattr(sys, "argv", argv) + return parse_arguments() + + # --- non-benchmark subcommands: must not raise AttributeError --- + + def test_reportgen_does_not_crash_without_storage_flags(self, monkeypatch, tmp_path): + """Regression test for #367: `reports reportgen` must parse cleanly.""" + args = self._run( + monkeypatch, + ["mlpstorage", "reports", "reportgen", "--results-dir", str(tmp_path)], + ) + assert args.program == "reports" + assert args.command == "reportgen" + assert not hasattr(args, "file") + assert not hasattr(args, "object") + + def test_history_does_not_crash_without_storage_flags(self, monkeypatch): + """`history show` must parse cleanly (no --file/--object on this parser).""" + args = self._run(monkeypatch, ["mlpstorage", "history", "show"]) + assert args.program == "history" + assert args.command == "show" + assert not hasattr(args, "file") + assert not hasattr(args, "object") + + def test_lockfile_does_not_crash_without_storage_flags(self, monkeypatch): + """`lockfile generate` must parse cleanly (no --file/--object on this parser).""" + args = self._run(monkeypatch, ["mlpstorage", "lockfile", "generate"]) + assert args.program == "lockfile" + assert not hasattr(args, "file") + assert not hasattr(args, "object") + + # --- benchmark subcommands: existing consolidation must still work --- + + def test_training_run_consolidates_file_flag(self, monkeypatch, tmp_path): + """`training run --file` should set data_access_protocol='file'.""" + args = self._run( + monkeypatch, + [ + "mlpstorage", "training", "run", + "--model", "unet3d", + "--hosts", "localhost", + "--num-accelerators", "1", + "--accelerator-type", "h100", + "--client-host-memory-in-gb", "64", + "--data-dir", str(tmp_path / "data"), + "--results-dir", str(tmp_path / "results"), + "--file", + ], + ) + assert args.data_access_protocol == "file" + assert not hasattr(args, "file") + assert not hasattr(args, "object") + + def test_training_run_consolidates_object_flag(self, monkeypatch, tmp_path): + """`training run --object s3` should set data_access_protocol='s3'.""" + args = self._run( + monkeypatch, + [ + "mlpstorage", "training", "run", + "--model", "unet3d", + "--hosts", "localhost", + "--num-accelerators", "1", + "--accelerator-type", "h100", + "--client-host-memory-in-gb", "64", + "--data-dir", str(tmp_path / "data"), + "--results-dir", str(tmp_path / "results"), + "--object", "s3", + ], + ) + assert args.data_access_protocol == "s3" + assert not hasattr(args, "file") + assert not hasattr(args, "object")