diff --git a/src/prime_rl/configs/inference.py b/src/prime_rl/configs/inference.py index d7efbdc92..01d68bb79 100644 --- a/src/prime_rl/configs/inference.py +++ b/src/prime_rl/configs/inference.py @@ -1,97 +1,94 @@ -from argparse import Namespace from pathlib import Path -from typing import Annotated, Any, Literal, TypeAlias +from typing import Annotated, Literal, TypeAlias from pydantic import BaseModel, ConfigDict, Field, model_validator -from prime_rl.configs.shared import BaseModelConfig, SlurmConfig -from prime_rl.utils.pydantic_config import BaseConfig, BaseSettings -from prime_rl.utils.utils import rgetattr, rsetattr +from prime_rl.configs.shared import SlurmConfig +from prime_rl.utils.pydantic_config import BaseSettings -# TODO: Set thinking/ solution budget - - -class ServerConfig(BaseConfig): - """Configures the inference server.""" +# Valid vLLM max_lora_rank values (from vllm/config/lora.py) +# TODO: on newer vLLM, can import via `get_args(vllm.config.lora.MaxLoRARanks)` +VALID_VLLM_LORA_RANKS = (8, 16, 32, 64, 128, 256, 320, 512) - host: Annotated[str | None, Field(description="The host to bind to.")] = None - port: Annotated[int, Field(description="The port to bind to.")] = 8000 +class VLLMConfig(BaseModel): + """Configures vLLM. Arguments must match exactly with vLLM's CLI arguments.""" -class ParallelConfig(BaseConfig): - """Configures multi-node and multi-GPU setups through different types of parallelism (TP, DP, PP).""" + model_config = ConfigDict(extra="allow") - tp: Annotated[ + model_name: Annotated[str, Field(description="The name of the model to use.")] = "Qwen/Qwen3-0.6B" + tool_call_parser: Annotated[str | None, Field(description="The tool call parser to use.")] = None + reasoning_parser: Annotated[str | None, Field(description="Parser reasoning parser to use.")] = None + data_parallel_size: Annotated[int, Field(description="The data parallel size to use.")] = 1 + data_parallel_size_local: Annotated[ + int | None, Field(description="The data parallel size to use on this node.") + ] = None + tensor_parallel_size: Annotated[int, Field(description="The tensor parallel size to use.")] = 1 + enable_lora: Annotated[bool, Field(description="Whether to enable LoRA.")] = False + max_lora_rank: Annotated[int | None, Field(description="The maximum LoRA rank to use.")] = None + max_loras: Annotated[ int, Field( - description="The tensor parallel size. It is passed to vLLM as `--tensor-parallel-size`", + description="The maximum number of LoRAs to use.", ), - ] = 1 - - dp: Annotated[ + ] = 8 + max_cpu_loras: Annotated[ int, Field( - ge=1, - description="The data parallel size. It is passed to vLLM as `--data-parallel-size`", + description="The maximum number of LoRAs to use on CPU.", ), - ] = 1 - - def __str__(self) -> str: - return f"tp={self.tp} dp={self.dp}" - - -class ModelConfig(BaseModelConfig): - """Configures the inference model. Most arguments are passed directly to the vLLM LLM class (https://docs.vllm.ai/en/latest/api/vllm.LLM.html).""" + ] = 100 + api_server_count: Annotated[int, Field(description="The number of API servers to use.")] = 1 - dtype: Annotated[ - Literal["auto", "float16", "bfloat16", "float32"], - Field( - description="Data type for model weights and activations. If 'auto' will use FP16 precision for FP32 and FP16 models, and BF16 precision for BF16 models. Passed to vLLM as `--dtype`", - ), - ] = "auto" + @model_validator(mode="after") + def validate_valid_vllm_arg(self): + # TODO + return self - max_model_len: Annotated[ - int | None, - Field( - description="Maximum model context length. If None, will use the maximum context length from model config. Passed to vLLM as `--max-model-len`", - ), - ] = None + @model_validator(mode="after") + def auto_setup_tool_call_parser(self): + # TODO + return self - enforce_eager: Annotated[ - bool, - Field( - description="Whether to enforce eager mode. If False, will use PyTorch eager and cuda graphs in hybrid for maximal performance. Passed to vLLM as `--enforce-eager`", - ), - ] = False + @model_validator(mode="after") + def auto_setup_reasoning_parser(self): + # TODO + return self - trust_remote_code: Annotated[ - bool, - Field( - description="Whether to trust remote code. Passed to vLLM engine init", - ), - ] = False + @model_validator(mode="after") + def auto_setup_max_lora_rank(self): + """ + Auto-setup max_lora_rank by rounding up to the nearest valid vLLM value. - tool_call_parser: Annotated[ - str | None, - Field( - description="The tool call parser to use. Passed to vLLM as `--tool-call-parser`. " - 'Set to "auto" to infer from the model name.', - ), - ] = None + vLLM only accepts specific values for max_lora_rank: (1, 8, 16, 32, 64, 128, 256, 320, 512). + This validator ensures that any configured rank is rounded up to the minimum valid value + that can serve adapters of the requested rank. + """ + if self.max_lora_rank is not None: + original_rank = self.max_lora_rank + for valid_rank in VALID_VLLM_LORA_RANKS: + if valid_rank >= self.max_lora_rank: + self.max_lora_rank = valid_rank + break + else: + raise ValueError(f"max_lora_rank={original_rank} exceeds vLLM maximum of {VALID_VLLM_LORA_RANKS[-1]}") + return self - reasoning_parser: Annotated[ - str | None, - Field( - description="Parser for extracting reasoning content from model outputs. Passed to vLLM as `--reasoning-parser`. Setting this enables reasoning mode.", - ), - ] = None + @model_validator(mode="after") + def auto_setup_api_server_count(self): + """ + Ensures that we have at least as many API servers as data parallel + size. Unless LoRA is enabled, in which case only one API server is + supported (vLLM limitation). + """ + if "api_server_count" not in self.model_fields_set: + min_api_server_count = self.data_parallel_size_local or self.data_parallel_size + if self.api_server_count < min_api_server_count: + self.api_server_count = min_api_server_count - rope_scaling: Annotated[ - dict[str, Any] | str | None, - Field( - description='RoPE scaling configuration as a dict. For YaRN, use: {rope_type="yarn", factor=4.0, original_max_position_embeddings=32768} or. Passed to vLLM as `--rope-scaling`.', - ), - ] = None + if self.enable_lora: + self.api_server_count = 1 # LoRA requires only one API server + return self class WeightBroadcastConfig(BaseSettings): @@ -102,21 +99,6 @@ class WeightBroadcastConfig(BaseSettings): ) -# Valid vLLM max_lora_rank values (from vllm/config/lora.py) -# TODO: on newer vLLM, can import via `get_args(vllm.config.lora.MaxLoRARanks)` -VALID_VLLM_LORA_RANKS = (8, 16, 32, 64, 128, 256, 320, 512) - -# vLLM all2all backend options for expert-parallel deployments. -All2AllBackend = Literal[ - "allgather_reducescatter", - "deepep_high_throughput", - "deepep_low_latency", - "flashinfer_all2allv", - "naive", - "pplx", -] - - class BaseInferenceDeploymentConfig(BaseModel): """Base deployment config for inference.""" @@ -147,126 +129,12 @@ class MultiNodeInferenceDeploymentConfig(BaseInferenceDeploymentConfig): class InferenceConfig(BaseSettings): """Configures inference.""" - # The server configuration - server: ServerConfig = ServerConfig() - - # The model configuration - model: ModelConfig = Field(default_factory=ModelConfig) - - # The parallel configuration - parallel: ParallelConfig = ParallelConfig() - - enable_lora: Annotated[ - bool, - Field( - description="Whether to enable LORA. Passed to vLLM as `--enable-lora`", - ), - ] = False - - max_loras: Annotated[ - int, - Field( - description="The maximum number of LoRAs to use. Passed to vLLM as `--max-loras`", - ), - ] = 8 - - # TODO: The default value is very high because our areal impl for lora isn't ideal - # We add a lora with the same name instead of changing weights inplace - # Because we dont cancel requests that are past max_async, these requests could be using a LoRA that gets unloaded which will crash the inference server - max_cpu_loras: Annotated[ - int, - Field( - description="The maximum number of LoRAs to use on CPU. Passed to vLLM as `--max-cpu-loras`", - ), - ] = 100 - - max_lora_rank: Annotated[ - int | None, - Field( - description="The maximum LoRA rank to use. Passed to vLLM as `--max-lora-rank`", - ), - ] = None - - enable_prefix_caching: Annotated[ - bool | None, - Field( - description="Whether to enable prefix caching. Passed to vLLM as `--enable-prefix-caching`", - ), - ] = None - - gpu_memory_utilization: Annotated[ - float, - Field( - description="The GPU memory utilization to use. Passed to vLLM as `--gpu-memory-utilization`", - ), - ] = 0.9 - - api_server_count: Annotated[ - int, - Field( - ge=1, - description="The number of API servers to use. Passed to vLLM as `--api-server-count`", - ), - ] = 1 - - data_parallel_size_local: Annotated[ - int | None, - Field( - ge=1, - description="Number of data parallel replicas to run on this node. Passed to vLLM as `--data-parallel-size-local`.", - ), - ] = None - - data_parallel_rpc_port: Annotated[ - int, - Field( - ge=1, - le=65535, - description="RPC port for data parallel communication. Passed to vLLM as `--data-parallel-rpc-port`.", - ), - ] = 13345 - - seed: Annotated[ - int, - Field( - description="Seed the inference components. Passed to vLLM as `--seed`", - ), - ] = 0 - - enable_expert_parallel: Annotated[ - bool, - Field( - description="Enable expert parallelism for MoE models. Passed to vLLM as `--enable-expert-parallel`.", - ), - ] = False - - all2all_backend: Annotated[ - All2AllBackend, - Field( - description="All-to-all backend for expert parallel communication. Passed to vLLM as `--all2all-backend`.", - ), - ] = "allgather_reducescatter" - - enable_eplb: Annotated[ - bool, - Field( - description="Enable expert parallel load balancer (EPLB). Passed to vLLM as `--enable-eplb`.", - ), - ] = False + vllm: VLLMConfig = VLLMConfig() weight_broadcast: Annotated[WeightBroadcastConfig, Field(description="The weight broadcast config.")] = ( WeightBroadcastConfig() ) - enable_return_routed_experts: Annotated[ - bool, - Field( - description="Whether to enable return routed experts. Passed to vLLM as `--enable-return-routed-experts`", - ), - ] = False - - # Launcher-only fields - deployment: Annotated[ InferenceDeploymentConfig, Field( @@ -299,87 +167,3 @@ def auto_setup_slurm_template(self): templates_dir = Path(prime_rl.__file__).parent / "templates" self.slurm.template_path = templates_dir / "inference.sbatch.j2" return self - - @model_validator(mode="after") - def auto_setup_max_lora_rank(self): - """Auto-setup max_lora_rank by rounding up to the nearest valid vLLM value. - - vLLM only accepts specific values for max_lora_rank: (1, 8, 16, 32, 64, 128, 256, 320, 512). - This validator ensures that any configured rank is rounded up to the minimum valid value - that can serve adapters of the requested rank. - """ - if self.max_lora_rank is not None: - original_rank = self.max_lora_rank - for valid_rank in VALID_VLLM_LORA_RANKS: - if valid_rank >= self.max_lora_rank: - self.max_lora_rank = valid_rank - break - else: - raise ValueError(f"max_lora_rank={original_rank} exceeds vLLM maximum of {VALID_VLLM_LORA_RANKS[-1]}") - return self - - @model_validator(mode="after") - def auto_setup_api_server_count(self): - """ - Ensures that we have at least as many API servers as data parallel - size. Unless LoRA is enabled, in which case only one API server is - supported (vLLM limitation). - """ - if "api_server_count" not in self.model_fields_set: - min_api_server_count = self.data_parallel_size_local or self.parallel.dp - if self.api_server_count < min_api_server_count: - self.api_server_count = min_api_server_count - - if self.enable_lora: - self.api_server_count = 1 # LoRA requires only one API server - return self - - def to_vllm(self) -> Namespace: - """Convert InferenceConfig to vLLM-compatible Namespace.""" - namespace = Namespace() - to_vllm = { - "server.host": "host", - "server.port": "port", - "model.name": "model", - "model.dtype": "dtype", - "model.max_model_len": "max_model_len", - "model.enforce_eager": "enforce_eager", - "model.trust_remote_code": "trust_remote_code", - "model.tool_call_parser": "tool_call_parser", - "model.reasoning_parser": "reasoning_parser", - "model.rope_scaling": "rope_scaling", - "parallel.tp": "tensor_parallel_size", - "parallel.dp": "data_parallel_size", - "data_parallel_size_local": "data_parallel_size_local", - "data_parallel_rpc_port": "data_parallel_rpc_port", - "enable_lora": "enable_lora", - "enable_prefix_caching": "enable_prefix_caching", - "max_loras": "max_loras", - "max_cpu_loras": "max_cpu_loras", - "max_lora_rank": "max_lora_rank", - "gpu_memory_utilization": "gpu_memory_utilization", - "api_server_count": "api_server_count", - "enable_return_routed_experts": "enable_return_routed_experts", - "enable_expert_parallel": "enable_expert_parallel", - "all2all_backend": "all2all_backend", - "enable_eplb": "enable_eplb", - "seed": "seed", - } - - for config_key, vllm_key in to_vllm.items(): - value = rgetattr(self, config_key.replace("-", "_")) - rsetattr(namespace, vllm_key, value) - - # Set `logprobs_mode` to `processed_logprobs` by default - rsetattr(namespace, "logprobs_mode", "processed_logprobs") - - # Remove reasoning_parser if not set (vLLM doesn't accept None) - if namespace.reasoning_parser is None: - delattr(namespace, "reasoning_parser") - - # Remove rope_scaling if not set (vLLM doesn't accept None) - if hasattr(namespace, "rope_scaling"): - if namespace.rope_scaling is None: - delattr(namespace, "rope_scaling") - - return namespace diff --git a/src/prime_rl/entrypoints/inference.py b/src/prime_rl/entrypoints/inference.py index ff205f868..8f9005147 100644 --- a/src/prime_rl/entrypoints/inference.py +++ b/src/prime_rl/entrypoints/inference.py @@ -91,9 +91,7 @@ def inference_local(config: InferenceConfig): logger.success("Dry run complete. To start inference locally, remove --dry-run from your command.") return - host = config.server.host or "0.0.0.0" - port = config.server.port - logger.info(f"Starting inference on http://{host}:{port}/v1\n") + logger.info("Starting inference\n") setup_vllm_env(config) diff --git a/src/prime_rl/inference/server.py b/src/prime_rl/inference/server.py index a7ebae342..0b96dc206 100644 --- a/src/prime_rl/inference/server.py +++ b/src/prime_rl/inference/server.py @@ -10,7 +10,7 @@ def setup_vllm_env(config: InferenceConfig): # spawn is more robust in vLLM nightlies and Qwen3-VL (fork can deadlock with multithreaded processes) os.environ.setdefault("VLLM_WORKER_MULTIPROC_METHOD", "spawn") - if config.enable_lora: + if config.vllm.enable_lora: os.environ["VLLM_ALLOW_RUNTIME_LORA_UPDATING"] = "True" diff --git a/src/prime_rl/inference/vllm/server.py b/src/prime_rl/inference/vllm/server.py index 202de2761..f3b7bcdb1 100644 --- a/src/prime_rl/inference/vllm/server.py +++ b/src/prime_rl/inference/vllm/server.py @@ -319,7 +319,7 @@ def server(config: InferenceConfig, vllm_args: list[str]): parser = FlexibleArgumentParser(description="vLLM OpenAI-Compatible RESTful API server.") parser = make_arg_parser(parser) - args = parser.parse_args(args=vllm_args, namespace=config.to_vllm()) + args = parser.parse_args(args=vllm_args, namespace=Namespace(**config.vllm.model_dump(exclude_none=True))) assert args is not None validate_parsed_serve_args(args) @@ -328,6 +328,9 @@ def server(config: InferenceConfig, vllm_args: list[str]): if args.tool_call_parser is not None: logger.info(f"Using tool_call_parser='{args.tool_call_parser}' for model '{args.model}'") + # Set `logprobs_mode` to `processed_logprobs` by default + args.logprobs_mode = "processed_logprobs" + # Set the worker extension class based on the broadcast backend args.worker_extension_cls = WORKER_EXTENSION_CLS[config.weight_broadcast.type]