Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion nemo_gym/openai_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ class NeMoGymAsyncOpenAI(AsyncOpenAI):
def __init__(self, **kwargs) -> None:
# TODO: this setup is take from https://github.com/NVIDIA/NeMo-Skills/blob/80dc78ac758c4cac81c83a43a729e7ca1280857b/nemo_skills/inference/model/base.py#L318
# However, there may still be a lingering issue regarding saturating at 100 max connections
kwargs["http_client"] = get_global_httpx_client()
kwargs["http_client"] = get_global_httpx_client(kwargs["base_url"])
kwargs["timeout"] = None # Enforce no timeout

super().__init__(**kwargs)
31 changes: 21 additions & 10 deletions nemo_gym/server_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from abc import abstractmethod
from os import getenv
from threading import Thread
from typing import Any, Literal, Optional, Tuple, Type, Union
from typing import Any, Dict, Literal, Optional, Tuple, Type, Union
from uuid import uuid4

import requests
Expand Down Expand Up @@ -65,13 +65,19 @@ def __init__(self, *args, **kwargs) -> None:
# ```
# In order to get the most benefit from connection pooling, make sure you're not instantiating multiple client instances - for example by using async with inside a "hot loop". This can be achieved either by having a single scoped client that's passed throughout wherever it's needed, or by having a single global client instance.
# ```
# In plain language:
# - Let's say we have 10 distinct endpoints we want to call 5 times each.
# - A connection pool as defined by the httpx client is for a single distinct endpoint. All requests to that endpoint should use the same httpx client.
# - So the optimal configuration here is to have 10 total httpx clients, one for each distinct endpoint.
# - Additionally, since the connections are pooled, if we had a single global client for all 10 distinct endpoints, we may run into deadlock situations,
# where requests to two different endpoints are waiting for each other to resolve.
#
# In principle, we use no timeout since various api or model calls may take an indefinite amount of time. Right now, we have no timeout, even for connection errors which may be problematic. We may want to revisit more granular httpx.Timeout later on.
#
# Eventually, we may also want to parameterize the max connections. For now, we set the max connections to just some very large number.
#
# It's critical that this client is NOT used before uvicorn.run is called. Under the hood, this async client will start and use an event loop, and store a handle to that specific event loop. When uvicorn.run is called, it will replace the event loop policy with its own. So the handle that the async client has is now outdated.
_GLOBAL_HTTPX_CLIENT = None
_GLOBAL_HTTPX_CLIENTS: Dict[str, NeMoGymGlobalAsyncClient] = dict()


class GlobalHTTPXAsyncClientConfig(BaseModel):
Expand All @@ -80,12 +86,13 @@ class GlobalHTTPXAsyncClientConfig(BaseModel):


def get_global_httpx_client(
base_url: str,
global_config_dict_parser_config: Optional[GlobalConfigDictParserConfig] = None,
global_config_dict_parser_cls: Type[GlobalConfigDictParser] = GlobalConfigDictParser,
) -> NeMoGymGlobalAsyncClient:
global _GLOBAL_HTTPX_CLIENT
if _GLOBAL_HTTPX_CLIENT is not None:
return _GLOBAL_HTTPX_CLIENT
if base_url in _GLOBAL_HTTPX_CLIENTS:
return _GLOBAL_HTTPX_CLIENTS[base_url]

global_config_dict = get_global_config_dict(
global_config_dict_parser_config=global_config_dict_parser_config,
global_config_dict_parser_cls=global_config_dict_parser_cls,
Expand All @@ -99,7 +106,9 @@ def get_global_httpx_client(
transport=AsyncHTTPTransport(retries=cfg.global_httpx_max_retries),
timeout=None,
)
_GLOBAL_HTTPX_CLIENT = client

_GLOBAL_HTTPX_CLIENTS[base_url] = client

return client


Expand Down Expand Up @@ -166,8 +175,9 @@ async def get(

"""
server_config_dict = get_first_server_config_dict(self.global_config_dict, server_name)
return await get_global_httpx_client().get(
f"{self._build_server_base_url(server_config_dict)}{url_path}",
base_url = self._build_server_base_url(server_config_dict)
return await get_global_httpx_client(base_url).get(
f"{base_url}{url_path}",
params=params,
headers=headers,
cookies=cookies,
Expand Down Expand Up @@ -198,8 +208,9 @@ async def post(

"""
server_config_dict = get_first_server_config_dict(self.global_config_dict, server_name)
return await get_global_httpx_client().post(
f"{self._build_server_base_url(server_config_dict)}{url_path}",
base_url = self._build_server_base_url(server_config_dict)
return await get_global_httpx_client(base_url).post(
f"{base_url}{url_path}",
content=content,
data=data,
files=files,
Expand Down
13 changes: 6 additions & 7 deletions tests/unit_tests/test_server_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,20 +106,19 @@ async def test_ServerClient_get_post_sanity(self, monkeypatch: MonkeyPatch) -> N
),
)

httpx_client_mock = AsyncMock()
httpx_client_mock.return_value = "my mock response"
monkeypatch.setattr(nemo_gym.server_utils.get_global_httpx_client(), "get", httpx_client_mock)
httpx_client_mock = MagicMock()
httpx_client_get_post_mock = AsyncMock()
httpx_client_get_post_mock.return_value = "my mock response"
httpx_client_mock.return_value.get = httpx_client_get_post_mock
httpx_client_mock.return_value.post = httpx_client_get_post_mock
monkeypatch.setattr(nemo_gym.server_utils, "get_global_httpx_client", httpx_client_mock)

actual_response = await server_client.get(
server_name="my_server",
url_path="blah blah",
)
assert "my mock response" == actual_response

httpx_client_mock = AsyncMock()
httpx_client_mock.return_value = "my mock response"
monkeypatch.setattr(nemo_gym.server_utils.get_global_httpx_client(), "post", httpx_client_mock)

actual_response = await server_client.post(
server_name="my_server",
url_path="blah blah",
Expand Down
2 changes: 1 addition & 1 deletion tests/unit_tests/test_train_data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def test_load_datasets_missing_train_dataset_shouldnt_download_raises_AssertionE
{
"name": "train",
"type": "train",
"jsonl_fpath": "resources_servers/multineedle/data/train.jsonl",
"jsonl_fpath": "some/nonexiststent/path",
"gitlab_identifier": {
"dataset_name": "multineedle",
"version": "0.0.1",
Expand Down
Loading