Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions benchmarks/scripts/aggregate_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from pydantic import Field

from prime_rl.utils.pydantic_config import BaseSettings, parse_argv
from prime_rl.utils.config import BaseConfig, cli

SHORTENED_ATTN_MAPPING = {
"flash_attention_2": "FA2",
Expand All @@ -23,7 +23,7 @@
DEVICE_NAME_STRIP_WORDS = ["NVIDIA", "RTX", "80GB", "40GB"]


class AggregateConfig(BaseSettings):
class AggregateConfig(BaseConfig):
"""Configuration for aggregating benchmark results."""

artifacts_dir: Annotated[Path, Field(description="Directory containing benchmark JSON artifacts")]
Expand Down Expand Up @@ -197,7 +197,7 @@ def generate_markdown(


def main():
config = parse_argv(AggregateConfig)
config = cli(AggregateConfig)

results = load_json_dir(config.artifacts_dir)
print(f"Loaded {len(results)} benchmark results", file=sys.stderr)
Expand Down
6 changes: 3 additions & 3 deletions benchmarks/scripts/run_single_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import torch
from pydantic import Field

from prime_rl.utils.pydantic_config import BaseSettings, parse_argv
from prime_rl.utils.config import BaseConfig, cli

MAX_LORAS = 4

Expand All @@ -46,7 +46,7 @@ def extract_oom_error_reason(output: str) -> str | None:
return None


class BenchmarkConfig(BaseSettings):
class BenchmarkConfig(BaseConfig):
"""Configuration for running a single benchmark."""

type: Annotated[
Expand Down Expand Up @@ -248,7 +248,7 @@ def run_benchmark(config: BenchmarkConfig) -> None:


def main():
config = parse_argv(BenchmarkConfig)
config = cli(BenchmarkConfig)
run_benchmark(config)


Expand Down
3 changes: 1 addition & 2 deletions examples/alphabet_sort/slurm_rl.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
toml_files = ["rl.toml"]

# Usage: uv run rl @ examples/alphabet_sort/rl.toml @ examples/alphabet_sort/slurm_rl.toml
output_dir = "outputs/alphabet-sort"

[slurm]
Expand Down
3 changes: 1 addition & 2 deletions examples/hendrycks_sanity/slurm_rl.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
toml_files = ["rl.toml"]

# Usage: uv run rl @ examples/hendrycks_sanity/rl.toml @ examples/hendrycks_sanity/slurm_rl.toml
output_dir = "outputs/hendrycks-sanity"

[slurm]
Expand Down
3 changes: 1 addition & 2 deletions examples/reverse_text/slurm_rl.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
toml_files = ["rl.toml"]

# Usage: uv run rl @ examples/reverse_text/rl.toml @ examples/reverse_text/slurm_rl.toml
output_dir = "outputs/reverse-text-rl"

[slurm]
Expand Down
3 changes: 1 addition & 2 deletions examples/reverse_text/slurm_sft.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
toml_files = ["sft.toml"]

# Usage: uv run sft @ examples/reverse_text/sft.toml @ examples/reverse_text/slurm_sft.toml
output_dir = "outputs/reverse-text-sft"

[slurm]
Expand Down
3 changes: 1 addition & 2 deletions examples/wiki_search/slurm_rl.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
toml_files = ["rl.toml"]

# Usage: uv run rl @ examples/wiki_search/rl.toml @ examples/wiki_search/slurm_rl.toml
output_dir = "outputs/wiki-search"

[slurm]
Expand Down
3 changes: 1 addition & 2 deletions examples/wordle/slurm_rl.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
toml_files = ["rl.toml"]

# Usage: uv run rl @ examples/wordle/rl.toml @ examples/wordle/slurm_rl.toml
output_dir = "outputs/wordle-rl"

[slurm]
Expand Down
3 changes: 1 addition & 2 deletions examples/wordle/slurm_sft.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
toml_files = ["sft.toml"]

# Usage: uv run sft @ examples/wordle/sft.toml @ examples/wordle/slurm_sft.toml
output_dir = "outputs/wordle-sft"

[slurm]
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ dependencies = [
"loguru>=0.7.3",
"pyarrow>=21.0.0",
"pydantic>=1.10.13",
"pydantic-settings>=2.12.0",
"pydantic-config",
"tomli>=2.2.1",
"tomli-w>=1.2.0",
"numpy>=2.2.6",
Expand Down Expand Up @@ -91,6 +91,7 @@ dion = { git = "https://github.com/samsja/dion.git", rev = "d891eeb" }
transformers = { git = "https://github.com/huggingface/transformers.git", rev = "609e3d5" }
vllm = { url = "https://vllm-wheels.s3.us-west-2.amazonaws.com/7a06e5b05b170d7da31845866da0a99fc65253a1/vllm-0.16.0rc3-cp38-abi3-manylinux_2_31_x86_64.whl" }
flash-attn-cute = { git = "https://github.com/Dao-AILab/flash-attention.git", subdirectory = "flash_attn/cute", rev = "main" }
pydantic-config = { git = "https://github.com/samsja/pydantic_config.git", branch = "main" }
reverse-text = { index = "primeintellect" }

[tool.uv.extra-build-dependencies]
Expand Down
42 changes: 20 additions & 22 deletions skills/toml-config/SKILL.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ description: How to write and use TOML configs in prime-rl. Use when creating co

# TOML Config

All prime-rl commands use pydantic-settings with TOML configs and CLI overrides.
All prime-rl commands use `pydantic_config` (tyro-backed) with TOML configs and CLI overrides.

## Running with configs

Expand All @@ -19,11 +19,17 @@ uv run rl @ configs/debug/rl/train.toml
uv run inference @ config.toml --model.name Qwen/Qwen3-0.6B --server.port 8001

# Boolean flags: no value needed
uv run inference --model.enforce_eager # sets to true
uv run inference --no-model.enforce_eager # sets to false
uv run inference --model.enforce-eager # sets to true
uv run inference --no-model.enforce-eager # sets to false

# CLI-only (no TOML file)
uv run inference --model.name Qwen/Qwen3-0.6B --model.max_model_len 2048
uv run inference --model.name Qwen/Qwen3-0.6B --model.max-model-len 2048

# Compose multiple config files (later files override earlier ones)
uv run rl @ examples/reverse_text/rl.toml @ examples/reverse_text/slurm_rl.toml

# Nested config files: load a config for a specific section
uv run rl --model @ model.toml --data @ data.toml
```

## TOML structure
Expand All @@ -46,19 +52,6 @@ port = 8000

Putting a top-level field after a section header nests it inside that section, which causes validation errors.

## Config inheritance

Configs can inherit from other TOML files:

```toml
toml_files = ["base.toml"]

[model]
name = "Qwen/Qwen3-0.6B" # overrides base
```

Paths in `toml_files` are relative to the file containing the field.

## Setting None

Use the string `"None"` in TOML to set a field to None:
Expand All @@ -71,6 +64,11 @@ max_model_len = "None"

Both `rl` and `sft` commands support SLURM execution via an optional `[slurm]` section. When present, the run is submitted as a SLURM job instead of running locally.

SLURM configs are composed with the base config via CLI:
```bash
uv run rl @ examples/reverse_text/rl.toml @ examples/reverse_text/slurm_rl.toml
```

### RL SLURM

```toml
Expand Down Expand Up @@ -133,9 +131,9 @@ All accept `@ config.toml` and CLI overrides:

## Key files

- `src/prime_rl/utils/pydantic_config.py` — `parse_argv`, `BaseSettings`, `@` syntax parsing
- `src/prime_rl/rl.py` — unified RL entrypoint (local + SLURM)
- `src/prime_rl/configs/rl.py` — `RLConfig`, `SlurmConfig, DeploymentConfig`, `write_subconfigs`
- `src/prime_rl/trainer/sft/train.py` — unified SFT entrypoint (local + SLURM)
- `src/prime_rl/configs/sft.py` — `SFTConfig`, `SFTSlurmConfig`
- `src/prime_rl/utils/config.py` — `BaseConfig`, `cli`, `get_all_fields`
- `src/prime_rl/entrypoints/rl.py` — unified RL entrypoint (local + SLURM)
- `src/prime_rl/configs/rl.py` — `RLConfig`, `SlurmConfig, DeploymentConfig`
- `src/prime_rl/entrypoints/sft.py` — unified SFT entrypoint (local + SLURM)
- `src/prime_rl/configs/sft.py` — `SFTConfig`
- `configs/` — all config files, organized by task
4 changes: 2 additions & 2 deletions src/prime_rl/configs/env_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@

from prime_rl.configs.orchestrator import EnvConfig
from prime_rl.configs.shared import LogConfig
from prime_rl.utils.pydantic_config import BaseSettings
from prime_rl.utils.config import BaseConfig


class EnvServerConfig(BaseSettings):
class EnvServerConfig(BaseConfig):
"""Configures an environment server."""

env: EnvConfig = EnvConfig()
Expand Down
13 changes: 10 additions & 3 deletions src/prime_rl/configs/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pydantic import Field, model_validator

from prime_rl.configs.shared import BaseModelConfig
from prime_rl.utils.pydantic_config import BaseConfig, BaseSettings, get_all_fields
from prime_rl.utils.config import BaseConfig, get_all_fields
from prime_rl.utils.utils import rgetattr, rsetattr

# TODO: Set thinking/ solution budget
Expand Down Expand Up @@ -93,7 +93,7 @@ class ModelConfig(BaseModelConfig):
] = None


class WeightBroadcastConfig(BaseSettings):
class WeightBroadcastConfig(BaseConfig):
"""Configures weight broadcast settings."""

type: Annotated[Literal["nccl", "filesystem"], Field(description="The type of weight broadcast to use.")] = (
Expand All @@ -116,7 +116,7 @@ class WeightBroadcastConfig(BaseSettings):
]


class InferenceConfig(BaseSettings):
class InferenceConfig(BaseConfig):
"""Configures inference."""

# The server configuration
Expand Down Expand Up @@ -237,6 +237,13 @@ class InferenceConfig(BaseSettings):
),
] = False

vllm_extra: Annotated[
dict[str, Any],
Field(
description="Extra arguments to pass to vLLM. These are applied as attributes on the vLLM namespace after config translation.",
),
] = {}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New vllm_extra field leaks into vLLM namespace

Medium Severity

The new vllm_extra field gets included in to_vllm() via the get_all_fields(self) loop, setting namespace.vllm_extra = {} (the raw dict) as a vLLM namespace attribute. This is unintended — the individual entries from the dict are correctly applied in server(), but the raw dict itself persists on the namespace. Since vllm_extra is not a recognized vLLM argument, this spurious attribute could cause issues if vLLM validates namespace attributes or uses vars(args) downstream.

Additional Locations (1)

Fix in Cursor Fix in Web


@model_validator(mode="after")
def round_up_max_lora_rank(self):
"""Round up max_lora_rank to the nearest valid vLLM value.
Expand Down
4 changes: 2 additions & 2 deletions src/prime_rl/configs/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
TransportConfig,
WandbWithExtrasConfig,
)
from prime_rl.utils.pydantic_config import BaseConfig, BaseSettings
from prime_rl.utils.config import BaseConfig


class OptimizerConfig(BaseConfig):
Expand Down Expand Up @@ -665,7 +665,7 @@ class TeacherModelConfig(BaseConfig):
] = ModelConfig()


class OrchestratorConfig(BaseSettings):
class OrchestratorConfig(BaseConfig):
"""Configures the orchestrator for RL training."""

# The OAI client configuration
Expand Down
14 changes: 7 additions & 7 deletions src/prime_rl/configs/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from prime_rl.configs.trainer import (
NCCLWeightBroadcastConfig as TrainerNCCLWeightBroadcastConfig,
)
from prime_rl.utils.pydantic_config import BaseSettings
from prime_rl.utils.config import BaseConfig
from prime_rl.utils.validation import (
validate_shared_ckpt_config,
validate_shared_max_async_level,
Expand All @@ -49,7 +49,7 @@
)


class SharedLogConfig(BaseSettings):
class SharedLogConfig(BaseConfig):
"""Configures shared logging."""

level: Annotated[str | None, Field(description="The log level to use.")] = "info"
Expand All @@ -62,7 +62,7 @@ class SharedLogConfig(BaseSettings):
] = False


class SharedWandbConfig(BaseSettings):
class SharedWandbConfig(BaseConfig):
"""Configures shared W&B configs."""

project: Annotated[str | None, Field(description="The W&B project to use.")] = "prime-rl"
Expand All @@ -72,7 +72,7 @@ class SharedWandbConfig(BaseSettings):
offline: Annotated[bool | None, Field(description="Whether to run W&B in offline mode.")] = False


class SharedCheckpointConfig(BaseSettings):
class SharedCheckpointConfig(BaseConfig):
"""Configures shared checkpoint configs."""

interval: Annotated[int | None, Field(description="The interval at which to save checkpoints.")] = None
Expand All @@ -98,7 +98,7 @@ class SharedCheckpointConfig(BaseSettings):
] = None


class SharedModelConfig(BaseSettings):
class SharedModelConfig(BaseConfig):
"""Configures shared model settings."""

name: Annotated[
Expand All @@ -107,7 +107,7 @@ class SharedModelConfig(BaseSettings):
] = "Qwen/Qwen3-0.6B"


class SharedWeightBroadcastConfig(BaseSettings):
class SharedWeightBroadcastConfig(BaseConfig):
"""Configures shared weight broadcast settings."""

type: Annotated[Literal["nccl", "filesystem"], Field(description="The type of weight broadcast to use.")] = (
Expand Down Expand Up @@ -174,7 +174,7 @@ def teacher_inference_not_supported(self):
]


class RLConfig(BaseSettings):
class RLConfig(BaseConfig):
"""Configures an RL training run."""

trainer: TrainerConfig
Expand Down
4 changes: 2 additions & 2 deletions src/prime_rl/configs/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
SchedulerConfig,
TokenizerConfig,
)
from prime_rl.utils.pydantic_config import BaseConfig, BaseSettings
from prime_rl.utils.config import BaseConfig


class BaseDataConfig(BaseModel):
Expand Down Expand Up @@ -151,7 +151,7 @@ class MultiNodeDeploymentConfig(BaseDeploymentConfig):
]


class SFTConfig(BaseSettings):
class SFTConfig(BaseConfig):
"""Configures the SFT trainer"""

# The model configuration
Expand Down
2 changes: 1 addition & 1 deletion src/prime_rl/configs/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from pydantic import BaseModel, Field, model_validator

from prime_rl.utils.pydantic_config import BaseConfig
from prime_rl.utils.config import BaseConfig


class SlurmConfig(BaseConfig):
Expand Down
4 changes: 2 additions & 2 deletions src/prime_rl/configs/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
TransportConfig,
WandbConfig,
)
from prime_rl.utils.pydantic_config import BaseConfig, BaseSettings
from prime_rl.utils.config import BaseConfig

# -- Shared trainer configs (used by both SFT and RL trainers) --

Expand Down Expand Up @@ -646,7 +646,7 @@ class NCCLWeightBroadcastConfig(BaseWeightBroadcastConfig):
]


class TrainerConfig(BaseSettings):
class TrainerConfig(BaseConfig):
"""Configures the RL trainer"""

# The model configuration
Expand Down
4 changes: 2 additions & 2 deletions src/prime_rl/entrypoints/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
import tomli_w

from prime_rl.configs.rl import RLConfig
from prime_rl.utils.config import cli
from prime_rl.utils.logger import setup_logger
from prime_rl.utils.pathing import validate_output_dir
from prime_rl.utils.process import cleanup_processes, cleanup_threads, monitor_process
from prime_rl.utils.pydantic_config import parse_argv
from prime_rl.utils.utils import (
get_free_port,
get_log_dir,
Expand Down Expand Up @@ -448,7 +448,7 @@ def rl(config: RLConfig):


def main():
rl(parse_argv(RLConfig))
rl(cli(RLConfig))


if __name__ == "__main__":
Expand Down
Loading