Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 29 additions & 11 deletions src/prime_rl/inference/vllm/serving_chat_with_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,33 @@ class ChatCompletionRequestWithTokens(ChatCompletionRequest):
class OpenAIServingChatWithTokens(OpenAIServingChat):
"""OpenAI-compatible generate API that allows token-in and routed experts capture."""

def _validate_prompt_has_generation_room(self, engine_prompt) -> int:
prompt_len = self._extract_prompt_len(engine_prompt)
max_model_len = self.model_config.max_model_len
if prompt_len >= max_model_len:
raise VLLMValidationError(
f"This model's maximum context length is "
f"{max_model_len} tokens. However, your request has "
f"{prompt_len} input tokens. Please reduce the length of "
"the input messages.",
parameter="input_tokens",
value=prompt_len,
)
return prompt_len

async def render_chat_request(self, request: ChatCompletionRequest):
result = await super().render_chat_request(request)
if isinstance(result, ErrorResponse) or isinstance(request, ChatCompletionRequestWithTokens):
return result

_, engine_prompts = result
try:
for engine_prompt in engine_prompts:
self._validate_prompt_has_generation_room(engine_prompt)
except ValueError as e:
return self.create_error_response(e)
return result

async def chat_completion_full_generator(
self,
request: ChatCompletionRequest,
Expand Down Expand Up @@ -159,21 +186,12 @@ async def create_chat_completion_with_tokens(
# have unique request ids.
sub_request_id = request_id if len(engine_prompts) == 1 else f"{request_id}_{i}"

prompt_len = self._extract_prompt_len(engine_prompt)
if prompt_len >= max_model_len:
raise VLLMValidationError(
f"This model's maximum context length is "
f"{max_model_len} tokens. However, your request has "
f"{prompt_len} input tokens. Please reduce the length of "
"the input messages.",
parameter="input_tokens",
value=prompt_len,
)
prompt_len = self._validate_prompt_has_generation_room(engine_prompt)

max_tokens = get_max_tokens(
max_model_len,
request.max_completion_tokens if request.max_completion_tokens is not None else request.max_tokens,
self._extract_prompt_len(engine_prompt),
prompt_len,
self.default_sampling_params,
self.override_max_tokens,
)
Expand Down
71 changes: 71 additions & 0 deletions tests/unit/inference/test_serving_chat_with_tokens.py
Original file line number Diff line number Diff line change
@@ -1 +1,72 @@
import asyncio
from types import MethodType, SimpleNamespace

import pytest
from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat
from vllm.exceptions import VLLMValidationError

from prime_rl.inference.vllm.serving_chat_with_tokens import (
ChatCompletionRequestWithTokens,
OpenAIServingChatWithTokens,
)


def _handler(max_model_len: int) -> OpenAIServingChatWithTokens:
handler = object.__new__(OpenAIServingChatWithTokens)
handler.model_config = SimpleNamespace(max_model_len=max_model_len)

def _extract_prompt_len(self, prompt):
return len(prompt["prompt_token_ids"])

handler._extract_prompt_len = MethodType(_extract_prompt_len, handler)
return handler


def test_validate_prompt_has_generation_room_allows_one_token_margin():
handler = _handler(max_model_len=4)

assert handler._validate_prompt_has_generation_room({"prompt_token_ids": [1, 2, 3]}) == 3


def test_validate_prompt_has_generation_room_rejects_full_context():
handler = _handler(max_model_len=4)

with pytest.raises(VLLMValidationError) as exc_info:
handler._validate_prompt_has_generation_room({"prompt_token_ids": [1, 2, 3, 4]})

message = str(exc_info.value)
assert "maximum context length is 4 tokens" in message
assert "4 input tokens" in message


def test_render_chat_request_returns_context_error_before_max_tokens_underflows(monkeypatch):
async def fake_render_chat_request(self, request):
return "conversation", [{"prompt_token_ids": [1, 2, 3, 4]}]

monkeypatch.setattr(OpenAIServingChat, "render_chat_request", fake_render_chat_request)

handler = _handler(max_model_len=4)

def create_error_response(self, error):
return {"error": str(error)}

handler.create_error_response = MethodType(create_error_response, handler)

result = asyncio.run(handler.render_chat_request(object()))

assert "maximum context length is 4 tokens" in result["error"]
assert "4 input tokens" in result["error"]


def test_render_chat_request_defers_token_endpoint_validation_until_tokens_are_installed(monkeypatch):
async def fake_render_chat_request(self, request):
return "conversation", [{"prompt_token_ids": [1, 2, 3, 4]}]

monkeypatch.setattr(OpenAIServingChat, "render_chat_request", fake_render_chat_request)

handler = _handler(max_model_len=4)
request = object.__new__(ChatCompletionRequestWithTokens)

result = asyncio.run(handler.render_chat_request(request))

assert result == ("conversation", [{"prompt_token_ids": [1, 2, 3, 4]}])
Loading