Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 68 additions & 66 deletions README.md

Large diffs are not rendered by default.

12 changes: 9 additions & 3 deletions resources_servers/gpqa_diamond/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,18 @@ ng_collect_rollouts \
+limit=3
```

`ng_collect_rollouts` also writes sidecar files next to `output_jsonl_fpath`, matching
the same pattern as `test_rollouts*`:
`ng_collect_rollouts` also writes sidecar files next to `output_jsonl_fpath`:

- `*_materialized_inputs.jsonl`
- `*_reward_profiling.jsonl`
- `*_agent_metrics.json`
- `*_aggregate_metrics.json`

`gpqa_diamond` additionally reports subject-area aggregate metrics based on
`metadata.subset_for_metrics`, for example:

- `subset/Organic Chemistry/pass@1/accuracy`
- `subset/Organic Chemistry/majority@1/accuracy`
- `subset/Organic Chemistry/num_tasks`

## Licensing

Expand Down
57 changes: 54 additions & 3 deletions resources_servers/gpqa_diamond/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
# limitations under the License.

import re
from typing import Optional
from collections import defaultdict
from typing import Any, Optional

from nemo_gym.reward_profile import compute_pass_majority_metrics
from resources_servers.mcqa.app import (
MCQAResourcesServer,
MCQAVerifyRequest,
Expand All @@ -37,16 +39,63 @@ def extract_letter(text: str) -> Optional[str]:
if letter_match:
return letter_match[-1].strip()

answer_match = re.findall(r"(?i)[\*\_]{0,2}Answer[\*\_]{0,2}\s*:[\s\*\_]{0,2}\s*([A-Z])(?![a-zA-Z0-9])", text)
answer_match = re.findall(
r"(?i)[\*\_]{0,2}Answer[\*\_]{0,2}\s*:[\s\*\_]{0,2}\s*([A-Z])(?![a-zA-Z0-9])",
text,
)
if answer_match:
return answer_match[-1].strip().upper()

return None


def _get_subset_label(task_rollouts: list[dict[str, Any]]) -> str:
metadata = (
task_rollouts[0].get("metadata") if task_rollouts else None
) or {}
subset = metadata.get("subset_for_metrics")
if isinstance(subset, str) and subset.strip():
return " ".join(subset.split()).replace("/", "_")
return "Unknown"


class GPQADiamondResourcesServer(MCQAResourcesServer):
"""GPQA-Diamond verifier with GPQA-specific answer extraction."""

def compute_metrics(self, tasks):
metrics = super().compute_metrics(tasks)

tasks_by_subset: dict[
str, list[list[dict[str, Any]]]
] = defaultdict(list)
for task_rollouts in tasks:
tasks_by_subset[_get_subset_label(task_rollouts)].append(
task_rollouts
)

for subset, subset_tasks in sorted(tasks_by_subset.items()):
subset_metrics = compute_pass_majority_metrics(
subset_tasks,
score_fn=lambda r: {"accuracy": r["reward"]},
answer_key="extracted_answer",
)
subset_prefix = f"subset/{subset}"
metrics[f"{subset_prefix}/num_tasks"] = len(subset_tasks)
for key, value in subset_metrics.items():
metrics[f"{subset_prefix}/{key}"] = value

return metrics

def get_key_metrics(self, agent_metrics):
key_metrics = super().get_key_metrics(agent_metrics)
for key in sorted(agent_metrics):
if key.startswith("subset/") and (
key.endswith("/pass@1/accuracy")
or key.endswith("/majority@1/accuracy")
):
key_metrics[key] = agent_metrics[key]
return key_metrics

async def verify(self, body: MCQAVerifyRequest) -> MCQAVerifyResponse:
text = body.response.output_text.strip()
options, expected_answer = _extract_options_and_expected(body)
Expand All @@ -56,7 +105,9 @@ async def verify(self, body: MCQAVerifyRequest) -> MCQAVerifyResponse:

if body.template_metadata and "output_regex" in body.template_metadata:
regex_pattern = body.template_metadata["output_regex"]
pred = _parse_answer_with_custom_regex(text, regex_pattern, allowed_letters, options)
pred = _parse_answer_with_custom_regex(
text, regex_pattern, allowed_letters, options
)

if pred is None:
pred = extract_letter(text)
Expand Down
197 changes: 184 additions & 13 deletions resources_servers/gpqa_diamond/tests/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,32 @@
# limitations under the License.
from unittest.mock import MagicMock

from app import GPQADiamondResourcesServer
import pytest

from nemo_gym.openai_utils import NeMoGymResponse
from nemo_gym.server_utils import ServerClient
from resources_servers.mcqa.app import MCQAResourcesServerConfig, MCQAVerifyRequest
from resources_servers.gpqa_diamond.app import GPQADiamondResourcesServer
from resources_servers.mcqa.app import (
MCQAResourcesServerConfig,
MCQAVerifyRequest,
)


class TestApp:
def test_sanity(self) -> None:
config = MCQAResourcesServerConfig(host="0.0.0.0", port=8080, entrypoint="", name="")
GPQADiamondResourcesServer(config=config, server_client=MagicMock(spec=ServerClient))
config = MCQAResourcesServerConfig(
host="0.0.0.0", port=8080, entrypoint="", name=""
)
GPQADiamondResourcesServer(
config=config,
server_client=MagicMock(spec=ServerClient),
)

async def test_verify_gpqa_diamond_template_metadata_priority(self) -> None:
server = GPQADiamondResourcesServer(
config=MCQAResourcesServerConfig(host="0.0.0.0", port=8080, entrypoint="", name=""),
config=MCQAResourcesServerConfig(
host="0.0.0.0", port=8080, entrypoint="", name=""
),
server_client=MagicMock(spec=ServerClient),
)

Expand Down Expand Up @@ -59,7 +70,12 @@ async def test_verify_gpqa_diamond_template_metadata_priority(self) -> None:

verify_request = MCQAVerifyRequest(
responses_create_params={
"input": [{"role": "user", "content": "Question?\nA: optA\nB: optB\nC: optC\nD: optD"}]
"input": [
{
"role": "user",
"content": "Question?\nA: optA\nB: optB\nC: optC\nD: optD",
}
]
},
response=regex_response,
options=[{"A": "optA"}, {"B": "optB"}, {"C": "optC"}, {"D": "optD"}],
Expand All @@ -74,7 +90,9 @@ async def test_verify_gpqa_diamond_template_metadata_priority(self) -> None:

async def test_verify_gpqa_diamond_format(self) -> None:
server = GPQADiamondResourcesServer(
config=MCQAResourcesServerConfig(host="0.0.0.0", port=8080, entrypoint="", name=""),
config=MCQAResourcesServerConfig(
host="0.0.0.0", port=8080, entrypoint="", name=""
),
server_client=MagicMock(spec=ServerClient),
)

Expand All @@ -86,7 +104,13 @@ async def test_verify_gpqa_diamond_format(self) -> None:
output=[
{
"id": "msg_answer",
"content": [{"annotations": [], "text": "Reasoning...\nAnswer: C", "type": "output_text"}],
"content": [
{
"annotations": [],
"text": "Reasoning...\nAnswer: C",
"type": "output_text",
}
],
"role": "assistant",
"status": "completed",
"type": "message",
Expand All @@ -103,7 +127,8 @@ async def test_verify_gpqa_diamond_format(self) -> None:
{
"role": "user",
"content": (
"The last line should be of the format 'Answer: LETTER'. "
"The last line should be of the format "
"'Answer: LETTER'. "
"Question?\nA: optA\nB: optB\nC: optC\nD: optD"
),
}
Expand All @@ -126,7 +151,13 @@ async def test_verify_gpqa_diamond_format(self) -> None:
output=[
{
"id": "msg_boxed",
"content": [{"annotations": [], "text": "Final: \\boxed{C}", "type": "output_text"}],
"content": [
{
"annotations": [],
"text": "Final: \\boxed{C}",
"type": "output_text",
}
],
"role": "assistant",
"status": "completed",
"type": "message",
Expand All @@ -137,7 +168,9 @@ async def test_verify_gpqa_diamond_format(self) -> None:
tools=[],
)
verify_request_boxed = MCQAVerifyRequest(
responses_create_params=verify_request.responses_create_params.model_dump(exclude_none=True),
responses_create_params=verify_request.responses_create_params.model_dump(
exclude_none=True
),
response=boxed_response,
options=verify_request.options,
expected_answer=verify_request.expected_answer,
Expand All @@ -148,9 +181,72 @@ async def test_verify_gpqa_diamond_format(self) -> None:
assert result_boxed.reward == 1.0
assert result_boxed.extracted_answer == "C"

async def test_verify_preserves_subset_metadata_for_aggregation(self) -> None:
server = GPQADiamondResourcesServer(
config=MCQAResourcesServerConfig(
host="0.0.0.0", port=8080, entrypoint="", name=""
),
server_client=MagicMock(spec=ServerClient),
)

response = NeMoGymResponse(
id="resp_subset",
created_at=0.0,
model="dummy",
object="response",
output=[
{
"id": "msg_subset",
"content": [
{
"annotations": [],
"text": "Answer: C",
"type": "output_text",
}
],
"role": "assistant",
"status": "completed",
"type": "message",
}
],
parallel_tool_calls=True,
tool_choice="auto",
tools=[],
)

verify_request = MCQAVerifyRequest(
responses_create_params={
"input": [
{
"role": "user",
"content": "Question?\nA: optA\nB: optB\nC: optC\nD: optD",
}
]
},
response=response,
options=[{"A": "optA"}, {"B": "optB"}, {"C": "optC"}, {"D": "optD"}],
expected_answer="C",
grading_mode="strict_single_letter_boxed",
metadata={"subset_for_metrics": "Organic Chemistry"},
template_metadata={"output_regex": r"Answer:\s*([A-Za-z])"},
)

result = await server.verify(verify_request)

assert result.reward == 1.0
assert result.metadata == {"subset_for_metrics": "Organic Chemistry"}

metrics = server.compute_metrics([[result.model_dump()]])
assert metrics["subset/Organic Chemistry/num_tasks"] == 1
assert metrics["subset/Organic Chemistry/pass@1/accuracy"] == pytest.approx(
100.0
)

async def test_verify_gpqa_diamond_rejects_invalid_letter(self) -> None:
server = GPQADiamondResourcesServer(
config=MCQAResourcesServerConfig(host="0.0.0.0", port=8080, entrypoint="", name=""),
config=MCQAResourcesServerConfig(
host="0.0.0.0", port=8080, entrypoint="", name=""
),
server_client=MagicMock(spec=ServerClient),
)

Expand All @@ -175,7 +271,12 @@ async def test_verify_gpqa_diamond_rejects_invalid_letter(self) -> None:

verify_request = MCQAVerifyRequest(
responses_create_params={
"input": [{"role": "user", "content": "Question?\nA: optA\nB: optB\nC: optC\nD: optD"}]
"input": [
{
"role": "user",
"content": "Question?\nA: optA\nB: optB\nC: optC\nD: optD",
}
]
},
response=invalid_response,
options=[{"A": "optA"}, {"B": "optB"}, {"C": "optC"}, {"D": "optD"}],
Expand All @@ -186,3 +287,73 @@ async def test_verify_gpqa_diamond_rejects_invalid_letter(self) -> None:

assert result.reward == 0.0
assert result.extracted_answer is None

def test_compute_metrics_breaks_down_by_subject(self) -> None:
server = GPQADiamondResourcesServer(
config=MCQAResourcesServerConfig(
host="0.0.0.0", port=8080, entrypoint="", name=""
),
server_client=MagicMock(spec=ServerClient),
)

tasks = [
[
{
"reward": 1.0,
"extracted_answer": "A",
"metadata": {"subset_for_metrics": "Organic Chemistry"},
}
],
[
{
"reward": 0.0,
"extracted_answer": "B",
"metadata": {"subset_for_metrics": "Organic Chemistry"},
}
],
[
{
"reward": 1.0,
"extracted_answer": "C",
"metadata": {"subset_for_metrics": "Quantum Mechanics"},
}
],
]

metrics = server.compute_metrics(tasks)

assert metrics["subset/Organic Chemistry/num_tasks"] == 2
assert metrics["subset/Quantum Mechanics/num_tasks"] == 1
assert metrics["subset/Organic Chemistry/pass@1/accuracy"] == pytest.approx(
50.0
)
assert metrics["subset/Quantum Mechanics/pass@1/accuracy"] == pytest.approx(
100.0
)
assert metrics["pass@1/accuracy"] == pytest.approx((2 / 3) * 100)

key_metrics = server.get_key_metrics(metrics)
assert key_metrics["subset/Organic Chemistry/pass@1/accuracy"] == pytest.approx(
50.0
)
assert key_metrics["subset/Quantum Mechanics/pass@1/accuracy"] == pytest.approx(
100.0
)

def test_compute_metrics_uses_unknown_subject_fallback(self) -> None:
server = GPQADiamondResourcesServer(
config=MCQAResourcesServerConfig(
host="0.0.0.0", port=8080, entrypoint="", name=""
),
server_client=MagicMock(spec=ServerClient),
)

metrics = server.compute_metrics(
[
[{"reward": 1.0, "extracted_answer": "A", "metadata": {}}],
[{"reward": 0.0, "extracted_answer": None}],
]
)

assert metrics["subset/Unknown/num_tasks"] == 2
assert metrics["subset/Unknown/pass@1/accuracy"] == pytest.approx(50.0)
Loading