|
14 | 14 | import asyncio |
15 | 15 | import json |
16 | 16 | import shlex |
| 17 | +import tomllib |
17 | 18 | from glob import glob |
18 | 19 | from os import environ, makedirs |
19 | 20 | from os.path import exists |
|
23 | 24 | from time import sleep |
24 | 25 | from typing import Dict, List, Optional |
25 | 26 |
|
| 27 | +import rich |
26 | 28 | import uvicorn |
27 | 29 | from devtools import pprint |
28 | 30 | from omegaconf import DictConfig, OmegaConf |
29 | | -from pydantic import BaseModel |
| 31 | +from pydantic import BaseModel, Field |
30 | 32 | from tqdm.auto import tqdm |
31 | 33 |
|
32 | 34 | from nemo_gym import PARENT_DIR |
| 35 | +from nemo_gym.config_types import BaseNeMoGymCLIConfig |
33 | 36 | from nemo_gym.global_config import ( |
34 | 37 | NEMO_GYM_CONFIG_DICT_ENV_VAR_NAME, |
35 | 38 | NEMO_GYM_CONFIG_PATH_ENV_VAR_NAME, |
@@ -60,23 +63,32 @@ def _run_command(command: str, working_directory: Path) -> Popen: # pragma: no |
60 | 63 | return Popen(command, executable="/bin/bash", shell=True, env=custom_env) |
61 | 64 |
|
62 | 65 |
|
63 | | -class RunConfig(BaseModel): |
64 | | - entrypoint: str |
| 66 | +class RunConfig(BaseNeMoGymCLIConfig): |
| 67 | + entrypoint: str = Field( |
| 68 | + description="Entrypoint for this command. This must be a relative path with 2 parts. Should look something like `responses_api_agents/simple_agent`." |
| 69 | + ) |
65 | 70 |
|
66 | 71 |
|
67 | 72 | class TestConfig(RunConfig): |
68 | | - should_validate_data: bool = False |
| 73 | + should_validate_data: bool = Field( |
| 74 | + default=False, |
| 75 | + description="Whether or not to validate the example data (examples, metrics, rollouts, etc) for this server.", |
| 76 | + ) |
69 | 77 |
|
70 | | - dir_path: Path = None # initialized in model_post_init |
| 78 | + _dir_path: Path # initialized in model_post_init |
71 | 79 |
|
72 | 80 | def model_post_init(self, context): # pragma: no cover |
73 | 81 | # TODO: This currently only handles relative entrypoints. Later on we can resolve the absolute path. |
74 | | - self.dir_path = Path(self.entrypoint) |
| 82 | + self._dir_path = Path(self.entrypoint) |
75 | 83 | assert not self.dir_path.is_absolute() |
76 | 84 | assert len(self.dir_path.parts) == 2 |
77 | 85 |
|
78 | 86 | return super().model_post_init(context) |
79 | 87 |
|
| 88 | + @property |
| 89 | + def dir_path(self) -> Path: |
| 90 | + return self._dir_path |
| 91 | + |
80 | 92 |
|
81 | 93 | class ServerInstanceDisplayConfig(BaseModel): |
82 | 94 | process_name: str |
@@ -274,6 +286,10 @@ def check_http_server_statuses(self) -> List[ServerStatus]: |
274 | 286 | def run( |
275 | 287 | global_config_dict_parser_config: Optional[GlobalConfigDictParserConfig] = None, |
276 | 288 | ): # pragma: no cover |
| 289 | + global_config_dict = get_global_config_dict(global_config_dict_parser_config=global_config_dict_parser_config) |
| 290 | + # Just here for help |
| 291 | + BaseNeMoGymCLIConfig.model_validate(global_config_dict) |
| 292 | + |
277 | 293 | rh = RunHelper() |
278 | 294 | rh.start(global_config_dict_parser_config) |
279 | 295 | rh.run_forever() |
@@ -386,8 +402,11 @@ def _format_pct(count: int, total: int) -> str: # pragma: no cover |
386 | 402 | return f"{count} / {total} ({100 * count / total:.2f}%)" |
387 | 403 |
|
388 | 404 |
|
389 | | -class TestAllConfig(BaseModel): |
390 | | - fail_on_total_and_test_mismatch: bool = False |
| 405 | +class TestAllConfig(BaseNeMoGymCLIConfig): |
| 406 | + fail_on_total_and_test_mismatch: bool = Field( |
| 407 | + default=False, |
| 408 | + description="There may be situations where there are an un-equal number of servers that exist vs have tests. This flag will fail the test job if this mismatch exists.", |
| 409 | + ) |
391 | 410 |
|
392 | 411 |
|
393 | 412 | def test_all(): # pragma: no cover |
@@ -474,6 +493,10 @@ def test_all(): # pragma: no cover |
474 | 493 |
|
475 | 494 |
|
476 | 495 | def dev_test(): # pragma: no cover |
| 496 | + global_config_dict = get_global_config_dict() |
| 497 | + # Just here for help |
| 498 | + BaseNeMoGymCLIConfig.model_validate(global_config_dict) |
| 499 | + |
477 | 500 | proc = Popen("pytest --cov=. --durations=10", shell=True) |
478 | 501 | exit(proc.wait()) |
479 | 502 |
|
@@ -592,4 +615,28 @@ def init_resources_server(): # pragma: no cover |
592 | 615 |
|
593 | 616 | def dump_config(): # pragma: no cover |
594 | 617 | global_config_dict = get_global_config_dict() |
| 618 | + # Just here for help |
| 619 | + BaseNeMoGymCLIConfig.model_validate(global_config_dict) |
| 620 | + |
595 | 621 | print(OmegaConf.to_yaml(global_config_dict, resolve=True)) |
| 622 | + |
| 623 | + |
| 624 | +def display_help(): # pragma: no cover |
| 625 | + global_config_dict = get_global_config_dict() |
| 626 | + # Just here for help |
| 627 | + BaseNeMoGymCLIConfig.model_validate(global_config_dict) |
| 628 | + |
| 629 | + pyproject_path = Path(PARENT_DIR) / "pyproject.toml" |
| 630 | + with pyproject_path.open("rb") as f: |
| 631 | + pyproject_data = tomllib.load(f) |
| 632 | + |
| 633 | + project_scripts = pyproject_data["project"]["scripts"] |
| 634 | + rich.print("""Run a command with `+h=true` or `+help=true` to see more detailed information! |
| 635 | +
|
| 636 | +[bold]Available CLI scripts[/bold] |
| 637 | +-----------------""") |
| 638 | + for script in project_scripts: |
| 639 | + if not script.startswith("ng_"): |
| 640 | + continue |
| 641 | + |
| 642 | + print(script) |
0 commit comments