Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions mlpstorage_py/cli_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,20 @@ def parse_arguments():
if hasattr(parsed_args, 'config_file') and parsed_args.config_file:
parsed_args = apply_yaml_config_overrides(parsed_args)

# Consolidate the data access protocol into a single field
if parsed_args.file:
parsed_args.data_access_protocol = "file"
else:
parsed_args.data_access_protocol = parsed_args.object
del parsed_args.file
del parsed_args.object
# Consolidate the data access protocol into a single field.
# The --file / --object flags are only defined on benchmark subcommands
# that call add_storage_type_arguments() (training, checkpointing,
# vectordb, kvcache). Other subcommands (reports, history, lockfile)
# do not define them, so guard the consolidation on attribute presence.
if hasattr(parsed_args, "file") or hasattr(parsed_args, "object"):
if getattr(parsed_args, "file", False):
parsed_args.data_access_protocol = "file"
else:
parsed_args.data_access_protocol = getattr(parsed_args, "object", None)
# Clean up the raw flags so downstream code uses data_access_protocol.
for _attr in ("file", "object"):
if hasattr(parsed_args, _attr):
delattr(parsed_args, _attr)

"""
print(f"Arguments found: {parsed_args}")
Expand Down
87 changes: 87 additions & 0 deletions tests/unit/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,3 +722,90 @@ def test_skips_none_values(self, tmp_path):
result = apply_yaml_config_overrides(args)
assert result.debug is True # Should not be overwritten
assert result.loops == 5

class TestParseArgumentsStorageFlagConsolidation:
"""Regression tests for issue #367.

The CLI parser must not crash when a subcommand that doesn't define
--file / --object (reports, history, lockfile) is invoked, and must
still correctly consolidate those flags into data_access_protocol on
benchmark subcommands that do define them (training, checkpointing,
vectordb, kvcache).
"""

@staticmethod
def _run(monkeypatch, argv):
"""Invoke parse_arguments() with a synthetic sys.argv."""
from mlpstorage_py.cli_parser import parse_arguments
monkeypatch.setattr(sys, "argv", argv)
return parse_arguments()

# --- non-benchmark subcommands: must not raise AttributeError ---

def test_reportgen_does_not_crash_without_storage_flags(self, monkeypatch, tmp_path):
"""Regression test for #367: `reports reportgen` must parse cleanly."""
args = self._run(
monkeypatch,
["mlpstorage", "reports", "reportgen", "--results-dir", str(tmp_path)],
)
assert args.program == "reports"
assert args.command == "reportgen"
assert not hasattr(args, "file")
assert not hasattr(args, "object")

def test_history_does_not_crash_without_storage_flags(self, monkeypatch):
"""`history show` must parse cleanly (no --file/--object on this parser)."""
args = self._run(monkeypatch, ["mlpstorage", "history", "show"])
assert args.program == "history"
assert args.command == "show"
assert not hasattr(args, "file")
assert not hasattr(args, "object")

def test_lockfile_does_not_crash_without_storage_flags(self, monkeypatch):
"""`lockfile generate` must parse cleanly (no --file/--object on this parser)."""
args = self._run(monkeypatch, ["mlpstorage", "lockfile", "generate"])
assert args.program == "lockfile"
assert not hasattr(args, "file")
assert not hasattr(args, "object")

# --- benchmark subcommands: existing consolidation must still work ---

def test_training_run_consolidates_file_flag(self, monkeypatch, tmp_path):
"""`training run --file` should set data_access_protocol='file'."""
args = self._run(
monkeypatch,
[
"mlpstorage", "training", "run",
"--model", "unet3d",
"--hosts", "localhost",
"--num-accelerators", "1",
"--accelerator-type", "h100",
"--client-host-memory-in-gb", "64",
"--data-dir", str(tmp_path / "data"),
"--results-dir", str(tmp_path / "results"),
"--file",
],
)
assert args.data_access_protocol == "file"
assert not hasattr(args, "file")
assert not hasattr(args, "object")

def test_training_run_consolidates_object_flag(self, monkeypatch, tmp_path):
"""`training run --object s3` should set data_access_protocol='s3'."""
args = self._run(
monkeypatch,
[
"mlpstorage", "training", "run",
"--model", "unet3d",
"--hosts", "localhost",
"--num-accelerators", "1",
"--accelerator-type", "h100",
"--client-host-memory-in-gb", "64",
"--data-dir", str(tmp_path / "data"),
"--results-dir", str(tmp_path / "results"),
"--object", "s3",
],
)
assert args.data_access_protocol == "s3"
assert not hasattr(args, "file")
assert not hasattr(args, "object")
Loading