Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions nemo_gym/server_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from typing import Literal, Optional, Tuple, Type, Union, Unpack
from uuid import uuid4

import ray
import requests
import uvicorn
import yappi
Expand Down Expand Up @@ -309,6 +310,22 @@ class UvicornLoggingConfig(BaseModel):
uvicorn_logging_show_200_ok: bool = False


def initialize_ray() -> None:
if ray.is_initialized():
print("Ray already initialized")
return

global_config_dict = get_global_config_dict()
ray_head_node_address = global_config_dict.get("ray_head_node_address")

if ray_head_node_address is not None:
print(f"Connecting to Ray cluster at specified address: {ray_head_node_address}")
ray.init(address=ray_head_node_address, ignore_reinit_error=True)
else:
print("Starting Ray cluster...")
ray.init(ignore_reinit_error=True)


class SimpleServer(BaseServer):
server_client: ServerClient

Expand Down Expand Up @@ -434,6 +451,8 @@ def set_ulimit(self, target_soft_limit: int = 65535): # pragma: no cover
def run_webserver(cls) -> None: # pragma: no cover
global_config_dict = get_global_config_dict()

initialize_ray()

server_config = cls.load_config_from_global_config()
server_client = ServerClient(
head_server_config=ServerClient.load_head_server_config(),
Expand Down
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,11 @@ dependencies = [
# Updated Mon Sep 22, 2025 with yappi==1.6.10
# License: MIT https://github.com/sumerc/yappi/blob/1d3f7501701e1f050b6dcd6a86fd36aec08185c7/LICENSE
"yappi",

# Ray: Used for distributed processing
# Updated Fri Oct 18, 2025 with ray[default]==2.46.0
# License: Apache 2.0 https://github.com/ray-project/ray/blob/master/LICENSE
"ray[default]",
]

[dependency-groups]
Expand Down
12 changes: 8 additions & 4 deletions resources_servers/comp_coding/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
from time import time
from typing import Any, Dict, List, Optional, Union

from lcb_integration.compute_code_generation_metrics import check_correctness
import ray
from lcb_integration.compute_code_generation_metrics import check_correctness_remote
from lcb_integration.extraction_utils import LMStyle, extract_code
from pydantic import BaseModel

Expand Down Expand Up @@ -124,14 +125,17 @@ async def verify(self, body: CompCodingVerifyRequest) -> CompCodingVerifyRespons

# We can directly measure here since we are inside the semaphore.
start_time = time()
result, metadata = await loop.run_in_executor(
None,
check_correctness,

task_args = (
{"input_output": tests.model_dump_json()}, # sample
code, # generation
self.config.unit_test_timeout_secs, # timeout
self.config.debug, # debug
)

future = check_correctness_remote.remote(*task_args)
result, metadata = await loop.run_in_executor(None, ray.get, future)

unit_tests_time_taken = time() - start_time

return CompCodingVerifyResponse(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,49 +16,61 @@
# borrowed and extended from
# https://github.com/Naman-ntc/codescratch/blob/main/evaluation/bigcode-evaluation-harness/lm_eval/tasks/custom_metrics/apps_custom_metrics/utils.py

import os
import sys


sys.set_int_max_str_digits(50000)

os.environ["TOKENIZERS_PARALLELISM"] = "false"
import json
import multiprocessing
import os
import sys
from collections import defaultdict
from concurrent.futures import ProcessPoolExecutor, as_completed

import numpy as np
import ray
from tqdm import tqdm

from lcb_integration.pass_k_utils import compute_metrics_from_results
from lcb_integration.testing_util import run_test


def _temp_run(sample, generation, debug, result, metadata_list, timeout):
res, metadata = run_test(sample, test=generation, debug=debug, timeout=timeout)
sys.set_int_max_str_digits(50000)
os.environ["TOKENIZERS_PARALLELISM"] = "false"


def _temp_run(in_outs, generation, debug, result, metadata_list, timeout):
res, metadata = run_test(in_outs, test=generation, debug=debug, timeout=timeout)
result.append(res)
metadata_list.append(metadata)


# Using SPREAD scheduling so that Ray assigns tasks to as many distinct nodes as possible.
@ray.remote(scheduling_strategy="SPREAD")
def check_correctness_remote(sample, generation, timeout, debug=True):
"""Ray wrapper of check_correctness for remote execution."""
return check_correctness(sample, generation, timeout, debug)


def check_correctness(sample, generation, timeout, debug=True):
"""Check correctness of code generation with a global timeout.
The global timeout is to catch some extreme/rare cases not handled by the timeouts
inside `run_test`"""

# Parse JSON once at the beginning to avoid multiple parsing
try:
in_outs = json.loads(sample["input_output"])
except (ValueError, MemoryError):
return [-1], None

manager = multiprocessing.Manager()
result = manager.list()
metadata_list = manager.list()
p = multiprocessing.Process(
target=_temp_run,
args=(sample, generation, debug, result, metadata_list, timeout),
args=(in_outs, generation, debug, result, metadata_list, timeout),
)
p.start()
p.join(timeout=(timeout + 1) * len(json.loads(sample["input_output"])["inputs"]) + 5)
p.join(timeout=(timeout + 1) * len(in_outs["inputs"]) + 5)
if p.is_alive():
p.kill()
if not result:
in_outs = json.loads(sample["input_output"])
# consider that all tests failed
result = [[-1 for i in range(len(in_outs["inputs"]))]]
metadata_list = [None]
Expand Down
22 changes: 14 additions & 8 deletions resources_servers/comp_coding/lcb_integration/testing_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import signal
import sys
import time
import traceback

# used for debugging to time steps
from datetime import datetime
Expand Down Expand Up @@ -306,6 +307,7 @@ def grade_call_based(code: str, all_inputs: list, all_outputs: list, fn_name: st
"error_message": "Runtime Error",
"inputs": truncatefn(gt_inp),
"expected": truncatefn(gt_out),
"traceback": traceback.format_exc(),
}

finally:
Expand Down Expand Up @@ -369,6 +371,7 @@ def grade_stdio(
"error_message": "Runtime Error",
"inputs": truncatefn(gt_inp),
"expected": truncatefn(gt_out),
"traceback": traceback.format_exc(),
}

finally:
Expand Down Expand Up @@ -431,7 +434,7 @@ def grade_stdio(
return all_results, {"execution time": total_execution_time}


def run_test(sample, test=None, debug=False, timeout=6):
def run_test(in_outs, test=None, debug=False, timeout=6):
"""
if test(generated_code) is not None it'll try to run the code.
otherwise it'll just return an input and output pair.
Expand All @@ -440,16 +443,17 @@ def run_test(sample, test=None, debug=False, timeout=6):

# Disable functionalities that can make destructive changes to the test.
# max memory is set to 4GB
reliability_guard(4 * 1024**3)
reliability_guard(maximum_memory_bytes=4 * 1024**3)

if debug:
print(f"start = {datetime.now().time()}")

try:
in_outs = json.loads(sample["input_output"])
except ValueError as e:
raise e
in_outs = None
# The in_outs is already loaded from the sample from the parent process
# try:
# in_outs = json.loads(sample["input_output"])
# except ValueError as e:
# raise e
# in_outs = None

if in_outs:
if in_outs.get("fn_name") is None:
Expand Down Expand Up @@ -537,7 +541,9 @@ def reliability_guard(maximum_memory_bytes=None):
if maximum_memory_bytes is not None:
import resource

_set_resource_limit(resource.RLIMIT_AS, maximum_memory_bytes)
# The resource limit on RLIMIT_AS has been disabled because setting it caused additional out-of-memory (OOM) issues with Ray.
# This happens since Ray and its subprocesses share the same virtual address space.
# _set_resource_limit(resource.RLIMIT_AS, maximum_memory_bytes)
_set_resource_limit(resource.RLIMIT_DATA, maximum_memory_bytes)

if not platform.uname().system == "Darwin":
Expand Down
2 changes: 2 additions & 0 deletions resources_servers/comp_coding/tests/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from unittest.mock import MagicMock

import pytest
import ray
from app import (
CompCodingResourcesServer,
CompCodingResourcesServerConfig,
Expand All @@ -32,6 +33,7 @@
class TestApp:
@pytest.fixture(scope="module")
def comp_coding_resources_server_client(self) -> Generator[TestClient, None, None]:
ray.init(num_cpus=1)
server = CompCodingResourcesServer(
config=CompCodingResourcesServerConfig(
host="0.0.0.0",
Expand Down
57 changes: 57 additions & 0 deletions tests/unit_tests/test_server_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
DictConfig,
HeadServer,
ServerClient,
initialize_ray,
)


Expand Down Expand Up @@ -154,3 +155,59 @@ async def test_HeadServer_global_config_dict_yaml(self, monkeypatch: MonkeyPatch
resp = await head_server.global_config_dict_yaml()

assert "a: 2\n" == resp

def _mock_ray_return_value(self, monkeypatch: MonkeyPatch, return_value: bool) -> MagicMock:
ray_is_initialized_mock = MagicMock()
ray_is_initialized_mock.return_value = return_value
monkeypatch.setattr(nemo_gym.server_utils.ray, "is_initialized", ray_is_initialized_mock)
return ray_is_initialized_mock

def _mock_ray_init(self, monkeypatch: MonkeyPatch) -> MagicMock:
ray_init_mock = MagicMock()
monkeypatch.setattr(nemo_gym.server_utils.ray, "init", ray_init_mock)
return ray_init_mock

def test_initialize_ray_already_initialized(self, monkeypatch: MonkeyPatch) -> None:
ray_is_initialized_mock = self._mock_ray_return_value(monkeypatch, True)

get_global_config_dict_mock = MagicMock()
monkeypatch.setattr(nemo_gym.server_utils, "get_global_config_dict", get_global_config_dict_mock)

initialize_ray()

ray_is_initialized_mock.assert_called_once()
get_global_config_dict_mock.assert_not_called()

def test_initialize_ray_with_address(self, monkeypatch: MonkeyPatch) -> None:
ray_is_initialized_mock = self._mock_ray_return_value(monkeypatch, False)

ray_init_mock = self._mock_ray_init(monkeypatch)

# Mock global config dict with ray_head_node_address
global_config_dict = DictConfig({"ray_head_node_address": "ray://test-address:10001"})
get_global_config_dict_mock = MagicMock()
get_global_config_dict_mock.return_value = global_config_dict
monkeypatch.setattr(nemo_gym.server_utils, "get_global_config_dict", get_global_config_dict_mock)

initialize_ray()

ray_is_initialized_mock.assert_called_once()
get_global_config_dict_mock.assert_called_once()
ray_init_mock.assert_called_once_with(address="ray://test-address:10001", ignore_reinit_error=True)

def test_initialize_ray_without_address(self, monkeypatch: MonkeyPatch) -> None:
ray_is_initialized_mock = self._mock_ray_return_value(monkeypatch, False)

ray_init_mock = self._mock_ray_init(monkeypatch)

# Mock global config dict without ray_head_node_address
global_config_dict = DictConfig({"k": "v"})
get_global_config_dict_mock = MagicMock()
get_global_config_dict_mock.return_value = global_config_dict
monkeypatch.setattr(nemo_gym.server_utils, "get_global_config_dict", get_global_config_dict_mock)

initialize_ray()

ray_is_initialized_mock.assert_called_once()
get_global_config_dict_mock.assert_called_once()
ray_init_mock.assert_called_once_with(ignore_reinit_error=True)