diff --git a/nemo_gym/server_utils.py b/nemo_gym/server_utils.py index 7f9e2ac97..83285e6af 100644 --- a/nemo_gym/server_utils.py +++ b/nemo_gym/server_utils.py @@ -27,6 +27,7 @@ from typing import Literal, Optional, Tuple, Type, Union, Unpack from uuid import uuid4 +import ray import requests import uvicorn import yappi @@ -309,6 +310,22 @@ class UvicornLoggingConfig(BaseModel): uvicorn_logging_show_200_ok: bool = False +def initialize_ray() -> None: + if ray.is_initialized(): + print("Ray already initialized") + return + + global_config_dict = get_global_config_dict() + ray_head_node_address = global_config_dict.get("ray_head_node_address") + + if ray_head_node_address is not None: + print(f"Connecting to Ray cluster at specified address: {ray_head_node_address}") + ray.init(address=ray_head_node_address, ignore_reinit_error=True) + else: + print("Starting Ray cluster...") + ray.init(ignore_reinit_error=True) + + class SimpleServer(BaseServer): server_client: ServerClient @@ -434,6 +451,8 @@ def set_ulimit(self, target_soft_limit: int = 65535): # pragma: no cover def run_webserver(cls) -> None: # pragma: no cover global_config_dict = get_global_config_dict() + initialize_ray() + server_config = cls.load_config_from_global_config() server_client = ServerClient( head_server_config=ServerClient.load_head_server_config(), diff --git a/pyproject.toml b/pyproject.toml index 7e5a9a669..47ee42fca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -136,6 +136,11 @@ dependencies = [ # Updated Mon Sep 22, 2025 with yappi==1.6.10 # License: MIT https://github.com/sumerc/yappi/blob/1d3f7501701e1f050b6dcd6a86fd36aec08185c7/LICENSE "yappi", + + # Ray: Used for distributed processing + # Updated Fri Oct 18, 2025 with ray[default]==2.46.0 + # License: Apache 2.0 https://github.com/ray-project/ray/blob/master/LICENSE + "ray[default]", ] [dependency-groups] diff --git a/resources_servers/comp_coding/app.py b/resources_servers/comp_coding/app.py index 81010ab30..c8898fcec 100644 --- a/resources_servers/comp_coding/app.py +++ b/resources_servers/comp_coding/app.py @@ -16,7 +16,8 @@ from time import time from typing import Any, Dict, List, Optional, Union -from lcb_integration.compute_code_generation_metrics import check_correctness +import ray +from lcb_integration.compute_code_generation_metrics import check_correctness_remote from lcb_integration.extraction_utils import LMStyle, extract_code from pydantic import BaseModel @@ -124,14 +125,17 @@ async def verify(self, body: CompCodingVerifyRequest) -> CompCodingVerifyRespons # We can directly measure here since we are inside the semaphore. start_time = time() - result, metadata = await loop.run_in_executor( - None, - check_correctness, + + task_args = ( {"input_output": tests.model_dump_json()}, # sample code, # generation self.config.unit_test_timeout_secs, # timeout self.config.debug, # debug ) + + future = check_correctness_remote.remote(*task_args) + result, metadata = await loop.run_in_executor(None, ray.get, future) + unit_tests_time_taken = time() - start_time return CompCodingVerifyResponse( diff --git a/resources_servers/comp_coding/lcb_integration/compute_code_generation_metrics.py b/resources_servers/comp_coding/lcb_integration/compute_code_generation_metrics.py index 1b80df7f4..d3bf13ecb 100644 --- a/resources_servers/comp_coding/lcb_integration/compute_code_generation_metrics.py +++ b/resources_servers/comp_coding/lcb_integration/compute_code_generation_metrics.py @@ -16,49 +16,61 @@ # borrowed and extended from # https://github.com/Naman-ntc/codescratch/blob/main/evaluation/bigcode-evaluation-harness/lm_eval/tasks/custom_metrics/apps_custom_metrics/utils.py -import os -import sys - - -sys.set_int_max_str_digits(50000) - -os.environ["TOKENIZERS_PARALLELISM"] = "false" import json import multiprocessing +import os +import sys from collections import defaultdict from concurrent.futures import ProcessPoolExecutor, as_completed import numpy as np +import ray from tqdm import tqdm from lcb_integration.pass_k_utils import compute_metrics_from_results from lcb_integration.testing_util import run_test -def _temp_run(sample, generation, debug, result, metadata_list, timeout): - res, metadata = run_test(sample, test=generation, debug=debug, timeout=timeout) +sys.set_int_max_str_digits(50000) +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +def _temp_run(in_outs, generation, debug, result, metadata_list, timeout): + res, metadata = run_test(in_outs, test=generation, debug=debug, timeout=timeout) result.append(res) metadata_list.append(metadata) +# Using SPREAD scheduling so that Ray assigns tasks to as many distinct nodes as possible. +@ray.remote(scheduling_strategy="SPREAD") +def check_correctness_remote(sample, generation, timeout, debug=True): + """Ray wrapper of check_correctness for remote execution.""" + return check_correctness(sample, generation, timeout, debug) + + def check_correctness(sample, generation, timeout, debug=True): """Check correctness of code generation with a global timeout. The global timeout is to catch some extreme/rare cases not handled by the timeouts inside `run_test`""" + # Parse JSON once at the beginning to avoid multiple parsing + try: + in_outs = json.loads(sample["input_output"]) + except (ValueError, MemoryError): + return [-1], None + manager = multiprocessing.Manager() result = manager.list() metadata_list = manager.list() p = multiprocessing.Process( target=_temp_run, - args=(sample, generation, debug, result, metadata_list, timeout), + args=(in_outs, generation, debug, result, metadata_list, timeout), ) p.start() - p.join(timeout=(timeout + 1) * len(json.loads(sample["input_output"])["inputs"]) + 5) + p.join(timeout=(timeout + 1) * len(in_outs["inputs"]) + 5) if p.is_alive(): p.kill() if not result: - in_outs = json.loads(sample["input_output"]) # consider that all tests failed result = [[-1 for i in range(len(in_outs["inputs"]))]] metadata_list = [None] diff --git a/resources_servers/comp_coding/lcb_integration/testing_util.py b/resources_servers/comp_coding/lcb_integration/testing_util.py index 522d7a3d2..0bb73bac9 100644 --- a/resources_servers/comp_coding/lcb_integration/testing_util.py +++ b/resources_servers/comp_coding/lcb_integration/testing_util.py @@ -21,6 +21,7 @@ import signal import sys import time +import traceback # used for debugging to time steps from datetime import datetime @@ -306,6 +307,7 @@ def grade_call_based(code: str, all_inputs: list, all_outputs: list, fn_name: st "error_message": "Runtime Error", "inputs": truncatefn(gt_inp), "expected": truncatefn(gt_out), + "traceback": traceback.format_exc(), } finally: @@ -369,6 +371,7 @@ def grade_stdio( "error_message": "Runtime Error", "inputs": truncatefn(gt_inp), "expected": truncatefn(gt_out), + "traceback": traceback.format_exc(), } finally: @@ -431,7 +434,7 @@ def grade_stdio( return all_results, {"execution time": total_execution_time} -def run_test(sample, test=None, debug=False, timeout=6): +def run_test(in_outs, test=None, debug=False, timeout=6): """ if test(generated_code) is not None it'll try to run the code. otherwise it'll just return an input and output pair. @@ -440,16 +443,17 @@ def run_test(sample, test=None, debug=False, timeout=6): # Disable functionalities that can make destructive changes to the test. # max memory is set to 4GB - reliability_guard(4 * 1024**3) + reliability_guard(maximum_memory_bytes=4 * 1024**3) if debug: print(f"start = {datetime.now().time()}") - try: - in_outs = json.loads(sample["input_output"]) - except ValueError as e: - raise e - in_outs = None + # The in_outs is already loaded from the sample from the parent process + # try: + # in_outs = json.loads(sample["input_output"]) + # except ValueError as e: + # raise e + # in_outs = None if in_outs: if in_outs.get("fn_name") is None: @@ -537,7 +541,9 @@ def reliability_guard(maximum_memory_bytes=None): if maximum_memory_bytes is not None: import resource - _set_resource_limit(resource.RLIMIT_AS, maximum_memory_bytes) + # The resource limit on RLIMIT_AS has been disabled because setting it caused additional out-of-memory (OOM) issues with Ray. + # This happens since Ray and its subprocesses share the same virtual address space. + # _set_resource_limit(resource.RLIMIT_AS, maximum_memory_bytes) _set_resource_limit(resource.RLIMIT_DATA, maximum_memory_bytes) if not platform.uname().system == "Darwin": diff --git a/resources_servers/comp_coding/tests/test_app.py b/resources_servers/comp_coding/tests/test_app.py index d5e021312..2579e56d5 100644 --- a/resources_servers/comp_coding/tests/test_app.py +++ b/resources_servers/comp_coding/tests/test_app.py @@ -16,6 +16,7 @@ from unittest.mock import MagicMock import pytest +import ray from app import ( CompCodingResourcesServer, CompCodingResourcesServerConfig, @@ -32,6 +33,7 @@ class TestApp: @pytest.fixture(scope="module") def comp_coding_resources_server_client(self) -> Generator[TestClient, None, None]: + ray.init(num_cpus=1) server = CompCodingResourcesServer( config=CompCodingResourcesServerConfig( host="0.0.0.0", diff --git a/tests/unit_tests/test_server_utils.py b/tests/unit_tests/test_server_utils.py index 6a64c32da..5011cfa24 100644 --- a/tests/unit_tests/test_server_utils.py +++ b/tests/unit_tests/test_server_utils.py @@ -27,6 +27,7 @@ DictConfig, HeadServer, ServerClient, + initialize_ray, ) @@ -154,3 +155,59 @@ async def test_HeadServer_global_config_dict_yaml(self, monkeypatch: MonkeyPatch resp = await head_server.global_config_dict_yaml() assert "a: 2\n" == resp + + def _mock_ray_return_value(self, monkeypatch: MonkeyPatch, return_value: bool) -> MagicMock: + ray_is_initialized_mock = MagicMock() + ray_is_initialized_mock.return_value = return_value + monkeypatch.setattr(nemo_gym.server_utils.ray, "is_initialized", ray_is_initialized_mock) + return ray_is_initialized_mock + + def _mock_ray_init(self, monkeypatch: MonkeyPatch) -> MagicMock: + ray_init_mock = MagicMock() + monkeypatch.setattr(nemo_gym.server_utils.ray, "init", ray_init_mock) + return ray_init_mock + + def test_initialize_ray_already_initialized(self, monkeypatch: MonkeyPatch) -> None: + ray_is_initialized_mock = self._mock_ray_return_value(monkeypatch, True) + + get_global_config_dict_mock = MagicMock() + monkeypatch.setattr(nemo_gym.server_utils, "get_global_config_dict", get_global_config_dict_mock) + + initialize_ray() + + ray_is_initialized_mock.assert_called_once() + get_global_config_dict_mock.assert_not_called() + + def test_initialize_ray_with_address(self, monkeypatch: MonkeyPatch) -> None: + ray_is_initialized_mock = self._mock_ray_return_value(monkeypatch, False) + + ray_init_mock = self._mock_ray_init(monkeypatch) + + # Mock global config dict with ray_head_node_address + global_config_dict = DictConfig({"ray_head_node_address": "ray://test-address:10001"}) + get_global_config_dict_mock = MagicMock() + get_global_config_dict_mock.return_value = global_config_dict + monkeypatch.setattr(nemo_gym.server_utils, "get_global_config_dict", get_global_config_dict_mock) + + initialize_ray() + + ray_is_initialized_mock.assert_called_once() + get_global_config_dict_mock.assert_called_once() + ray_init_mock.assert_called_once_with(address="ray://test-address:10001", ignore_reinit_error=True) + + def test_initialize_ray_without_address(self, monkeypatch: MonkeyPatch) -> None: + ray_is_initialized_mock = self._mock_ray_return_value(monkeypatch, False) + + ray_init_mock = self._mock_ray_init(monkeypatch) + + # Mock global config dict without ray_head_node_address + global_config_dict = DictConfig({"k": "v"}) + get_global_config_dict_mock = MagicMock() + get_global_config_dict_mock.return_value = global_config_dict + monkeypatch.setattr(nemo_gym.server_utils, "get_global_config_dict", get_global_config_dict_mock) + + initialize_ray() + + ray_is_initialized_mock.assert_called_once() + get_global_config_dict_mock.assert_called_once() + ray_init_mock.assert_called_once_with(ignore_reinit_error=True)