1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14+ import json
1415from unittest .mock import mock_open , patch
1516
16- from nemo_gym .dataset_viewer import JsonlDatasetViewerConfig , build_jsonl_dataset_viewer
17+ from pydantic import BaseModel
18+ from pytest import MonkeyPatch
19+
20+ from nemo_gym .dataset_viewer import (
21+ JsonlDatasetViewerConfig ,
22+ build_jsonl_dataset_viewer ,
23+ get_aggregate_metrics ,
24+ )
1725
1826
1927class TestDatasetViewer :
@@ -31,3 +39,80 @@ def test_sanity(
3139 mock_content = r"""{"reward": 0.0, "accuracy": false, "set_overlap": 0.0, "original_term_minefield_hit": false, "order_instruction_following_failure": false, "id": 44, "expected_synonym_values": [489, 504], "expected_synonyms": ["Awake", "Alert"], "minefield_label": "Alive", "minefield_label_value": 497, "responses_create_params": {"input": [{"content": "# Instructions", "role": "system"}, {"content": "How does the human body's response to danger highlight the instinct for survival?", "role": "user"}]}, "response": {"id": "resp_689038d64ad081929f6f36d2f2554431063ce0bbdad9d001", "created_at": 1754282198.0, "error": null, "incomplete_details": null, "instructions": null, "metadata": {}, "model": "gpt-4.1-2025-04-14", "object": "response", "output": [{"content": [{"annotations": [], "text": "fake chat message", "type": "output_text"}], "role": "assistant", "type": "message", "id": "fc_689038d5d1cc8192b91bbef0069ff82f063ce0bbdad9d001", "status": "completed"}, {"summary": [{"type": "summary_text", "text": "fake reasoning"}], "type": "reasoning", "id": "fc_689038d5d1cc8192b91bbef0069ff82f063ce0bbdad9d001", "status": "completed"}, {"arguments": "{\"synonym\":\"Survival\"}", "call_id": "call_pyKbpFtdag6LL6euwpAJ8UEw", "name": "get_synonym_value", "type": "function_call", "id": "fc_689038d5d1cc8192b91bbef0069ff82f063ce0bbdad9d001", "status": "completed"}, {"call_id": "call_pyKbpFtdag6LL6euwpAJ8UEw", "output": "{\"synonym_value\": 860}", "type": "function_call_output"}, {"arguments": "{\"synonym_values\":[860]}", "call_id": "call_N4Kr3NJJohoTaxSL5DkkG4W6", "name": "extract_synonym_values", "type": "function_call", "id": "fc_689038d6c71c8192b7ab0de7cc655d98063ce0bbdad9d001", "status": "completed"}], "parallel_tool_calls": false, "temperature": 1.0, "tool_choice": "auto", "tools": [{"name": "get_synonym_value", "parameters": {"properties": {"synonym": {"type": "string", "title": "Synonym", "description": "The synonym to get the value for."}}, "type": "object", "required": ["synonym"], "additionalProperties": false}, "strict": true, "type": "function", "description": "Get the synonym value for a synonym.\nThis operation returns a value that conforms to the following JSON Schema: {\"properties\": {\"synonym_value\": {\"type\": \"integer\", \"title\": \"Synonym Value\", \"description\": \"The value for this synonym.\"}}, \"type\": \"object\", \"required\": [\"synonym_value\"]}\n"}, {"name": "extract_synonym_values", "parameters": {"properties": {"synonym_values": {"items": {"type": "integer"}, "type": "array", "title": "Synonym Values", "description": "The synonym values corresponding to the term for the user query."}}, "type": "object", "required": ["synonym_values"], "additionalProperties": false}, "strict": true, "type": "function", "description": "Extract the synonym values you retrieved for the term that is relevant to the user query.\nThis operation returns a value that conforms to the following JSON Schema: {\"properties\": {\"success\": {\"type\": \"boolean\", \"title\": \"Success\", \"description\": \"Success.\"}}, \"type\": \"object\", \"required\": [\"success\"]}\n"}], "top_p": 1.0, "background": false, "max_output_tokens": null, "max_tool_calls": null, "previous_response_id": null, "prompt": null, "reasoning": {"effort": null, "generate_summary": null, "summary": null}, "service_tier": "default", "status": "completed", "text": {"format": {"type": "text"}}, "top_logprobs": 0, "truncation": "disabled", "usage": {"input_tokens": 2864, "input_tokens_details": {"cached_tokens": 2798}, "output_tokens": 19, "output_tokens_details": {"reasoning_tokens": 0}, "total_tokens": 2883}, "user": null, "prompt_cache_key": null, "safety_identifier": null, "store": true}}"""
3240 with patch ("builtins.open" , mock_open (read_data = mock_content )):
3341 build_jsonl_dataset_viewer (config )
42+
43+ def test_get_aggregate_metrics (self , monkeypatch : MonkeyPatch ):
44+ class DummySample (BaseModel ):
45+ responses_create_params : dict = {}
46+ response : dict = {}
47+ reward : float = 1.0
48+ accuracy : bool = True
49+ set_overlap : float = 0.5
50+ unrelated_list : list = []
51+ unrelated_dict : dict = {}
52+
53+ class DummySampleWithStrings (DummySample ):
54+ some_string : str
55+
56+ samples = [
57+ DummySample (reward = 1.0 , accuracy = True , set_overlap = 0.5 ),
58+ DummySample (reward = 0.0 , accuracy = False , set_overlap = 0.0 ),
59+ DummySample (reward = 0.5 , accuracy = True , set_overlap = 1.0 ),
60+ ]
61+
62+ samples_with_strings = [
63+ DummySampleWithStrings (reward = 1.0 , accuracy = True , some_string = "asdf" ),
64+ DummySampleWithStrings (reward = 0.0 , accuracy = False , some_string = "asdf" ),
65+ DummySampleWithStrings (reward = 0.5 , accuracy = True , some_string = "word1" ),
66+ DummySampleWithStrings (reward = 0.5 , accuracy = True , some_string = "word1" ),
67+ DummySampleWithStrings (reward = 0.5 , accuracy = True , some_string = "word2" ),
68+ ]
69+
70+ def mock_compute_sample_metrics (line : str ):
71+ metrics = json .loads (line )
72+ return metrics , False
73+
74+ class DummyAgg :
75+ def model_dump (self , by_alias = True ):
76+ return {}
77+
78+ def aggregate (self ):
79+ return DummyAgg ()
80+
81+ monkeypatch .setattr (
82+ "nemo_gym.train_data_utils.compute_sample_metrics" ,
83+ mock_compute_sample_metrics ,
84+ )
85+
86+ result_1 = get_aggregate_metrics (samples , "{}\n " )
87+
88+ assert "reward" in result_1
89+ assert "accuracy" in result_1
90+ assert "set_overlap" in result_1
91+
92+ assert "unrelated_str" not in result_1
93+ assert "unrelated_list" not in result_1
94+ assert "unrelated_dict" not in result_1
95+
96+ assert "responses_create_params" not in result_1
97+ assert "response" not in result_1
98+
99+ # Check computed values
100+ reward_stats = result_1 ["reward" ]
101+ assert reward_stats ["Total # non-null values" ] == 3
102+ assert reward_stats ["Average" ] == (1.0 + 0.0 + 0.5 ) / 3
103+ assert reward_stats ["Min" ] == 0.0
104+ assert reward_stats ["Max" ] == 1.0
105+
106+ # Check computed values with bools converted to int
107+ accuracy_stats = result_1 ["accuracy" ]
108+ assert accuracy_stats ["Total # non-null values" ] == 3
109+ assert accuracy_stats ["Average" ] == (1 + 0 + 1 ) / 3
110+ assert accuracy_stats ["Min" ] == 0
111+ assert accuracy_stats ["Max" ] == 1
112+
113+ # Check string counts
114+ result_2 = get_aggregate_metrics (samples_with_strings , "{}\n " )
115+
116+ assert "some_string" in result_2
117+ assert result_2 ["some_string" ]["unique_count" ] == 3
118+ assert result_2 ["some_string" ]["total_count" ] == 5
0 commit comments