Skip to content

Commit 40bf7d4

Browse files
authored
Misc infra 20251024 (#234)
Signed-off-by: Brian Yu <bxyu@nvidia.com>
1 parent 901164a commit 40bf7d4

7 files changed

Lines changed: 29 additions & 16 deletions

File tree

nemo_gym/openai_utils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -347,10 +347,10 @@ class NeMoGymChatCompletionMessageToolCallParam(ChatCompletionMessageToolCallPar
347347
function: NeMoGymChatCompletionMessageToolCallFunctionParam
348348

349349

350-
class NeMoGymChatCompletionAssistantMessageParam(ChatCompletionAssistantMessageParam):
350+
class NeMoGymChatCompletionAssistantMessageParam(ChatCompletionAssistantMessageParam, total=False):
351351
# Override the iterable which is annoying to work with.
352352
content: Union[str, List[ContentArrayOfContentPart], None]
353-
tool_calls: List[NeMoGymChatCompletionMessageToolCallParam]
353+
tool_calls: Optional[List[NeMoGymChatCompletionMessageToolCallParam]] = None
354354

355355

356356
class NeMoGymChatCompletionAssistantMessageForTrainingParam(
@@ -434,7 +434,10 @@ async def _request(self, **request_kwargs: Dict) -> ClientResponse:
434434
tries += 1
435435
response = await request(**request_kwargs)
436436
# See https://platform.openai.com/docs/guides/error-codes/api-errors
437-
if response.status in (429, 500, 503):
437+
# 500 is internal server error, which may sporadically occur
438+
# 502 is Bad gateway (when the endpoint is overloaded)
439+
# 504 is Gateway timeout (when the endpoint config has too low of a gateway timeout setting for the model to finish generating)
440+
if response.status in (429, 500, 502, 503, 504):
438441
content = (await response.content.read()).decode()
439442
print(
440443
f"Hit a {response.status} trying to query an OpenAI endpoint (try {tries}). Sleeping 0.5s. Error message: {content}"

nemo_gym/server_utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
from aiohttp import ClientResponse, ClientSession, ClientTimeout, DummyCookieJar, ServerDisconnectedError, TCPConnector
3535
from aiohttp.client import _RequestOptions
3636
from fastapi import FastAPI, Request, Response
37+
from fastapi.exception_handlers import request_validation_exception_handler
38+
from fastapi.exceptions import RequestValidationError
3739
from fastapi.responses import JSONResponse
3840
from omegaconf import DictConfig, OmegaConf
3941
from pydantic import BaseModel, ConfigDict
@@ -464,6 +466,15 @@ def run_webserver(cls) -> None: # pragma: no cover
464466
server.set_ulimit()
465467
server.setup_exception_middleware(app)
466468

469+
@app.exception_handler(RequestValidationError)
470+
async def validation_exception_handler(request: Request, exc):
471+
print(
472+
f"""Hit validation exception! Errors: {json.dumps(exc.errors(), indent=4)}
473+
Full body: {json.dumps(exc.body, indent=4)}
474+
"""
475+
)
476+
return await request_validation_exception_handler(request, exc)
477+
467478
profiling_config = ProfilingMiddlewareConfig.model_validate(global_config_dict)
468479
if profiling_config.profiling_enabled:
469480
server.setup_profiling(app, profiling_config)

responses_api_models/azure_openai_model/app.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ async def responses(
5959
async with self._semaphore:
6060
chat_completion_create_params = self._converter.responses_to_chat_completion_create_params(body)
6161
chat_completion_params_dict = chat_completion_create_params.model_dump(exclude_unset=True)
62-
chat_completion_params_dict.setdefault("model", self.config.openai_model)
62+
chat_completion_params_dict["model"] = self.config.openai_model
6363
chat_completion_response = await self._client.chat.completions.create(**chat_completion_params_dict)
6464

6565
choice = chat_completion_response.choices[0]
@@ -95,7 +95,7 @@ async def chat_completions(
9595
self, request: Request, body: NeMoGymChatCompletionCreateParamsNonStreaming = Body()
9696
) -> NeMoGymChatCompletion:
9797
body_dict = body.model_dump(exclude_unset=True)
98-
body_dict.setdefault("model", self.config.openai_model)
98+
body_dict["model"] = self.config.openai_model
9999
openai_response_dict = await self._client.chat.completions.create(**body_dict)
100100
return NeMoGymChatCompletion.model_validate(openai_response_dict)
101101

responses_api_models/azure_openai_model/tests/test_app.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,11 +108,11 @@ async def mock_create_chat(**kwargs):
108108
},
109109
)
110110
assert chat_with_model.status_code == 200
111-
assert called_args_chat.get("model") == "override_model"
111+
assert called_args_chat.get("model") == "dummy_model"
112112

113113
server._client.chat.completions.create.assert_any_await(
114114
messages=[{"role": "user", "content": "hi"}],
115-
model="override_model",
115+
model="dummy_model",
116116
)
117117

118118
async def test_responses(self, monkeypatch: MonkeyPatch) -> None:

responses_api_models/openai_model/app.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,15 +45,15 @@ def model_post_init(self, context):
4545

4646
async def responses(self, body: NeMoGymResponseCreateParamsNonStreaming = Body()) -> NeMoGymResponse:
4747
body_dict = body.model_dump(exclude_unset=True)
48-
body_dict.setdefault("model", self.config.openai_model)
48+
body_dict["model"] = self.config.openai_model
4949
openai_response_dict = await self._client.create_response(**body_dict)
5050
return NeMoGymResponse.model_validate(openai_response_dict)
5151

5252
async def chat_completions(
5353
self, body: NeMoGymChatCompletionCreateParamsNonStreaming = Body()
5454
) -> NeMoGymChatCompletion:
5555
body_dict = body.model_dump(exclude_unset=True)
56-
body_dict.setdefault("model", self.config.openai_model)
56+
body_dict["model"] = self.config.openai_model
5757
openai_response_dict = await self._client.create_chat_completion(**body_dict)
5858
return NeMoGymChatCompletion.model_validate(openai_response_dict)
5959

responses_api_models/openai_model/tests/test_app.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,11 +87,11 @@ async def mock_create_chat(**kwargs):
8787
},
8888
)
8989
assert chat_with_model.status_code == 200
90-
assert called_args_chat.get("model") == "override_model"
90+
assert called_args_chat.get("model") == "dummy_model"
9191

9292
server._client.create_chat_completion.assert_any_await(
9393
messages=[{"role": "user", "content": "hi"}],
94-
model="override_model",
94+
model="dummy_model",
9595
)
9696

9797
async def test_responses(self, monkeypatch: MonkeyPatch) -> None:
@@ -142,6 +142,6 @@ async def mock_create_response(**kwargs):
142142
# model provided should override config
143143
res_with_model = client.post("/v1/responses", json={"input": "hello", "model": "override_model"})
144144
assert res_with_model.status_code == 200
145-
assert called_args_response.get("model") == "override_model"
145+
assert called_args_response.get("model") == "dummy_model"
146146

147-
server._client.create_response.assert_any_await(input="hello", model="override_model")
147+
server._client.create_response.assert_any_await(input="hello", model="dummy_model")

responses_api_models/vllm_model/app.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,7 @@ async def responses(
9393
) -> NeMoGymResponse:
9494
# Response Create Params -> Chat Completion Create Params
9595
chat_completion_create_params = self._converter.responses_to_chat_completion_create_params(body)
96-
if not body.model:
97-
body.model = self.config.model
96+
body.model = self.config.model
9897

9998
# Chat Completion Create Params -> Chat Completion
10099
chat_completion_response = await self.chat_completions(request, chat_completion_create_params)
@@ -135,7 +134,7 @@ async def chat_completions(
135134
self, request: Request, body: NeMoGymChatCompletionCreateParamsNonStreaming = Body()
136135
) -> NeMoGymChatCompletion:
137136
body_dict = body.model_dump(exclude_unset=True)
138-
body_dict.setdefault("model", self.config.model)
137+
body_dict["model"] = self.config.model
139138

140139
session_id = request.session[SESSION_ID_KEY]
141140
if session_id not in self._session_id_to_client:

0 commit comments

Comments
 (0)