Skip to content

Commit 08e384a

Browse files
committed
Merge NVIDIA-NeMo/Gym PR #129 into tkonuk/compat-openai-199
2 parents d873f71 + 314600e commit 08e384a

5 files changed

Lines changed: 326 additions & 16 deletions

File tree

.pre-commit-config.yaml

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -50,15 +50,8 @@ repos:
5050
exclude: '^\.github/'
5151
types: [file]
5252
- id: update-readme-table
53-
name: "Update resource server list in README"
54-
entry: |
55-
bash -c '
56-
if git diff --cached --name-only --diff-filter=ACMR | grep -q "^resources_servers/.*/configs/.*\.yaml$"; then
57-
echo "[pre-commit] Saw staged config changes; updating resource servers in README." >&2
58-
python scripts/update_resource_servers.py
59-
else
60-
echo [pre-commit] "No staged config changes; skipping README update." >&2
61-
fi
62-
'
63-
language: system
64-
files: ^README\.md$
53+
name: Update resource server list in README
54+
language: python
55+
entry: python scripts/update_resource_servers.py
56+
additional_dependencies: [pyyaml]
57+
files: ^README\.md$|^resources_servers/.*/configs/.*\.yaml$

responses_api_models/vllm_model/app.py

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ class VLLMModelConfig(BaseResponsesAPIModelConfig):
6060
model: str
6161
return_token_id_information: bool
6262

63+
uses_reasoning_parser: bool
64+
6365
def model_post_init(self, context):
6466
if isinstance(self.base_url, str):
6567
self.base_url = [self.base_url]
@@ -80,7 +82,9 @@ def model_post_init(self, context):
8082

8183
self._session_id_to_client: Dict[str, NeMoGymAsyncOpenAI] = dict()
8284

83-
self._converter = VLLMConverter(return_token_id_information=self.config.return_token_id_information)
85+
self._converter = VLLMConverter(
86+
return_token_id_information=self.config.return_token_id_information,
87+
)
8488

8589
return super().model_post_init(context)
8690

@@ -154,11 +158,50 @@ async def chat_completions(
154158
# prompt_logprobs=0,
155159
)
156160

161+
if self.config.uses_reasoning_parser:
162+
for message_dict in body_dict["messages"]:
163+
if message_dict.get("role") != "assistant" or "content" not in message_dict:
164+
continue
165+
166+
content = message_dict["content"]
167+
if isinstance(content, str):
168+
reasoning_matches, remaining_content = self._converter._extract_reasoning_from_content(content)
169+
message_dict["content"] = remaining_content
170+
if reasoning_matches:
171+
message_dict["reasoning_content"] = reasoning_matches[0]
172+
elif isinstance(content, list):
173+
reasoning_content = None
174+
for content_item_dict in content:
175+
reasoning_matches, remaining_content = self._converter._extract_reasoning_from_content(
176+
content_item_dict["text"]
177+
)
178+
assert reasoning_content is None or not reasoning_matches, (
179+
f"Found multiple reasoning matches in a single assistant message content item list!\nMessage: {message_dict}"
180+
)
181+
182+
# Even though we set the reasoning content already here, we still loop through all the content item dicts for the assert above.
183+
content_item_dict["text"] = remaining_content
184+
if reasoning_matches:
185+
message_dict["reasoning_content"] = reasoning_matches[0]
186+
else:
187+
raise NotImplementedError
188+
157189
chat_completion_dict = await client.create_chat_completion(**create_params)
158190
choice_dict = chat_completion_dict["choices"][0]
159-
assert not choice_dict["message"].get("reasoning_content"), (
160-
"Please do not use a reasoning parser in vLLM! There is one source of truth for handling data (including reasoning), which is NeMo Gym!"
161-
)
191+
if self.config.uses_reasoning_parser:
192+
reasoning_content = choice_dict["message"].get("reasoning_content")
193+
if reasoning_content:
194+
choice_dict["message"].pop("reasoning_content")
195+
196+
# We wrap this here in think tags for Gym's sake and to return a valid OpenAI Chat Completions response.
197+
choice_dict["message"]["content"] = (
198+
self._converter._wrap_reasoning_in_think_tags([reasoning_content])
199+
+ choice_dict["message"]["content"]
200+
)
201+
else:
202+
assert not choice_dict["message"].get("reasoning_content"), (
203+
"Please do not use a reasoning parser in vLLM! There is one source of truth for handling data (including reasoning), which is NeMo Gym!"
204+
)
162205

163206
if self.config.return_token_id_information:
164207
log_probs = choice_dict["logprobs"]["content"]

responses_api_models/vllm_model/configs/vllm_model.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@ policy_model:
66
api_key: ${policy_api_key}
77
model: ${policy_model_name}
88
return_token_id_information: false
9+
uses_reasoning_parser: true

responses_api_models/vllm_model/configs/vllm_model_for_training.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@ policy_model:
66
api_key: ${policy_api_key}
77
model: ${policy_model_name}
88
return_token_id_information: true
9+
uses_reasoning_parser: true

responses_api_models/vllm_model/tests/test_app.py

Lines changed: 272 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -670,6 +670,7 @@ def _setup_server(self, monkeypatch: MonkeyPatch):
670670
entrypoint="",
671671
name="",
672672
return_token_id_information=False,
673+
uses_reasoning_parser=False,
673674
)
674675

675676
get_global_config_dict_mock = MagicMock()
@@ -1477,6 +1478,7 @@ def test_client_session_routing(self, monkeypatch: MonkeyPatch):
14771478
entrypoint="",
14781479
name="",
14791480
return_token_id_information=False,
1481+
uses_reasoning_parser=False,
14801482
)
14811483
server = VLLMModel(config=config, server_client=MagicMock(spec=ServerClient))
14821484
app = server.setup_webserver()
@@ -1586,6 +1588,276 @@ def test_client_session_routing(self, monkeypatch: MonkeyPatch):
15861588
data = response_2_2.json()
15871589
assert data["output"][0]["content"][0]["text"] == "2"
15881590

1591+
def test_responses_reasoning_parser(self, monkeypatch: MonkeyPatch):
1592+
server = self._setup_server(monkeypatch)
1593+
server.config.uses_reasoning_parser = True
1594+
1595+
app = server.setup_webserver()
1596+
client = TestClient(app)
1597+
1598+
mock_chat_completion = NeMoGymChatCompletion(
1599+
id="chtcmpl-123",
1600+
object="chat.completion",
1601+
created=FIXED_TIME,
1602+
model="dummy_model",
1603+
choices=[
1604+
NeMoGymChoice(
1605+
index=0,
1606+
finish_reason="tool_calls",
1607+
message=NeMoGymChatCompletionMessage(
1608+
role="assistant",
1609+
content=" hello hello",
1610+
tool_calls=[
1611+
NeMoGymChatCompletionMessageToolCall(
1612+
id="call_123",
1613+
function=NeMoGymFunction(
1614+
name="get_order_status",
1615+
arguments='{"order_id": "123"}',
1616+
),
1617+
type="function",
1618+
),
1619+
NeMoGymChatCompletionMessageToolCall(
1620+
id="call_234",
1621+
function=NeMoGymFunction(
1622+
name="get_delivery_date",
1623+
arguments='{"order_id": "234"}',
1624+
),
1625+
type="function",
1626+
),
1627+
],
1628+
reasoning_content="Gathering order status and delivery info...",
1629+
),
1630+
)
1631+
],
1632+
)
1633+
1634+
input_messages = [
1635+
NeMoGymEasyInputMessage(
1636+
type="message",
1637+
role="user",
1638+
content=[NeMoGymResponseInputText(text="Check my order status", type="input_text")],
1639+
status="completed",
1640+
),
1641+
NeMoGymResponseReasoningItem(
1642+
id="rs_123",
1643+
status="completed",
1644+
type="reasoning",
1645+
summary=[
1646+
NeMoGymSummary(
1647+
type="summary_text",
1648+
text="First reasoning item",
1649+
)
1650+
],
1651+
),
1652+
NeMoGymEasyInputMessage(
1653+
type="message",
1654+
role="assistant",
1655+
content=[NeMoGymResponseInputText(text="Sure, one sec.", type="input_text")],
1656+
status="completed",
1657+
),
1658+
NeMoGymEasyInputMessage(
1659+
type="message",
1660+
role="user",
1661+
content=[NeMoGymResponseInputText(text="cool", type="input_text")],
1662+
status="completed",
1663+
),
1664+
NeMoGymEasyInputMessage(
1665+
type="message",
1666+
role="assistant",
1667+
content=[NeMoGymResponseInputText(text="I'm still checking", type="input_text")],
1668+
status="completed",
1669+
),
1670+
NeMoGymEasyInputMessage(
1671+
type="message",
1672+
role="user",
1673+
content=[NeMoGymResponseInputText(text="ok", type="input_text")],
1674+
status="completed",
1675+
),
1676+
]
1677+
1678+
input_tools = [
1679+
NeMoGymFunctionToolParam(
1680+
name="get_order_status",
1681+
parameters={
1682+
"type": "object",
1683+
"properties": {
1684+
"order_id": {
1685+
"type": "string",
1686+
"description": "The ID of the order",
1687+
},
1688+
},
1689+
"required": ["order_id"],
1690+
},
1691+
type="function",
1692+
description="Get the current status for a given order",
1693+
strict=True,
1694+
),
1695+
NeMoGymFunctionToolParam(
1696+
name="get_delivery_date",
1697+
parameters={
1698+
"type": "object",
1699+
"properties": {
1700+
"order_id": {
1701+
"type": "string",
1702+
"description": "The ID of the order",
1703+
},
1704+
},
1705+
"required": ["order_id"],
1706+
},
1707+
type="function",
1708+
description="Get the estimated delivery date for a given order",
1709+
strict=True,
1710+
),
1711+
]
1712+
1713+
expected_response = NeMoGymResponse(
1714+
**COMMON_RESPONSE_PARAMS,
1715+
id="resp_123",
1716+
object="response",
1717+
tools=input_tools,
1718+
created_at=FIXED_TIME,
1719+
model="dummy_model",
1720+
output=[
1721+
NeMoGymResponseReasoningItem(
1722+
id="rs_123",
1723+
status="completed",
1724+
type="reasoning",
1725+
summary=[
1726+
NeMoGymSummary(
1727+
type="summary_text",
1728+
text="Gathering order status and delivery info...",
1729+
)
1730+
],
1731+
),
1732+
NeMoGymResponseOutputMessage(
1733+
id="msg_123",
1734+
status="completed",
1735+
type="message",
1736+
content=[
1737+
NeMoGymResponseOutputText(
1738+
type="output_text",
1739+
text=" hello hello",
1740+
annotations=[],
1741+
logprobs=None,
1742+
)
1743+
],
1744+
),
1745+
NeMoGymResponseFunctionToolCall(
1746+
type="function_call",
1747+
name="get_order_status",
1748+
arguments='{"order_id": "123"}',
1749+
call_id="call_123",
1750+
status="completed",
1751+
id="call_123",
1752+
),
1753+
NeMoGymResponseFunctionToolCall(
1754+
type="function_call",
1755+
name="get_delivery_date",
1756+
arguments='{"order_id": "234"}',
1757+
call_id="call_234",
1758+
status="completed",
1759+
id="call_234",
1760+
),
1761+
],
1762+
)
1763+
1764+
mock_method = AsyncMock(return_value=mock_chat_completion.model_dump())
1765+
monkeypatch.setattr(
1766+
server._clients[0].__class__,
1767+
"create_chat_completion",
1768+
mock_method,
1769+
)
1770+
1771+
monkeypatch.setattr("responses_api_models.vllm_model.app.time", lambda: FIXED_TIME)
1772+
monkeypatch.setattr("responses_api_models.vllm_model.app.uuid4", lambda: FakeUUID())
1773+
1774+
request_body = NeMoGymResponseCreateParamsNonStreaming(
1775+
input=input_messages,
1776+
tools=input_tools,
1777+
)
1778+
1779+
response = client.post(
1780+
"/v1/responses",
1781+
json=request_body.model_dump(exclude_unset=True, mode="json"),
1782+
)
1783+
assert response.status_code == 200
1784+
1785+
data = response.json()
1786+
1787+
expected_dict = expected_response.model_dump()
1788+
assert data == expected_dict
1789+
1790+
expected_messages = [
1791+
{"content": [{"text": "Check my order status", "type": "text"}], "role": "user"},
1792+
{
1793+
"role": "assistant",
1794+
"content": "Sure, one sec.",
1795+
"tool_calls": [],
1796+
"reasoning_content": "First reasoning item",
1797+
},
1798+
{"content": [{"text": "cool", "type": "text"}], "role": "user"},
1799+
{
1800+
"role": "assistant",
1801+
"content": "I'm still checking",
1802+
"tool_calls": [],
1803+
},
1804+
{"content": [{"text": "ok", "type": "text"}], "role": "user"},
1805+
]
1806+
actual_messages = mock_method.call_args.kwargs["messages"]
1807+
assert expected_messages == actual_messages
1808+
1809+
request_body = NeMoGymResponseCreateParamsNonStreaming(
1810+
input=input_messages + data["output"],
1811+
tools=input_tools,
1812+
)
1813+
1814+
response = client.post(
1815+
"/v1/responses",
1816+
json=request_body.model_dump(exclude_unset=True, mode="json"),
1817+
)
1818+
assert response.status_code == 200
1819+
1820+
data = response.json()
1821+
1822+
expected_dict = expected_response.model_dump()
1823+
assert data == expected_dict
1824+
1825+
expected_messages = [
1826+
{"content": [{"text": "Check my order status", "type": "text"}], "role": "user"},
1827+
{
1828+
"role": "assistant",
1829+
"content": "Sure, one sec.",
1830+
"tool_calls": [],
1831+
"reasoning_content": "First reasoning item",
1832+
},
1833+
{"content": [{"text": "cool", "type": "text"}], "role": "user"},
1834+
{
1835+
"role": "assistant",
1836+
"content": "I'm still checking",
1837+
"tool_calls": [],
1838+
},
1839+
{"content": [{"text": "ok", "type": "text"}], "role": "user"},
1840+
{
1841+
"role": "assistant",
1842+
"content": " hello hello",
1843+
"tool_calls": [
1844+
{
1845+
"id": "call_123",
1846+
"function": {"arguments": '{"order_id": "123"}', "name": "get_order_status"},
1847+
"type": "function",
1848+
},
1849+
{
1850+
"id": "call_234",
1851+
"function": {"arguments": '{"order_id": "234"}', "name": "get_delivery_date"},
1852+
"type": "function",
1853+
},
1854+
],
1855+
"reasoning_content": "Gathering order status and delivery info...",
1856+
},
1857+
]
1858+
actual_messages = mock_method.call_args.kwargs["messages"]
1859+
assert expected_messages == actual_messages
1860+
15891861

15901862
class TestVLLMConverter:
15911863
def setup_method(self, _):

0 commit comments

Comments
 (0)