Skip to content

Commit 901164a

Browse files
authored
CLI help and command help; misc improvements (#229)
Signed-off-by: Brian Yu <bxyu@nvidia.com>
1 parent 3bd86ae commit 901164a

9 files changed

Lines changed: 698 additions & 73 deletions

File tree

nemo_gym/cli.py

Lines changed: 55 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import asyncio
1515
import json
1616
import shlex
17+
import tomllib
1718
from glob import glob
1819
from os import environ, makedirs
1920
from os.path import exists
@@ -23,13 +24,15 @@
2324
from time import sleep
2425
from typing import Dict, List, Optional
2526

27+
import rich
2628
import uvicorn
2729
from devtools import pprint
2830
from omegaconf import DictConfig, OmegaConf
29-
from pydantic import BaseModel
31+
from pydantic import BaseModel, Field
3032
from tqdm.auto import tqdm
3133

3234
from nemo_gym import PARENT_DIR
35+
from nemo_gym.config_types import BaseNeMoGymCLIConfig
3336
from nemo_gym.global_config import (
3437
NEMO_GYM_CONFIG_DICT_ENV_VAR_NAME,
3538
NEMO_GYM_CONFIG_PATH_ENV_VAR_NAME,
@@ -60,23 +63,32 @@ def _run_command(command: str, working_directory: Path) -> Popen: # pragma: no
6063
return Popen(command, executable="/bin/bash", shell=True, env=custom_env)
6164

6265

63-
class RunConfig(BaseModel):
64-
entrypoint: str
66+
class RunConfig(BaseNeMoGymCLIConfig):
67+
entrypoint: str = Field(
68+
description="Entrypoint for this command. This must be a relative path with 2 parts. Should look something like `responses_api_agents/simple_agent`."
69+
)
6570

6671

6772
class TestConfig(RunConfig):
68-
should_validate_data: bool = False
73+
should_validate_data: bool = Field(
74+
default=False,
75+
description="Whether or not to validate the example data (examples, metrics, rollouts, etc) for this server.",
76+
)
6977

70-
dir_path: Path = None # initialized in model_post_init
78+
_dir_path: Path # initialized in model_post_init
7179

7280
def model_post_init(self, context): # pragma: no cover
7381
# TODO: This currently only handles relative entrypoints. Later on we can resolve the absolute path.
74-
self.dir_path = Path(self.entrypoint)
82+
self._dir_path = Path(self.entrypoint)
7583
assert not self.dir_path.is_absolute()
7684
assert len(self.dir_path.parts) == 2
7785

7886
return super().model_post_init(context)
7987

88+
@property
89+
def dir_path(self) -> Path:
90+
return self._dir_path
91+
8092

8193
class ServerInstanceDisplayConfig(BaseModel):
8294
process_name: str
@@ -274,6 +286,10 @@ def check_http_server_statuses(self) -> List[ServerStatus]:
274286
def run(
275287
global_config_dict_parser_config: Optional[GlobalConfigDictParserConfig] = None,
276288
): # pragma: no cover
289+
global_config_dict = get_global_config_dict(global_config_dict_parser_config=global_config_dict_parser_config)
290+
# Just here for help
291+
BaseNeMoGymCLIConfig.model_validate(global_config_dict)
292+
277293
rh = RunHelper()
278294
rh.start(global_config_dict_parser_config)
279295
rh.run_forever()
@@ -386,8 +402,11 @@ def _format_pct(count: int, total: int) -> str: # pragma: no cover
386402
return f"{count} / {total} ({100 * count / total:.2f}%)"
387403

388404

389-
class TestAllConfig(BaseModel):
390-
fail_on_total_and_test_mismatch: bool = False
405+
class TestAllConfig(BaseNeMoGymCLIConfig):
406+
fail_on_total_and_test_mismatch: bool = Field(
407+
default=False,
408+
description="There may be situations where there are an un-equal number of servers that exist vs have tests. This flag will fail the test job if this mismatch exists.",
409+
)
391410

392411

393412
def test_all(): # pragma: no cover
@@ -474,6 +493,10 @@ def test_all(): # pragma: no cover
474493

475494

476495
def dev_test(): # pragma: no cover
496+
global_config_dict = get_global_config_dict()
497+
# Just here for help
498+
BaseNeMoGymCLIConfig.model_validate(global_config_dict)
499+
477500
proc = Popen("pytest --cov=. --durations=10", shell=True)
478501
exit(proc.wait())
479502

@@ -592,4 +615,28 @@ def init_resources_server(): # pragma: no cover
592615

593616
def dump_config(): # pragma: no cover
594617
global_config_dict = get_global_config_dict()
618+
# Just here for help
619+
BaseNeMoGymCLIConfig.model_validate(global_config_dict)
620+
595621
print(OmegaConf.to_yaml(global_config_dict, resolve=True))
622+
623+
624+
def display_help(): # pragma: no cover
625+
global_config_dict = get_global_config_dict()
626+
# Just here for help
627+
BaseNeMoGymCLIConfig.model_validate(global_config_dict)
628+
629+
pyproject_path = Path(PARENT_DIR) / "pyproject.toml"
630+
with pyproject_path.open("rb") as f:
631+
pyproject_data = tomllib.load(f)
632+
633+
project_scripts = pyproject_data["project"]["scripts"]
634+
rich.print("""Run a command with `+h=true` or `+help=true` to see more detailed information!
635+
636+
[bold]Available CLI scripts[/bold]
637+
-----------------""")
638+
for script in project_scripts:
639+
if not script.startswith("ng_"):
640+
continue
641+
642+
print(script)

nemo_gym/config_types.py

Lines changed: 68 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from enum import Enum
1515
from typing import Any, ClassVar, Dict, List, Literal, Optional, Union
1616

17+
import rich
1718
from omegaconf import DictConfig, OmegaConf
1819
from pydantic import (
1920
BaseModel,
@@ -23,6 +24,60 @@
2324
ValidationError,
2425
model_validator,
2526
)
27+
from rich.text import Text
28+
29+
30+
########################################
31+
# Base CLI configs
32+
########################################
33+
34+
35+
class BaseNeMoGymCLIConfig(BaseModel):
36+
@model_validator(mode="before")
37+
@classmethod
38+
def pre_process(cls, data):
39+
if not (data.get("h") or data.get("help")):
40+
return data
41+
42+
rich.print(f"""Displaying help for [bold]{cls.__name__}[/bold]
43+
""")
44+
# We use __doc__ directly here since inspect.getdoc will inherit the doc from parent classes.
45+
class_doc = cls.__doc__
46+
if class_doc:
47+
rich.print(f"""[bold]Description[/bold]
48+
-----------
49+
{class_doc.strip()}
50+
""")
51+
52+
fields = cls.model_fields.items()
53+
if fields:
54+
rich.print("""[bold]Parameters[/bold]
55+
----------""")
56+
57+
prefixes: List[Text] = []
58+
suffixes: List[Text] = []
59+
for field_name, field in fields:
60+
description_str = field.description if field.description else ""
61+
62+
# Not sure if there is a better way to get this annotation_str, e.g. using typing.get_args or typing.get_origin
63+
annotation_str = (
64+
field.annotation.__name__ if isinstance(field.annotation, type) else str(field.annotation)
65+
)
66+
annotation_str = annotation_str.replace("typing.", "")
67+
68+
prefixes.append(Text.from_markup(f"- [blue]{field_name}[/blue] [yellow]({annotation_str})[/yellow]"))
69+
suffixes.append(description_str)
70+
71+
max_prefix_length = max(map(len, prefixes))
72+
ljust_length = max_prefix_length + 3
73+
for prefix, suffix in zip(prefixes, suffixes):
74+
prefix.align("left", ljust_length)
75+
rich.print(prefix + suffix)
76+
else:
77+
print("There are no arguments to this CLI command!")
78+
79+
# Exit after help is printed.
80+
exit()
2681

2782

2883
########################################
@@ -63,10 +118,14 @@ def is_server_ref(config_dict: DictConfig) -> Optional[ServerRef]:
63118
########################################
64119

65120

66-
class UploadJsonlDatasetGitlabConfig(BaseModel):
67-
dataset_name: str
68-
version: str # Must be x.x.x
69-
input_jsonl_fpath: str
121+
class UploadJsonlDatasetGitlabConfig(BaseNeMoGymCLIConfig):
122+
"""
123+
Upload a local jsonl dataset artifact to Gitlab.
124+
"""
125+
126+
dataset_name: str = Field(description="The dataset name.")
127+
version: str = Field(description="The version of this dataset. Must be in the format `x.x.x`.")
128+
input_jsonl_fpath: str = Field(description="Path to the jsonl file to upload.")
70129

71130

72131
class JsonlDatasetGitlabIdentifer(BaseModel):
@@ -75,8 +134,11 @@ class JsonlDatasetGitlabIdentifer(BaseModel):
75134
artifact_fpath: str
76135

77136

78-
class DownloadJsonlDatasetGitlabConfig(JsonlDatasetGitlabIdentifer):
79-
output_fpath: str
137+
class DownloadJsonlDatasetGitlabConfig(JsonlDatasetGitlabIdentifer, BaseNeMoGymCLIConfig):
138+
dataset_name: str = Field(description="The dataset name.")
139+
version: str = Field(description="The version of this dataset. Must be in the format `x.x.x`.")
140+
artifact_fpath: str = Field(description="The filepath to the artifact to download.")
141+
output_fpath: str = Field(description="Where to save the downloaded dataset.")
80142

81143

82144
DatasetType = Union[Literal["train"], Literal["validation"], Literal["example"]]

nemo_gym/dataset_viewer.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,11 @@
2323
ResponseInputItemParam,
2424
ResponseReasoningItemParam,
2525
)
26-
from pydantic import BaseModel, ConfigDict
26+
from pydantic import ConfigDict, Field
2727
from tqdm.auto import tqdm
2828

2929
from nemo_gym.base_resources_server import BaseVerifyResponse
30+
from nemo_gym.config_types import BaseNeMoGymCLIConfig
3031
from nemo_gym.server_utils import get_global_config_dict
3132
from nemo_gym.train_data_utils import (
3233
DatasetMetrics,
@@ -203,8 +204,8 @@ def extra_info_to_messages(d: DatasetViewerVerifyResponse) -> List[ChatMessage]:
203204
return messages
204205

205206

206-
class JsonlDatasetViewerConfig(BaseModel):
207-
jsonl_fpath: str
207+
class JsonlDatasetViewerConfig(BaseNeMoGymCLIConfig):
208+
jsonl_fpath: str = Field(description="Filepath to a local jsonl file to view.")
208209

209210

210211
def get_aggregate_metrics(data: List[DatasetViewerVerifyResponse]) -> Dict[str, Any]:

nemo_gym/openai_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -422,7 +422,7 @@ class NeMoGymChatCompletionCreateParamsNonStreaming(BaseModel):
422422
########################################
423423

424424

425-
class NeMoGymAsyncOpenAI(BaseModel):
425+
class NeMoGymAsyncOpenAI(BaseModel): # pragma: no cover
426426
"""This is just a stub class that wraps around aiohttp"""
427427

428428
base_url: str

nemo_gym/rollout_collection.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from pydantic import BaseModel, Field
2323
from tqdm.asyncio import tqdm
2424

25-
from nemo_gym.config_types import BaseServerConfig
25+
from nemo_gym.config_types import BaseNeMoGymCLIConfig, BaseServerConfig
2626
from nemo_gym.server_utils import (
2727
GlobalAIOHTTPAsyncClientConfig,
2828
ServerClient,
@@ -32,14 +32,30 @@
3232
)
3333

3434

35-
class RolloutCollectionConfig(BaseModel):
36-
agent_name: str
37-
input_jsonl_fpath: str
38-
output_jsonl_fpath: str
39-
limit: Optional[int] = None
40-
num_repeats: Optional[int] = None
41-
num_samples_in_parallel: Optional[int] = None
42-
responses_create_params: Dict[str, Any] = Field(default_factory=dict)
35+
class RolloutCollectionConfig(BaseNeMoGymCLIConfig):
36+
"""
37+
Perform a batch of rollout collection.
38+
"""
39+
40+
agent_name: str = Field(description="The agent to collect rollouts from.")
41+
input_jsonl_fpath: str = Field(
42+
description="The input data source to use to collect rollouts, in the form of a file path to a jsonl file."
43+
)
44+
output_jsonl_fpath: str = Field(description="The output data jsonl file path.")
45+
limit: Optional[int] = Field(
46+
default=None, description="Maximum number of examples to load and take from the input dataset."
47+
)
48+
num_repeats: Optional[int] = Field(
49+
default=None,
50+
description="The number of times to repeat each example to run. Useful if you want to calculate mean@k e.g. mean@4 or mean@16.",
51+
)
52+
num_samples_in_parallel: Optional[int] = Field(
53+
default=None, description="Limit the number of concurrent samples running at once."
54+
)
55+
responses_create_params: Dict[str, Any] = Field(
56+
default_factory=dict,
57+
description="Overrides for the responses_create_params e.g. temperature, max_output_tokens, etc.",
58+
)
4359

4460

4561
class RolloutCollectionHelper(BaseModel): # pragma: no cover

nemo_gym/train_data_utils.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from nemo_gym.config_types import (
3030
AGENT_REF_KEY,
3131
AgentServerRef,
32+
BaseNeMoGymCLIConfig,
3233
DatasetConfig,
3334
DatasetType,
3435
DownloadJsonlDatasetGitlabConfig,
@@ -42,10 +43,15 @@
4243
)
4344

4445

45-
class TrainDataProcessorConfig(BaseModel):
46-
output_dirpath: str
47-
mode: Union[Literal["train_preparation"], Literal["example_validation"]]
48-
should_download: bool = False
46+
class TrainDataProcessorConfig(BaseNeMoGymCLIConfig):
47+
output_dirpath: str = Field(description="Path to the directory to save the outputs.")
48+
mode: Union[Literal["train_preparation"], Literal["example_validation"]] = Field(
49+
description="Whether to do train_preparation or example_validation."
50+
)
51+
should_download: bool = Field(
52+
default=False,
53+
description="Whether or not to download missing datasets. By default, no datasets will be downloaded.",
54+
)
4955

5056
@property
5157
def in_scope_dataset_types(self) -> List[DatasetType]:

pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,10 @@ ng_prepare_data = "nemo_gym.train_data_utils:prepare_data"
246246
nemo_gym_dump_config = "nemo_gym.cli:dump_config"
247247
ng_dump_config = "nemo_gym.cli:dump_config"
248248

249+
# Display help
250+
nemo_gym_help = "nemo_gym.cli:display_help"
251+
ng_help = "nemo_gym.cli:display_help"
252+
249253

250254
[tool.setuptools.packages.find]
251255
where = ["."]

tests/unit_tests/test_cli.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,48 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import sys
15+
import tomllib
16+
from importlib import import_module
17+
from io import StringIO
18+
from pathlib import Path
19+
20+
from omegaconf import OmegaConf
21+
from pytest import MonkeyPatch, raises
22+
23+
import nemo_gym.global_config
24+
from nemo_gym import PARENT_DIR
1425
from nemo_gym.cli import RunConfig
1526

1627

1728
# TODO: Eventually we want to add more tests to ensure that the CLI flows do not break
1829
class TestCLI:
1930
def test_sanity(self) -> None:
2031
RunConfig(entrypoint="", name="")
32+
33+
def test_pyproject_scripts(self) -> None:
34+
pyproject_path = Path(PARENT_DIR) / "pyproject.toml"
35+
with pyproject_path.open("rb") as f:
36+
pyproject_data = tomllib.load(f)
37+
38+
project_scripts = pyproject_data["project"]["scripts"]
39+
40+
for script_name, import_path in project_scripts.items():
41+
# Dedupe `nemo_gym_*` from `ng_*` commands
42+
if not script_name.startswith("ng_"):
43+
continue
44+
45+
# We only test `+h=true` and not `+help=true`
46+
print(f"Running `{script_name} +h=true`")
47+
48+
module, fn = import_path.split(":")
49+
fn = getattr(import_module(module), fn)
50+
51+
with MonkeyPatch.context() as mp:
52+
mp.setattr(nemo_gym.global_config, "_GLOBAL_CONFIG_DICT", OmegaConf.create({"h": True}))
53+
54+
text_trap = StringIO()
55+
mp.setattr(sys, "stdout", text_trap)
56+
57+
with raises(SystemExit):
58+
fn()

0 commit comments

Comments
 (0)