Skip to content

Commit bb1a3ba

Browse files
Dataset viewer simple aggregations (#9)
Migrated over from gitlab: - Display aggregate metrics - Aggregate generic keys using multineedle - Display other dynamic aggregations - Count string totals and unique values - Remove TrainDataProcessor dependency, add test - Remove dupe file read, fix arg types hints --------- Signed-off-by: Frankie Siino <fsiino@nvidia.com>
1 parent 0cb45db commit bb1a3ba

3 files changed

Lines changed: 217 additions & 73 deletions

File tree

nemo_gym/dataset_viewer.py

Lines changed: 61 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import json
15-
from typing import List
15+
from typing import Any, Dict, List
1616

17-
from gradio import Blocks, Chatbot, ChatMessage, Dropdown
17+
from gradio import JSON, Blocks, Chatbot, ChatMessage, Dropdown
1818
from gradio.components.chatbot import MetadataDict
1919
from openai.types.responses.response_input_param import (
2020
EasyInputMessageParam,
@@ -28,6 +28,11 @@
2828

2929
from nemo_gym.base_resources_server import BaseVerifyResponse
3030
from nemo_gym.server_utils import get_global_config_dict
31+
from nemo_gym.train_data_utils import (
32+
AvgMinMax,
33+
DatasetMetrics,
34+
compute_sample_metrics,
35+
)
3136

3237

3338
class DatasetViewerVerifyResponse(BaseVerifyResponse):
@@ -131,7 +136,6 @@ def convert_single_message(m: ResponseInputItemParam) -> List[ChatMessage]:
131136

132137
def rollout_to_messages(create_params: dict, response: dict) -> List[ChatMessage]:
133138
messages = []
134-
135139
sampling_params = create_params.copy()
136140
sampling_params.pop("input")
137141
sampling_params.pop("tools", None)
@@ -202,14 +206,59 @@ class JsonlDatasetViewerConfig(BaseModel):
202206
jsonl_fpath: str
203207

204208

209+
def aggregate_other_metrics(data: List[DatasetViewerVerifyResponse]) -> Dict[str, Any]:
210+
metric_values = {}
211+
string_values = {}
212+
for d in data:
213+
d = d.model_dump() if hasattr(d, "model_dump") else d
214+
for k, v in d.items():
215+
if k in ("responses_create_params", "response"):
216+
continue
217+
if isinstance(v, bool):
218+
v = int(v)
219+
if isinstance(v, (int, float)):
220+
metric_values.setdefault(k, []).append(v)
221+
# get unique count for strings
222+
elif isinstance(v, str):
223+
string_values.setdefault(k, []).append(v)
224+
225+
result = {}
226+
for k, v in metric_values.items():
227+
if v:
228+
obj = AvgMinMax(
229+
total=len(v),
230+
average=sum(v) / len(v),
231+
min=min(v),
232+
max=max(v),
233+
)
234+
result[k] = obj.model_dump(by_alias=True)
235+
236+
for k, v in string_values.items():
237+
result[k] = {"unique_count": len(set(v)), "total_count": len(v)}
238+
239+
return result
240+
241+
242+
def get_aggregate_metrics(data: List[DatasetViewerVerifyResponse], raw_lines: List[str]) -> Dict[str, Any]:
243+
dataset_metrics = DatasetMetrics()
244+
for line in raw_lines:
245+
metrics, is_offending = compute_sample_metrics(line)
246+
if not is_offending:
247+
dataset_metrics.add(metrics)
248+
249+
aggregate_metrics = dataset_metrics.aggregate()
250+
aggregate_metrics_dict = aggregate_metrics.model_dump(by_alias=True)
251+
aggregate_metrics_dict.update(**aggregate_other_metrics(data))
252+
return aggregate_metrics_dict
253+
254+
205255
def build_jsonl_dataset_viewer(config: JsonlDatasetViewerConfig) -> Blocks:
256+
data = []
257+
raw_lines = []
206258
with open(config.jsonl_fpath) as f:
207-
data = list(
208-
tqdm(
209-
map(DatasetViewerVerifyResponse.model_validate_json, f),
210-
desc="Loading data",
211-
)
212-
)
259+
for line in tqdm(f, desc="Loading data"):
260+
raw_lines.append(line)
261+
data.append(DatasetViewerVerifyResponse.model_validate_json(line))
213262

214263
choices = [(f"Sample {i + 1} - Responses ID {d.response.id}", i) for i, d in enumerate(data)]
215264

@@ -225,6 +274,9 @@ def select_item(value: int):
225274
}
226275
"""
227276
with Blocks(analytics_enabled=False, css=CSS) as demo:
277+
aggregate_dicts = get_aggregate_metrics(data, raw_lines)
278+
JSON(value=aggregate_dicts, label="Aggregate Metrics", open=False)
279+
228280
item_dropdown = Dropdown(choices=choices, value=0, label="Samples")
229281
chatbot = Chatbot(
230282
value=select_item(0),

nemo_gym/train_data_utils.py

Lines changed: 70 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from itertools import count, repeat
1818
from pathlib import Path
1919
from shutil import copyfileobj
20-
from typing import Dict, List, Literal, Optional, Self, Union
20+
from typing import Dict, List, Literal, Optional, Self, Tuple, Union
2121

2222
from devtools import pprint
2323
from omegaconf import DictConfig
@@ -128,6 +128,72 @@ def _aggregate(self: Self) -> Self:
128128
)
129129

130130

131+
def compute_sample_metrics(sample_dict_str: str) -> Tuple[DatasetMetrics, bool]:
132+
try:
133+
sample_dict = json.loads(sample_dict_str)
134+
except json.JSONDecodeError:
135+
return DatasetMetrics(), True
136+
137+
try:
138+
sample = BaseRunRequest.model_validate(sample_dict)
139+
except ValidationError:
140+
return DatasetMetrics(), True
141+
142+
responses_create_params = sample.responses_create_params
143+
responses_create_params = responses_create_params.model_dump(exclude_unset=True)
144+
inputs = responses_create_params.get("input")
145+
146+
number_of_tools_metrics = AvgMinMax()
147+
if responses_create_params.get("tools") is not None:
148+
number_of_tools = len(responses_create_params["tools"])
149+
number_of_tools_metrics = AvgMinMax(
150+
total=1,
151+
average=number_of_tools,
152+
min=number_of_tools,
153+
max=number_of_tools,
154+
)
155+
156+
if isinstance(inputs, str):
157+
inputs = [{"role": "user", "content": inputs}]
158+
user_inputs = [i for i in inputs if i.get("role") == "user"] if inputs else []
159+
number_of_turns_metrics = AvgMinMax()
160+
if user_inputs:
161+
number_of_turns = len(user_inputs)
162+
number_of_turns_metrics = AvgMinMax(
163+
total=1,
164+
average=number_of_turns,
165+
min=number_of_turns,
166+
max=number_of_turns,
167+
)
168+
169+
temperature_metrics = AvgMinMax()
170+
if responses_create_params.get("temperature") is not None:
171+
temperature = responses_create_params["temperature"]
172+
temperature_metrics = AvgMinMax(
173+
total=1,
174+
average=temperature,
175+
min=temperature,
176+
max=temperature,
177+
)
178+
179+
json_dumped_number_of_words = len(json.dumps(responses_create_params).split())
180+
json_dumped_number_of_words_metrics = AvgMinMax(
181+
total=1,
182+
average=json_dumped_number_of_words,
183+
min=json_dumped_number_of_words,
184+
max=json_dumped_number_of_words,
185+
)
186+
187+
metrics = DatasetMetrics(
188+
number_of_examples=1,
189+
number_of_tools=number_of_tools_metrics,
190+
json_dumped_number_of_words=json_dumped_number_of_words_metrics,
191+
number_of_turns=number_of_turns_metrics,
192+
temperature=temperature_metrics,
193+
)
194+
return metrics, False
195+
196+
131197
class DatasetValidatorState(BaseModel):
132198
model_config = ConfigDict(arbitrary_types_allowed=True)
133199

@@ -283,72 +349,13 @@ def load_datasets(
283349
def _validate_samples_and_aggregate_metrics_single_sample(
284350
self, state: DatasetValidatorState, sample_idx: int, sample_dict_str: str
285351
) -> None:
286-
try:
287-
sample_dict = json.loads(sample_dict_str)
288-
except json.JSONDecodeError:
289-
state.offending_example_idxs.append(sample_idx)
290-
return
291-
292-
try:
293-
sample = BaseRunRequest.model_validate(sample_dict)
294-
except ValidationError:
352+
metrics, is_offending = compute_sample_metrics(sample_dict_str)
353+
if is_offending:
295354
state.offending_example_idxs.append(sample_idx)
296355
return
297356

357+
sample_dict = json.loads(sample_dict_str)
298358
state.key_counts.update(sample_dict.keys())
299-
300-
responses_create_params = sample.responses_create_params
301-
responses_create_params = responses_create_params.model_dump(exclude_unset=True)
302-
inputs = responses_create_params["input"]
303-
304-
number_of_tools_metrics = AvgMinMax()
305-
if responses_create_params.get("tools") is not None:
306-
number_of_tools = len(responses_create_params["tools"])
307-
number_of_tools_metrics = AvgMinMax(
308-
total=1,
309-
average=number_of_tools,
310-
min=number_of_tools,
311-
max=number_of_tools,
312-
)
313-
314-
if isinstance(inputs, str):
315-
inputs = [{"role": "user", "content": inputs}]
316-
user_inputs = [i for i in inputs if i.get("role") == "user"]
317-
number_of_turns_metrics = AvgMinMax()
318-
if user_inputs:
319-
number_of_turns = len(user_inputs)
320-
number_of_turns_metrics = AvgMinMax(
321-
total=1,
322-
average=number_of_turns,
323-
min=number_of_turns,
324-
max=number_of_turns,
325-
)
326-
327-
temperature_metrics = AvgMinMax()
328-
if responses_create_params.get("temperature") is not None:
329-
temperature = responses_create_params["temperature"]
330-
temperature_metrics = AvgMinMax(
331-
total=1,
332-
average=temperature,
333-
min=temperature,
334-
max=temperature,
335-
)
336-
337-
json_dumped_number_of_words = len(json.dumps(responses_create_params).split())
338-
json_dumped_number_of_words_metrics = AvgMinMax(
339-
total=1,
340-
average=json_dumped_number_of_words,
341-
min=json_dumped_number_of_words,
342-
max=json_dumped_number_of_words,
343-
)
344-
345-
metrics = DatasetMetrics(
346-
number_of_examples=1,
347-
number_of_tools=number_of_tools_metrics,
348-
json_dumped_number_of_words=json_dumped_number_of_words_metrics,
349-
number_of_turns=number_of_turns_metrics,
350-
temperature=temperature_metrics,
351-
)
352359
state.metrics.add(metrics)
353360

354361
def _validate_samples_and_aggregate_metrics_single_dataset(

tests/unit_tests/test_dataset_viewer.py

Lines changed: 86 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,17 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import json
1415
from unittest.mock import mock_open, patch
1516

16-
from nemo_gym.dataset_viewer import JsonlDatasetViewerConfig, build_jsonl_dataset_viewer
17+
from pydantic import BaseModel
18+
from pytest import MonkeyPatch
19+
20+
from nemo_gym.dataset_viewer import (
21+
JsonlDatasetViewerConfig,
22+
build_jsonl_dataset_viewer,
23+
get_aggregate_metrics,
24+
)
1725

1826

1927
class TestDatasetViewer:
@@ -31,3 +39,80 @@ def test_sanity(
3139
mock_content = r"""{"reward": 0.0, "accuracy": false, "set_overlap": 0.0, "original_term_minefield_hit": false, "order_instruction_following_failure": false, "id": 44, "expected_synonym_values": [489, 504], "expected_synonyms": ["Awake", "Alert"], "minefield_label": "Alive", "minefield_label_value": 497, "responses_create_params": {"input": [{"content": "# Instructions", "role": "system"}, {"content": "How does the human body's response to danger highlight the instinct for survival?", "role": "user"}]}, "response": {"id": "resp_689038d64ad081929f6f36d2f2554431063ce0bbdad9d001", "created_at": 1754282198.0, "error": null, "incomplete_details": null, "instructions": null, "metadata": {}, "model": "gpt-4.1-2025-04-14", "object": "response", "output": [{"content": [{"annotations": [], "text": "fake chat message", "type": "output_text"}], "role": "assistant", "type": "message", "id": "fc_689038d5d1cc8192b91bbef0069ff82f063ce0bbdad9d001", "status": "completed"}, {"summary": [{"type": "summary_text", "text": "fake reasoning"}], "type": "reasoning", "id": "fc_689038d5d1cc8192b91bbef0069ff82f063ce0bbdad9d001", "status": "completed"}, {"arguments": "{\"synonym\":\"Survival\"}", "call_id": "call_pyKbpFtdag6LL6euwpAJ8UEw", "name": "get_synonym_value", "type": "function_call", "id": "fc_689038d5d1cc8192b91bbef0069ff82f063ce0bbdad9d001", "status": "completed"}, {"call_id": "call_pyKbpFtdag6LL6euwpAJ8UEw", "output": "{\"synonym_value\": 860}", "type": "function_call_output"}, {"arguments": "{\"synonym_values\":[860]}", "call_id": "call_N4Kr3NJJohoTaxSL5DkkG4W6", "name": "extract_synonym_values", "type": "function_call", "id": "fc_689038d6c71c8192b7ab0de7cc655d98063ce0bbdad9d001", "status": "completed"}], "parallel_tool_calls": false, "temperature": 1.0, "tool_choice": "auto", "tools": [{"name": "get_synonym_value", "parameters": {"properties": {"synonym": {"type": "string", "title": "Synonym", "description": "The synonym to get the value for."}}, "type": "object", "required": ["synonym"], "additionalProperties": false}, "strict": true, "type": "function", "description": "Get the synonym value for a synonym.\nThis operation returns a value that conforms to the following JSON Schema: {\"properties\": {\"synonym_value\": {\"type\": \"integer\", \"title\": \"Synonym Value\", \"description\": \"The value for this synonym.\"}}, \"type\": \"object\", \"required\": [\"synonym_value\"]}\n"}, {"name": "extract_synonym_values", "parameters": {"properties": {"synonym_values": {"items": {"type": "integer"}, "type": "array", "title": "Synonym Values", "description": "The synonym values corresponding to the term for the user query."}}, "type": "object", "required": ["synonym_values"], "additionalProperties": false}, "strict": true, "type": "function", "description": "Extract the synonym values you retrieved for the term that is relevant to the user query.\nThis operation returns a value that conforms to the following JSON Schema: {\"properties\": {\"success\": {\"type\": \"boolean\", \"title\": \"Success\", \"description\": \"Success.\"}}, \"type\": \"object\", \"required\": [\"success\"]}\n"}], "top_p": 1.0, "background": false, "max_output_tokens": null, "max_tool_calls": null, "previous_response_id": null, "prompt": null, "reasoning": {"effort": null, "generate_summary": null, "summary": null}, "service_tier": "default", "status": "completed", "text": {"format": {"type": "text"}}, "top_logprobs": 0, "truncation": "disabled", "usage": {"input_tokens": 2864, "input_tokens_details": {"cached_tokens": 2798}, "output_tokens": 19, "output_tokens_details": {"reasoning_tokens": 0}, "total_tokens": 2883}, "user": null, "prompt_cache_key": null, "safety_identifier": null, "store": true}}"""
3240
with patch("builtins.open", mock_open(read_data=mock_content)):
3341
build_jsonl_dataset_viewer(config)
42+
43+
def test_get_aggregate_metrics(self, monkeypatch: MonkeyPatch):
44+
class DummySample(BaseModel):
45+
responses_create_params: dict = {}
46+
response: dict = {}
47+
reward: float = 1.0
48+
accuracy: bool = True
49+
set_overlap: float = 0.5
50+
unrelated_list: list = []
51+
unrelated_dict: dict = {}
52+
53+
class DummySampleWithStrings(DummySample):
54+
some_string: str
55+
56+
samples = [
57+
DummySample(reward=1.0, accuracy=True, set_overlap=0.5),
58+
DummySample(reward=0.0, accuracy=False, set_overlap=0.0),
59+
DummySample(reward=0.5, accuracy=True, set_overlap=1.0),
60+
]
61+
62+
samples_with_strings = [
63+
DummySampleWithStrings(reward=1.0, accuracy=True, some_string="asdf"),
64+
DummySampleWithStrings(reward=0.0, accuracy=False, some_string="asdf"),
65+
DummySampleWithStrings(reward=0.5, accuracy=True, some_string="word1"),
66+
DummySampleWithStrings(reward=0.5, accuracy=True, some_string="word1"),
67+
DummySampleWithStrings(reward=0.5, accuracy=True, some_string="word2"),
68+
]
69+
70+
def mock_compute_sample_metrics(line: str):
71+
metrics = json.loads(line)
72+
return metrics, False
73+
74+
class DummyAgg:
75+
def model_dump(self, by_alias=True):
76+
return {}
77+
78+
def aggregate(self):
79+
return DummyAgg()
80+
81+
monkeypatch.setattr(
82+
"nemo_gym.train_data_utils.compute_sample_metrics",
83+
mock_compute_sample_metrics,
84+
)
85+
86+
result_1 = get_aggregate_metrics(samples, "{}\n")
87+
88+
assert "reward" in result_1
89+
assert "accuracy" in result_1
90+
assert "set_overlap" in result_1
91+
92+
assert "unrelated_str" not in result_1
93+
assert "unrelated_list" not in result_1
94+
assert "unrelated_dict" not in result_1
95+
96+
assert "responses_create_params" not in result_1
97+
assert "response" not in result_1
98+
99+
# Check computed values
100+
reward_stats = result_1["reward"]
101+
assert reward_stats["Total # non-null values"] == 3
102+
assert reward_stats["Average"] == (1.0 + 0.0 + 0.5) / 3
103+
assert reward_stats["Min"] == 0.0
104+
assert reward_stats["Max"] == 1.0
105+
106+
# Check computed values with bools converted to int
107+
accuracy_stats = result_1["accuracy"]
108+
assert accuracy_stats["Total # non-null values"] == 3
109+
assert accuracy_stats["Average"] == (1 + 0 + 1) / 3
110+
assert accuracy_stats["Min"] == 0
111+
assert accuracy_stats["Max"] == 1
112+
113+
# Check string counts
114+
result_2 = get_aggregate_metrics(samples_with_strings, "{}\n")
115+
116+
assert "some_string" in result_2
117+
assert result_2["some_string"]["unique_count"] == 3
118+
assert result_2["some_string"]["total_count"] == 5

0 commit comments

Comments
 (0)