Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 61 additions & 9 deletions nemo_gym/dataset_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from typing import List
from typing import Any, Dict, List

from gradio import Blocks, Chatbot, ChatMessage, Dropdown
from gradio import JSON, Blocks, Chatbot, ChatMessage, Dropdown
from gradio.components.chatbot import MetadataDict
from openai.types.responses.response_input_param import (
EasyInputMessageParam,
Expand All @@ -28,6 +28,11 @@

from nemo_gym.base_resources_server import BaseVerifyResponse
from nemo_gym.server_utils import get_global_config_dict
from nemo_gym.train_data_utils import (
AvgMinMax,
DatasetMetrics,
compute_sample_metrics,
)


class DatasetViewerVerifyResponse(BaseVerifyResponse):
Expand Down Expand Up @@ -131,7 +136,6 @@ def convert_single_message(m: ResponseInputItemParam) -> List[ChatMessage]:

def rollout_to_messages(create_params: dict, response: dict) -> List[ChatMessage]:
messages = []

sampling_params = create_params.copy()
sampling_params.pop("input")
sampling_params.pop("tools", None)
Expand Down Expand Up @@ -202,14 +206,59 @@ class JsonlDatasetViewerConfig(BaseModel):
jsonl_fpath: str


def aggregate_other_metrics(data: List[DatasetViewerVerifyResponse]) -> Dict[str, Any]:
metric_values = {}
string_values = {}
for d in data:
d = d.model_dump() if hasattr(d, "model_dump") else d
for k, v in d.items():
if k in ("responses_create_params", "response"):
continue
if isinstance(v, bool):
v = int(v)
if isinstance(v, (int, float)):
metric_values.setdefault(k, []).append(v)
# get unique count for strings
elif isinstance(v, str):
string_values.setdefault(k, []).append(v)

result = {}
for k, v in metric_values.items():
if v:
obj = AvgMinMax(
total=len(v),
average=sum(v) / len(v),
min=min(v),
max=max(v),
)
result[k] = obj.model_dump(by_alias=True)

for k, v in string_values.items():
result[k] = {"unique_count": len(set(v)), "total_count": len(v)}

return result


def get_aggregate_metrics(data: List[DatasetViewerVerifyResponse], raw_lines: List[str]) -> Dict[str, Any]:
dataset_metrics = DatasetMetrics()
for line in raw_lines:
metrics, is_offending = compute_sample_metrics(line)
if not is_offending:
dataset_metrics.add(metrics)

aggregate_metrics = dataset_metrics.aggregate()
aggregate_metrics_dict = aggregate_metrics.model_dump(by_alias=True)
aggregate_metrics_dict.update(**aggregate_other_metrics(data))
return aggregate_metrics_dict


def build_jsonl_dataset_viewer(config: JsonlDatasetViewerConfig) -> Blocks:
data = []
raw_lines = []
with open(config.jsonl_fpath) as f:
data = list(
tqdm(
map(DatasetViewerVerifyResponse.model_validate_json, f),
desc="Loading data",
)
)
for line in tqdm(f, desc="Loading data"):
raw_lines.append(line)
data.append(DatasetViewerVerifyResponse.model_validate_json(line))

choices = [(f"Sample {i + 1} - Responses ID {d.response.id}", i) for i, d in enumerate(data)]

Expand All @@ -225,6 +274,9 @@ def select_item(value: int):
}
"""
with Blocks(analytics_enabled=False, css=CSS) as demo:
aggregate_dicts = get_aggregate_metrics(data, raw_lines)
JSON(value=aggregate_dicts, label="Aggregate Metrics", open=False)

item_dropdown = Dropdown(choices=choices, value=0, label="Samples")
chatbot = Chatbot(
value=select_item(0),
Expand Down
133 changes: 70 additions & 63 deletions nemo_gym/train_data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from itertools import count, repeat
from pathlib import Path
from shutil import copyfileobj
from typing import Dict, List, Literal, Optional, Self, Union
from typing import Dict, List, Literal, Optional, Self, Tuple, Union

from devtools import pprint
from omegaconf import DictConfig
Expand Down Expand Up @@ -128,6 +128,72 @@ def _aggregate(self: Self) -> Self:
)


def compute_sample_metrics(sample_dict_str: str) -> Tuple[DatasetMetrics, bool]:
try:
sample_dict = json.loads(sample_dict_str)
except json.JSONDecodeError:
return DatasetMetrics(), True

try:
sample = BaseRunRequest.model_validate(sample_dict)
except ValidationError:
return DatasetMetrics(), True

responses_create_params = sample.responses_create_params
responses_create_params = responses_create_params.model_dump(exclude_unset=True)
inputs = responses_create_params.get("input")

number_of_tools_metrics = AvgMinMax()
if responses_create_params.get("tools") is not None:
number_of_tools = len(responses_create_params["tools"])
number_of_tools_metrics = AvgMinMax(
total=1,
average=number_of_tools,
min=number_of_tools,
max=number_of_tools,
)

if isinstance(inputs, str):
inputs = [{"role": "user", "content": inputs}]
user_inputs = [i for i in inputs if i.get("role") == "user"] if inputs else []
number_of_turns_metrics = AvgMinMax()
if user_inputs:
number_of_turns = len(user_inputs)
number_of_turns_metrics = AvgMinMax(
total=1,
average=number_of_turns,
min=number_of_turns,
max=number_of_turns,
)

temperature_metrics = AvgMinMax()
if responses_create_params.get("temperature") is not None:
temperature = responses_create_params["temperature"]
temperature_metrics = AvgMinMax(
total=1,
average=temperature,
min=temperature,
max=temperature,
)

json_dumped_number_of_words = len(json.dumps(responses_create_params).split())
json_dumped_number_of_words_metrics = AvgMinMax(
total=1,
average=json_dumped_number_of_words,
min=json_dumped_number_of_words,
max=json_dumped_number_of_words,
)

metrics = DatasetMetrics(
number_of_examples=1,
number_of_tools=number_of_tools_metrics,
json_dumped_number_of_words=json_dumped_number_of_words_metrics,
number_of_turns=number_of_turns_metrics,
temperature=temperature_metrics,
)
return metrics, False


class DatasetValidatorState(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)

Expand Down Expand Up @@ -283,72 +349,13 @@ def load_datasets(
def _validate_samples_and_aggregate_metrics_single_sample(
self, state: DatasetValidatorState, sample_idx: int, sample_dict_str: str
) -> None:
try:
sample_dict = json.loads(sample_dict_str)
except json.JSONDecodeError:
state.offending_example_idxs.append(sample_idx)
return

try:
sample = BaseRunRequest.model_validate(sample_dict)
except ValidationError:
metrics, is_offending = compute_sample_metrics(sample_dict_str)
if is_offending:
state.offending_example_idxs.append(sample_idx)
return

sample_dict = json.loads(sample_dict_str)
state.key_counts.update(sample_dict.keys())

responses_create_params = sample.responses_create_params
responses_create_params = responses_create_params.model_dump(exclude_unset=True)
inputs = responses_create_params["input"]

number_of_tools_metrics = AvgMinMax()
if responses_create_params.get("tools") is not None:
number_of_tools = len(responses_create_params["tools"])
number_of_tools_metrics = AvgMinMax(
total=1,
average=number_of_tools,
min=number_of_tools,
max=number_of_tools,
)

if isinstance(inputs, str):
inputs = [{"role": "user", "content": inputs}]
user_inputs = [i for i in inputs if i.get("role") == "user"]
number_of_turns_metrics = AvgMinMax()
if user_inputs:
number_of_turns = len(user_inputs)
number_of_turns_metrics = AvgMinMax(
total=1,
average=number_of_turns,
min=number_of_turns,
max=number_of_turns,
)

temperature_metrics = AvgMinMax()
if responses_create_params.get("temperature") is not None:
temperature = responses_create_params["temperature"]
temperature_metrics = AvgMinMax(
total=1,
average=temperature,
min=temperature,
max=temperature,
)

json_dumped_number_of_words = len(json.dumps(responses_create_params).split())
json_dumped_number_of_words_metrics = AvgMinMax(
total=1,
average=json_dumped_number_of_words,
min=json_dumped_number_of_words,
max=json_dumped_number_of_words,
)

metrics = DatasetMetrics(
number_of_examples=1,
number_of_tools=number_of_tools_metrics,
json_dumped_number_of_words=json_dumped_number_of_words_metrics,
number_of_turns=number_of_turns_metrics,
temperature=temperature_metrics,
)
state.metrics.add(metrics)

def _validate_samples_and_aggregate_metrics_single_dataset(
Expand Down
87 changes: 86 additions & 1 deletion tests/unit_tests/test_dataset_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from unittest.mock import mock_open, patch

from nemo_gym.dataset_viewer import JsonlDatasetViewerConfig, build_jsonl_dataset_viewer
from pydantic import BaseModel
from pytest import MonkeyPatch

from nemo_gym.dataset_viewer import (
JsonlDatasetViewerConfig,
build_jsonl_dataset_viewer,
get_aggregate_metrics,
)


class TestDatasetViewer:
Expand All @@ -31,3 +39,80 @@ def test_sanity(
mock_content = r"""{"reward": 0.0, "accuracy": false, "set_overlap": 0.0, "original_term_minefield_hit": false, "order_instruction_following_failure": false, "id": 44, "expected_synonym_values": [489, 504], "expected_synonyms": ["Awake", "Alert"], "minefield_label": "Alive", "minefield_label_value": 497, "responses_create_params": {"input": [{"content": "# Instructions", "role": "system"}, {"content": "How does the human body's response to danger highlight the instinct for survival?", "role": "user"}]}, "response": {"id": "resp_689038d64ad081929f6f36d2f2554431063ce0bbdad9d001", "created_at": 1754282198.0, "error": null, "incomplete_details": null, "instructions": null, "metadata": {}, "model": "gpt-4.1-2025-04-14", "object": "response", "output": [{"content": [{"annotations": [], "text": "fake chat message", "type": "output_text"}], "role": "assistant", "type": "message", "id": "fc_689038d5d1cc8192b91bbef0069ff82f063ce0bbdad9d001", "status": "completed"}, {"summary": [{"type": "summary_text", "text": "fake reasoning"}], "type": "reasoning", "id": "fc_689038d5d1cc8192b91bbef0069ff82f063ce0bbdad9d001", "status": "completed"}, {"arguments": "{\"synonym\":\"Survival\"}", "call_id": "call_pyKbpFtdag6LL6euwpAJ8UEw", "name": "get_synonym_value", "type": "function_call", "id": "fc_689038d5d1cc8192b91bbef0069ff82f063ce0bbdad9d001", "status": "completed"}, {"call_id": "call_pyKbpFtdag6LL6euwpAJ8UEw", "output": "{\"synonym_value\": 860}", "type": "function_call_output"}, {"arguments": "{\"synonym_values\":[860]}", "call_id": "call_N4Kr3NJJohoTaxSL5DkkG4W6", "name": "extract_synonym_values", "type": "function_call", "id": "fc_689038d6c71c8192b7ab0de7cc655d98063ce0bbdad9d001", "status": "completed"}], "parallel_tool_calls": false, "temperature": 1.0, "tool_choice": "auto", "tools": [{"name": "get_synonym_value", "parameters": {"properties": {"synonym": {"type": "string", "title": "Synonym", "description": "The synonym to get the value for."}}, "type": "object", "required": ["synonym"], "additionalProperties": false}, "strict": true, "type": "function", "description": "Get the synonym value for a synonym.\nThis operation returns a value that conforms to the following JSON Schema: {\"properties\": {\"synonym_value\": {\"type\": \"integer\", \"title\": \"Synonym Value\", \"description\": \"The value for this synonym.\"}}, \"type\": \"object\", \"required\": [\"synonym_value\"]}\n"}, {"name": "extract_synonym_values", "parameters": {"properties": {"synonym_values": {"items": {"type": "integer"}, "type": "array", "title": "Synonym Values", "description": "The synonym values corresponding to the term for the user query."}}, "type": "object", "required": ["synonym_values"], "additionalProperties": false}, "strict": true, "type": "function", "description": "Extract the synonym values you retrieved for the term that is relevant to the user query.\nThis operation returns a value that conforms to the following JSON Schema: {\"properties\": {\"success\": {\"type\": \"boolean\", \"title\": \"Success\", \"description\": \"Success.\"}}, \"type\": \"object\", \"required\": [\"success\"]}\n"}], "top_p": 1.0, "background": false, "max_output_tokens": null, "max_tool_calls": null, "previous_response_id": null, "prompt": null, "reasoning": {"effort": null, "generate_summary": null, "summary": null}, "service_tier": "default", "status": "completed", "text": {"format": {"type": "text"}}, "top_logprobs": 0, "truncation": "disabled", "usage": {"input_tokens": 2864, "input_tokens_details": {"cached_tokens": 2798}, "output_tokens": 19, "output_tokens_details": {"reasoning_tokens": 0}, "total_tokens": 2883}, "user": null, "prompt_cache_key": null, "safety_identifier": null, "store": true}}"""
with patch("builtins.open", mock_open(read_data=mock_content)):
build_jsonl_dataset_viewer(config)

def test_get_aggregate_metrics(self, monkeypatch: MonkeyPatch):
class DummySample(BaseModel):
responses_create_params: dict = {}
response: dict = {}
reward: float = 1.0
accuracy: bool = True
set_overlap: float = 0.5
unrelated_list: list = []
unrelated_dict: dict = {}

class DummySampleWithStrings(DummySample):
some_string: str

samples = [
DummySample(reward=1.0, accuracy=True, set_overlap=0.5),
DummySample(reward=0.0, accuracy=False, set_overlap=0.0),
DummySample(reward=0.5, accuracy=True, set_overlap=1.0),
]

samples_with_strings = [
DummySampleWithStrings(reward=1.0, accuracy=True, some_string="asdf"),
DummySampleWithStrings(reward=0.0, accuracy=False, some_string="asdf"),
DummySampleWithStrings(reward=0.5, accuracy=True, some_string="word1"),
DummySampleWithStrings(reward=0.5, accuracy=True, some_string="word1"),
DummySampleWithStrings(reward=0.5, accuracy=True, some_string="word2"),
]

def mock_compute_sample_metrics(line: str):
metrics = json.loads(line)
return metrics, False

class DummyAgg:
def model_dump(self, by_alias=True):
return {}

def aggregate(self):
return DummyAgg()

monkeypatch.setattr(
"nemo_gym.train_data_utils.compute_sample_metrics",
mock_compute_sample_metrics,
)

result_1 = get_aggregate_metrics(samples, "{}\n")

assert "reward" in result_1
assert "accuracy" in result_1
assert "set_overlap" in result_1

assert "unrelated_str" not in result_1
assert "unrelated_list" not in result_1
assert "unrelated_dict" not in result_1

assert "responses_create_params" not in result_1
assert "response" not in result_1

# Check computed values
reward_stats = result_1["reward"]
assert reward_stats["Total # non-null values"] == 3
assert reward_stats["Average"] == (1.0 + 0.0 + 0.5) / 3
assert reward_stats["Min"] == 0.0
assert reward_stats["Max"] == 1.0

# Check computed values with bools converted to int
accuracy_stats = result_1["accuracy"]
assert accuracy_stats["Total # non-null values"] == 3
assert accuracy_stats["Average"] == (1 + 0 + 1) / 3
assert accuracy_stats["Min"] == 0
assert accuracy_stats["Max"] == 1

# Check string counts
result_2 = get_aggregate_metrics(samples_with_strings, "{}\n")

assert "some_string" in result_2
assert result_2["some_string"]["unique_count"] == 3
assert result_2["some_string"]["total_count"] == 5