Skip to content
57 changes: 53 additions & 4 deletions responses_api_models/vllm_model/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ class VLLMModelConfig(BaseResponsesAPIModelConfig):
model: str
return_token_id_information: bool

uses_reasoning_parser: bool

def model_post_init(self, context):
if isinstance(self.base_url, str):
self.base_url = [self.base_url]
Expand All @@ -80,7 +82,9 @@ def model_post_init(self, context):

self._session_id_to_client: Dict[str, NeMoGymAsyncOpenAI] = dict()

self._converter = VLLMConverter(return_token_id_information=self.config.return_token_id_information)
self._converter = VLLMConverter(
return_token_id_information=self.config.return_token_id_information,
)

return super().model_post_init(context)

Expand Down Expand Up @@ -142,6 +146,9 @@ async def chat_completions(
client = self._session_id_to_client[session_id]

create_params = body_dict
# Always disable skip_special_tokens to preserve <think> </think> tags for reasoning parsing
create_params |= dict(skip_special_tokens=False)

if self.config.return_token_id_information:
create_params |= dict(
logprobs=True,
Expand All @@ -154,11 +161,53 @@ async def chat_completions(
# prompt_logprobs=0,
)

if self.config.uses_reasoning_parser:
for message_dict in body_dict["messages"]:
if message_dict.get("role") != "assistant" or "content" not in message_dict:
continue

content = message_dict["content"]
if isinstance(content, str):
reasoning_matches, remaining_content = self._converter._extract_reasoning_from_content(content)
message_dict["content"] = remaining_content
if reasoning_matches:
message_dict["reasoning_content"] = reasoning_matches[0]
elif isinstance(content, list):
reasoning_content = None
for content_item_dict in content:
reasoning_matches, remaining_content = self._converter._extract_reasoning_from_content(
content_item_dict["text"]
)
assert reasoning_content is None or not reasoning_matches, (
f"Found multiple reasoning matches in a single assistant message content item list!\nMessage: {message_dict}"
)

# Even though we set the reasoning content already here, we still loop through all the content item dicts for the assert above.
content_item_dict["text"] = remaining_content
if reasoning_matches:
message_dict["reasoning_content"] = reasoning_matches[0]
elif not content:
# No content or content None is a no-op
pass
else:
raise NotImplementedError

chat_completion_dict = await client.create_chat_completion(**create_params)

choice_dict = chat_completion_dict["choices"][0]
assert not choice_dict["message"].get("reasoning_content"), (
"Please do not use a reasoning parser in vLLM! There is one source of truth for handling data (including reasoning), which is NeMo Gym!"
)
if self.config.uses_reasoning_parser:
reasoning_content = choice_dict["message"].get("reasoning_content")
if reasoning_content:
choice_dict["message"].pop("reasoning_content")

# We wrap this here in think tags for Gym's sake and to return a valid OpenAI Chat Completions response.
choice_dict["message"]["content"] = self._converter._wrap_reasoning_in_think_tags(
[reasoning_content]
) + (choice_dict["message"]["content"] or "")
else:
assert not choice_dict["message"].get("reasoning_content"), (
"Please do not use a reasoning parser in vLLM! There is one source of truth for handling data (including reasoning), which is NeMo Gym!"
)

if self.config.return_token_id_information:
log_probs = choice_dict["logprobs"]["content"]
Expand Down
1 change: 1 addition & 0 deletions responses_api_models/vllm_model/configs/vllm_model.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ policy_model:
api_key: ${policy_api_key}
model: ${policy_model_name}
return_token_id_information: false
uses_reasoning_parser: true
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ policy_model:
api_key: ${policy_api_key}
model: ${policy_model_name}
return_token_id_information: true
uses_reasoning_parser: true
Loading