Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions nemo_gym/openai_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,10 +347,10 @@ class NeMoGymChatCompletionMessageToolCallParam(ChatCompletionMessageToolCallPar
function: NeMoGymChatCompletionMessageToolCallFunctionParam


class NeMoGymChatCompletionAssistantMessageParam(ChatCompletionAssistantMessageParam):
class NeMoGymChatCompletionAssistantMessageParam(ChatCompletionAssistantMessageParam, total=False):
# Override the iterable which is annoying to work with.
content: Union[str, List[ContentArrayOfContentPart], None]
tool_calls: List[NeMoGymChatCompletionMessageToolCallParam]
tool_calls: Optional[List[NeMoGymChatCompletionMessageToolCallParam]] = None


class NeMoGymChatCompletionAssistantMessageForTrainingParam(
Expand Down Expand Up @@ -434,7 +434,10 @@ async def _request(self, **request_kwargs: Dict) -> ClientResponse:
tries += 1
response = await request(**request_kwargs)
# See https://platform.openai.com/docs/guides/error-codes/api-errors
if response.status in (429, 500, 503):
# 500 is internal server error, which may sporadically occur
# 502 is Bad gateway (when the endpoint is overloaded)
# 504 is Gateway timeout (when the endpoint config has too low of a gateway timeout setting for the model to finish generating)
if response.status in (429, 500, 502, 503, 504):
content = (await response.content.read()).decode()
print(
f"Hit a {response.status} trying to query an OpenAI endpoint (try {tries}). Sleeping 0.5s. Error message: {content}"
Expand Down
11 changes: 11 additions & 0 deletions nemo_gym/server_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
from aiohttp import ClientResponse, ClientSession, ClientTimeout, DummyCookieJar, ServerDisconnectedError, TCPConnector
from aiohttp.client import _RequestOptions
from fastapi import FastAPI, Request, Response
from fastapi.exception_handlers import request_validation_exception_handler
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from omegaconf import DictConfig, OmegaConf
from pydantic import BaseModel, ConfigDict
Expand Down Expand Up @@ -464,6 +466,15 @@ def run_webserver(cls) -> None: # pragma: no cover
server.set_ulimit()
server.setup_exception_middleware(app)

@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc):
print(
f"""Hit validation exception! Errors: {json.dumps(exc.errors(), indent=4)}
Full body: {json.dumps(exc.body, indent=4)}
"""
)
return await request_validation_exception_handler(request, exc)

profiling_config = ProfilingMiddlewareConfig.model_validate(global_config_dict)
if profiling_config.profiling_enabled:
server.setup_profiling(app, profiling_config)
Expand Down
4 changes: 2 additions & 2 deletions responses_api_models/azure_openai_model/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ async def responses(
async with self._semaphore:
chat_completion_create_params = self._converter.responses_to_chat_completion_create_params(body)
chat_completion_params_dict = chat_completion_create_params.model_dump(exclude_unset=True)
chat_completion_params_dict.setdefault("model", self.config.openai_model)
chat_completion_params_dict["model"] = self.config.openai_model
chat_completion_response = await self._client.chat.completions.create(**chat_completion_params_dict)

choice = chat_completion_response.choices[0]
Expand Down Expand Up @@ -95,7 +95,7 @@ async def chat_completions(
self, request: Request, body: NeMoGymChatCompletionCreateParamsNonStreaming = Body()
) -> NeMoGymChatCompletion:
body_dict = body.model_dump(exclude_unset=True)
body_dict.setdefault("model", self.config.openai_model)
body_dict["model"] = self.config.openai_model
openai_response_dict = await self._client.chat.completions.create(**body_dict)
return NeMoGymChatCompletion.model_validate(openai_response_dict)

Expand Down
4 changes: 2 additions & 2 deletions responses_api_models/azure_openai_model/tests/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,11 @@ async def mock_create_chat(**kwargs):
},
)
assert chat_with_model.status_code == 200
assert called_args_chat.get("model") == "override_model"
assert called_args_chat.get("model") == "dummy_model"

server._client.chat.completions.create.assert_any_await(
messages=[{"role": "user", "content": "hi"}],
model="override_model",
model="dummy_model",
)

async def test_responses(self, monkeypatch: MonkeyPatch) -> None:
Expand Down
4 changes: 2 additions & 2 deletions responses_api_models/openai_model/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,15 @@ def model_post_init(self, context):

async def responses(self, body: NeMoGymResponseCreateParamsNonStreaming = Body()) -> NeMoGymResponse:
body_dict = body.model_dump(exclude_unset=True)
body_dict.setdefault("model", self.config.openai_model)
body_dict["model"] = self.config.openai_model
openai_response_dict = await self._client.create_response(**body_dict)
return NeMoGymResponse.model_validate(openai_response_dict)

async def chat_completions(
self, body: NeMoGymChatCompletionCreateParamsNonStreaming = Body()
) -> NeMoGymChatCompletion:
body_dict = body.model_dump(exclude_unset=True)
body_dict.setdefault("model", self.config.openai_model)
body_dict["model"] = self.config.openai_model
openai_response_dict = await self._client.create_chat_completion(**body_dict)
return NeMoGymChatCompletion.model_validate(openai_response_dict)

Expand Down
8 changes: 4 additions & 4 deletions responses_api_models/openai_model/tests/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,11 @@ async def mock_create_chat(**kwargs):
},
)
assert chat_with_model.status_code == 200
assert called_args_chat.get("model") == "override_model"
assert called_args_chat.get("model") == "dummy_model"

server._client.create_chat_completion.assert_any_await(
messages=[{"role": "user", "content": "hi"}],
model="override_model",
model="dummy_model",
)

async def test_responses(self, monkeypatch: MonkeyPatch) -> None:
Expand Down Expand Up @@ -142,6 +142,6 @@ async def mock_create_response(**kwargs):
# model provided should override config
res_with_model = client.post("/v1/responses", json={"input": "hello", "model": "override_model"})
assert res_with_model.status_code == 200
assert called_args_response.get("model") == "override_model"
assert called_args_response.get("model") == "dummy_model"

server._client.create_response.assert_any_await(input="hello", model="override_model")
server._client.create_response.assert_any_await(input="hello", model="dummy_model")
5 changes: 2 additions & 3 deletions responses_api_models/vllm_model/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,7 @@ async def responses(
) -> NeMoGymResponse:
# Response Create Params -> Chat Completion Create Params
chat_completion_create_params = self._converter.responses_to_chat_completion_create_params(body)
if not body.model:
body.model = self.config.model
body.model = self.config.model

# Chat Completion Create Params -> Chat Completion
chat_completion_response = await self.chat_completions(request, chat_completion_create_params)
Expand Down Expand Up @@ -135,7 +134,7 @@ async def chat_completions(
self, request: Request, body: NeMoGymChatCompletionCreateParamsNonStreaming = Body()
) -> NeMoGymChatCompletion:
body_dict = body.model_dump(exclude_unset=True)
body_dict.setdefault("model", self.config.model)
body_dict["model"] = self.config.model

session_id = request.session[SESSION_ID_KEY]
if session_id not in self._session_id_to_client:
Expand Down