Skip to content

Commit ee4d767

Browse files
bxyu-nvidiaabhibha-nvidia
authored andcommitted
VLLMModel propogates token IDs (#11)
Signed-off-by: Brian Yu <bxyu@nvidia.com> Signed-off-by: Abhibha Gupta <abhibhag@nvidia.com>
1 parent aca5f74 commit ee4d767

8 files changed

Lines changed: 233 additions & 8431 deletions

File tree

README.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
- [FAQ: Why NeMo Gym?](#faq-why-nemo-gym)
2424
- [FAQ: Error: Found files with missing copyright](#faq-error-found-files-with-missing-copyright)
2525
- [FAQ: build-docs / Build docs CI failures](#faq-build-docs--build-docs-ci-failures)
26+
- [FAQ: NeMo Gym, training frameworks, and token IDs](#faq-nemo-gym-training-frameworks-and-token-ids)
2627

2728

2829
# NeMo-Gym
@@ -874,3 +875,15 @@ pickling environment... done
874875
checking consistency... done
875876
```
876877
You may need to reformat some of your docstrings to Napoleon format docstrings https://sphinxcontrib-napoleon.readthedocs.io/en/latest/
878+
879+
880+
# FAQ: NeMo Gym, training frameworks, and token IDs
881+
One of the goals of NeMo Gym is to act as a rollout tool for LLM post-training, either as synthetic data generation for SFT or as training environments for RL.
882+
883+
RL training frameworks don't typically operate in OpenAI schema; they operate in tokens IDs. It is especially critical to always have the correct token IDs during training so that we stay on-policy and to make sure that what we think the model sees is what the model actually sees. However, when providing this OpenAI schema compatible interface to training environment developers, we lose track of the token IDs in Gym.
884+
885+
For example, say we are training a Qwen 3 family model. During rollouts, the model may sample from the entire token distribution. The token IDs are then decoded into text and subsequently converted to OpenAI schema and returned to the training environment developer. At some point for multi-step and multi-turn scenarios, the training environment developer will call the model again with the previously output OpenAI schema. This re-tokenization causes problems since a single string may map to multiple possible sequences of token IDs. So if the model generations token ID sequence 1 and the re-tokenization outputs token ID sequence 2, suddenly things may become off policy when the Gym result is consumed by the RL training framework.
886+
887+
So, the OpenAI compatible model server in a training framework needs to be able to handle this discrepancy. In order to do that, Gym needs a handle on the ground truth token IDs and it needs to provide that information back to the training frameworks' OpenAI compatible server.
888+
889+
TODO @bxyu-nvidia: expand on this later.

nemo_gym/openai_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,15 @@ class NeMoGymResponseReasoningItemForTraining(NeMoGymResponseReasoningItem, Toke
197197
pass
198198

199199

200+
RESPONSES_TO_TRAIN = {
201+
NeMoGymEasyInputMessage: NeMoGymEasyInputMessageForTraining,
202+
NeMoGymMessage: NeMoGymMessageForTraining,
203+
NeMoGymResponseOutputMessage: NeMoGymResponseOutputMessageForTraining,
204+
NeMoGymResponseFunctionToolCall: NeMoGymResponseFunctionToolCallForTraining,
205+
NeMoGymResponseReasoningItem: NeMoGymResponseReasoningItemForTraining,
206+
}
207+
208+
200209
NeMoGymResponseInputItem = Union[
201210
NeMoGymEasyInputMessage,
202211
NeMoGymMessage,

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ ng_dump_config = "nemo_gym.cli:dump_config"
226226

227227
[tool.setuptools.packages.find]
228228
where = ["."]
229-
include = ["resources_servers", "responses_api_agents", "responses_api_models", "nemo_gym"]
229+
include = ["resources_servers", "responses_api_agents", "responses_api_models", "nemo_gym", "penguin"]
230230

231231
################################################
232232
# Testing

responses_api_models/vllm_model/app.py

Lines changed: 90 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
import re
1515
from time import time
16-
from typing import List, Tuple
16+
from typing import ClassVar, List, Optional, Tuple
1717
from uuid import uuid4
1818

1919
from openai import BaseModel as OpenAIBaseModel
@@ -25,8 +25,10 @@
2525
SimpleResponsesAPIModel,
2626
)
2727
from nemo_gym.openai_utils import (
28+
RESPONSES_TO_TRAIN,
2829
NeMoGymAsyncOpenAI,
2930
NeMoGymChatCompletion,
31+
NeMoGymChatCompletionAssistantMessageForTrainingParam,
3032
NeMoGymChatCompletionAssistantMessageParam,
3133
NeMoGymChatCompletionCreateParamsNonStreaming,
3234
NeMoGymChatCompletionDeveloperMessageParam,
@@ -47,13 +49,15 @@
4749
NeMoGymResponseOutputText,
4850
NeMoGymResponseReasoningItem,
4951
NeMoGymSummary,
52+
TokenIDLogProbMixin,
5053
)
5154

5255

5356
class VLLMModelConfig(BaseResponsesAPIModelConfig):
5457
base_url: str
5558
api_key: str
5659
model: str
60+
return_token_id_information: bool
5761

5862

5963
# This needs to be OpenAI BaseModel since it is casted to below by the OpenAI client.
@@ -69,7 +73,7 @@ def model_post_init(self, context):
6973
base_url=self.config.base_url,
7074
api_key=self.config.api_key,
7175
)
72-
self._converter = VLLMConverter()
76+
self._converter = VLLMConverter(return_token_id_information=self.config.return_token_id_information)
7377
return super().model_post_init(context)
7478

7579
async def responses(self, body: NeMoGymResponseCreateParamsNonStreaming = Body()) -> NeMoGymResponse:
@@ -82,21 +86,10 @@ async def responses(self, body: NeMoGymResponseCreateParamsNonStreaming = Body()
8286
chat_completion_response = await self.chat_completions(chat_completion_create_params)
8387

8488
choice = chat_completion_response.choices[0]
85-
message = choice.message
8689

8790
response_output = self._converter.postprocess_chat_response(choice)
8891
response_output_dicts = [item.model_dump() for item in response_output]
8992

90-
last_response_output_item = response_output_dicts[-1]
91-
if hasattr(message, "prompt_token_ids"):
92-
last_response_output_item.update(
93-
dict(
94-
prompt_token_ids=message.prompt_token_ids,
95-
generation_token_ids=message.generation_token_ids,
96-
generation_log_probs=message.generation_log_probs,
97-
)
98-
)
99-
10093
# Chat Completion -> Response
10194
return NeMoGymResponse(
10295
id=f"resp_{uuid4().hex}",
@@ -130,72 +123,97 @@ async def chat_completions(
130123
body_dict = body.model_dump(exclude_unset=True)
131124
body_dict.setdefault("model", self.config.model)
132125

133-
openai_response = await self._client.chat.completions.create(
134-
**body_dict,
135-
logprobs=True,
136-
# The extra body below is VLLM specific to get the generation log probs associated with generation token IDs.
137-
extra_body={
138-
"return_tokens_as_token_ids": True,
139-
},
140-
)
126+
create_params = body_dict
127+
if self.config.return_token_id_information:
128+
create_params |= dict(
129+
logprobs=True,
130+
# The extra body below is VLLM specific to get the generation log probs associated with generation token IDs.
131+
extra_body={
132+
"return_tokens_as_token_ids": True,
133+
},
134+
)
135+
136+
openai_response = await self._client.chat.completions.create(**create_params)
141137
assert not getattr(openai_response.choices[0].message, "reasoning_content", None), (
142138
"Please do not use a reasoning parser in vLLM! There is one source of truth for handling data (including reasoning), which is NeMo Gym!"
143139
)
144140
openai_response: NeMoGymChatCompletion
145141

146-
log_probs = openai_response.choices[0].logprobs.content
147-
generation_token_ids = []
148-
generation_log_probs = []
149-
for log_prob in log_probs:
150-
# Looks like `"token_id:151667"`
151-
generation_token_ids.append(int(log_prob.token.removeprefix("token_id:")))
152-
generation_log_probs.append(log_prob.logprob)
153-
154-
# The base url has /v1 at the end but vLLM's tokenize endpoint does not have v1, hence the ..
155-
# I can't believe the path is resolved correctly LOL
156-
tokenize_response = await self._client.post(
157-
"../tokenize",
158-
cast_to=VLLMTokenizeResponse,
159-
body=body_dict,
160-
)
161-
162142
chat_completion_dict = openai_response.model_dump()
163-
message_dict = chat_completion_dict["choices"][0]["message"]
164-
message_dict.update(
165-
dict(
166-
prompt_token_ids=tokenize_response.tokens,
167-
generation_token_ids=generation_token_ids,
168-
generation_log_probs=generation_log_probs,
143+
144+
if self.config.return_token_id_information:
145+
log_probs = openai_response.choices[0].logprobs.content
146+
generation_token_ids = []
147+
generation_log_probs = []
148+
for log_prob in log_probs:
149+
# Looks like `"token_id:151667"`
150+
generation_token_ids.append(int(log_prob.token.removeprefix("token_id:")))
151+
generation_log_probs.append(log_prob.logprob)
152+
153+
# The base url has /v1 at the end but vLLM's tokenize endpoint does not have v1, hence the ..
154+
# I can't believe the path is resolved correctly LOL
155+
tokenize_response = await self._client.post(
156+
"../tokenize",
157+
cast_to=VLLMTokenizeResponse,
158+
body=body_dict,
169159
)
170-
)
160+
161+
message_dict = chat_completion_dict["choices"][0]["message"]
162+
message_dict.update(
163+
dict(
164+
prompt_token_ids=tokenize_response.tokens,
165+
generation_token_ids=generation_token_ids,
166+
generation_log_probs=generation_log_probs,
167+
)
168+
)
169+
171170
return NeMoGymChatCompletion(**chat_completion_dict)
172171

173172

174173
class VLLMConverterResponsesToChatCompletionsState(BaseModel):
174+
return_token_id_information: bool
175+
175176
messages: List[NeMoGymChatCompletionMessageParam] = Field(default_factory=list)
176177

177178
# We are mapping from Response input items to chat completions messages, which is many to one.
178179
# Our state will accumulate the reasoning, chat, and tool calls for assistant messages.
179180
content_buffer: str = "" # Buffer for reasoning and chat
180181
tool_calls_buffer: List[NeMoGymChatCompletionMessageToolCallParam] = Field(default_factory=list)
181182

183+
# Will only be populated if return_token_id_information is True.
184+
token_information: Optional[TokenIDLogProbMixin] = None
185+
182186
def flush_assistant(self) -> None:
183187
if not (self.content_buffer or self.tool_calls_buffer):
184188
return
185189

186-
self.messages.append(
187-
NeMoGymChatCompletionAssistantMessageParam(
188-
content=self.content_buffer or None,
189-
role="assistant",
190-
tool_calls=self.tool_calls_buffer,
191-
)
190+
shared_params = dict(
191+
content=self.content_buffer or None,
192+
role="assistant",
193+
tool_calls=self.tool_calls_buffer,
192194
)
195+
if self.return_token_id_information:
196+
message = NeMoGymChatCompletionAssistantMessageForTrainingParam(
197+
**shared_params,
198+
**self.token_information.model_dump(),
199+
)
200+
else:
201+
message = NeMoGymChatCompletionAssistantMessageParam(**shared_params)
202+
203+
self.messages.append(message)
204+
193205
self.content_buffer = ""
194206
self.tool_calls_buffer = []
195207

196208

197-
class VLLMConverter:
198-
THINK_TAG_PATTERN = re.compile(r"<think>(.*?)</think>", re.DOTALL)
209+
class VLLMConverter(BaseModel):
210+
return_token_id_information: bool
211+
212+
# =======================================================
213+
# Reasoning handling. This may change across models and model families
214+
# =======================================================
215+
216+
THINK_TAG_PATTERN: ClassVar = re.compile(r"<think>(.*?)</think>", re.DOTALL)
199217

200218
@staticmethod
201219
def _wrap_reasoning_in_think_tags(texts: List[str]) -> str:
@@ -220,7 +238,9 @@ def responses_to_chat_completion_create_params(
220238
responses_create_params = responses_create_params.model_dump(exclude_unset=True)
221239

222240
# Tracks messages including reasoning for each respective message type helper function
223-
state = VLLMConverterResponsesToChatCompletionsState()
241+
state = VLLMConverterResponsesToChatCompletionsState(
242+
return_token_id_information=self.return_token_id_information
243+
)
224244

225245
# Input can be a string. Wrap in a ResponseInput-like
226246
response_input = responses_create_params["input"]
@@ -255,6 +275,13 @@ def responses_to_chat_completion_create_params(
255275
case _: # pragma: no cover
256276
raise NotImplementedError(f"Unsupported message type: {m}")
257277

278+
if self.return_token_id_information and m.get("prompt_token_ids"):
279+
state.token_information = TokenIDLogProbMixin(
280+
prompt_token_ids=m["prompt_token_ids"],
281+
generation_token_ids=m["generation_token_ids"],
282+
generation_log_probs=m["generation_log_probs"],
283+
)
284+
258285
state.flush_assistant()
259286

260287
model = responses_create_params.pop("model", None)
@@ -439,6 +466,16 @@ def postprocess_chat_response(self, choice: NeMoGymChoice) -> List[NeMoGymRespon
439466
)
440467
)
441468

469+
if self.return_token_id_information:
470+
last_response_output_item = response_output[-1]
471+
train_cls = RESPONSES_TO_TRAIN[last_response_output_item.__class__]
472+
response_output[-1] = train_cls(
473+
**last_response_output_item.model_dump(),
474+
prompt_token_ids=raw_message["prompt_token_ids"],
475+
generation_token_ids=raw_message["generation_token_ids"],
476+
generation_log_probs=raw_message["generation_log_probs"],
477+
)
478+
442479
return response_output
443480

444481
def _extract_reasoning_from_content(self, content: str) -> Tuple[List[str], str]:

responses_api_models/vllm_model/configs/vllm_model.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@ openai_model:
55
base_url: ${policy_base_url}
66
api_key: ${policy_api_key}
77
model: ${policy_model_name}
8+
return_token_id_information: false
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
openai_model:
2+
responses_api_models:
3+
vllm_model:
4+
entrypoint: app.py
5+
base_url: ${policy_base_url}
6+
api_key: ${policy_api_key}
7+
model: ${policy_model_name}
8+
return_token_id_information: true

0 commit comments

Comments
 (0)