Skip to content

Commit 255492d

Browse files
committed
refactor(multimodal): unify structured response parsing and error handling
- Add parse_structured_chat_response utility for streaming/non-streaming responses - Return GraderError instead of score=0 on exceptions in multimodal graders - Update tests to verify GraderError behavior - Move exception handling to aevaluate level for cleaner code
1 parent 2ed52ba commit 255492d

7 files changed

Lines changed: 75 additions & 63 deletions

File tree

openjudge/graders/multimodal/image_coherence.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from openjudge.models.base_chat_model import BaseChatModel
2424
from openjudge.models.schema.oai.message import ChatMessage
2525
from openjudge.models.schema.prompt_template import LanguageEnum, PromptTemplate
26+
from openjudge.utils.utils import parse_structured_chat_response
2627

2728
# pylint: disable=line-too-long
2829

@@ -228,24 +229,15 @@ async def _aevaluate_single_image(
228229
data_url = f"data:image/{image_format};base64,{image.base64}"
229230
content.append({"type": "image_url", "image_url": {"url": data_url}})
230231

231-
# Call model without structured output
232232
chat_response = await self.model.achat(
233233
messages=[{"role": "user", "content": content}],
234234
structured_model=GraderScoreCallback,
235235
)
236236

237-
# Handle both streaming and non-streaming responses
238-
if hasattr(chat_response, "__aiter__"):
239-
parsed = {}
240-
async for chunk in chat_response:
241-
if chunk.parsed:
242-
parsed.update(chunk.parsed)
243-
# Default to 5.0 (neutral score on 0-10 scale) for missing fields
244-
score = parsed.get("score", 5.0)
245-
reason = parsed.get("reason", "")
246-
else:
247-
score = chat_response.parsed["score"]
248-
reason = chat_response.parsed["reason"]
237+
# Default to 5.0 (neutral score on 0-10 scale) for missing fields
238+
parsed = await parse_structured_chat_response(chat_response)
239+
score = parsed.get("score", 5.0)
240+
reason = parsed.get("reason", "")
249241
return score, reason
250242

251243
async def _acompute(

openjudge/graders/multimodal/image_helpfulness.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from openjudge.models.base_chat_model import BaseChatModel
2525
from openjudge.models.schema.oai.message import ChatMessage
2626
from openjudge.models.schema.prompt_template import LanguageEnum, PromptTemplate
27+
from openjudge.utils.utils import parse_structured_chat_response
2728

2829
# pylint: disable=line-too-long
2930

@@ -229,18 +230,10 @@ async def _aevaluate_single_image(
229230
structured_model=GraderScoreCallback,
230231
)
231232

232-
# Handle both streaming and non-streaming responses
233-
if hasattr(chat_response, "__aiter__"):
234-
parsed = {}
235-
async for chunk in chat_response:
236-
if chunk.parsed:
237-
parsed.update(chunk.parsed)
238-
# Default to 5.0 (neutral score on 0-10 scale) for missing fields
239-
score = parsed.get("score", 5.0)
240-
reason = parsed.get("reason", "")
241-
else:
242-
score = chat_response.parsed["score"]
243-
reason = chat_response.parsed["reason"]
233+
# Default to 5.0 (neutral score on 0-10 scale) for missing fields
234+
parsed = await parse_structured_chat_response(chat_response)
235+
score = parsed.get("score", 5.0)
236+
reason = parsed.get("reason", "")
244237
return score, reason
245238

246239
async def _acompute(

openjudge/graders/multimodal/text_to_image.py

Lines changed: 10 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from openjudge.models.openai_chat_model import OpenAIChatModel
2121
from openjudge.models.schema.oai.message import ChatMessage
2222
from openjudge.models.schema.prompt_template import LanguageEnum, PromptTemplate
23+
from openjudge.utils.utils import parse_structured_chat_response
2324

2425
# pylint: disable=line-too-long
2526

@@ -265,20 +266,11 @@ async def _aevaluate_semantic_consistency(
265266
structured_model=GraderScoreCallback,
266267
)
267268

268-
# Handle both streaming and non-streaming responses
269-
if hasattr(chat_response, "__aiter__"):
270-
parsed = {}
271-
async for chunk in chat_response:
272-
if chunk.parsed:
273-
parsed.update(chunk.parsed)
274-
# Default to 5.0 (neutral score on 0-10 scale) for missing fields
275-
score = parsed.get("score", 5.0)
276-
score = score if isinstance(score, list) else [score]
277-
reason = parsed.get("reason", "")
278-
else:
279-
score = chat_response.parsed["score"]
280-
score = score if isinstance(score, list) else [score]
281-
reason = chat_response.parsed["reason"]
269+
# Default to 5.0 (neutral score on 0-10 scale) for missing fields
270+
parsed = await parse_structured_chat_response(chat_response)
271+
score = parsed.get("score", 5.0)
272+
score = score if isinstance(score, list) else [score]
273+
reason = parsed.get("reason", "")
282274
return score, reason
283275

284276
async def _aevaluate_perceptual_quality(
@@ -295,20 +287,11 @@ async def _aevaluate_perceptual_quality(
295287
structured_model=GraderScoreCallback,
296288
)
297289

298-
# Handle both streaming and non-streaming responses
299-
if hasattr(chat_response, "__aiter__"):
300-
parsed = {}
301-
async for chunk in chat_response:
302-
if chunk.parsed:
303-
parsed.update(chunk.parsed)
304-
# Default to 5.0 (neutral score on 0-10 scale) for missing fields
305-
score = parsed.get("score", [5.0, 5.0])
306-
reason = parsed.get("reason", "")
307-
else:
308-
score = chat_response.parsed["score"]
309-
reason = chat_response.parsed["reason"]
310-
290+
# Default to [5.0, 5.0] (neutral scores on 0-10 scale) for missing fields
291+
parsed = await parse_structured_chat_response(chat_response)
292+
score = parsed.get("score", [5.0, 5.0])
311293
score = score[:2] if isinstance(score, list) else [score, score]
294+
reason = parsed.get("reason", "")
312295
return score, reason
313296

314297
async def _a_compute(

openjudge/utils/utils.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
"""
77

88
import json
9-
from typing import Any, Dict, Type
9+
from typing import Any, Dict, Optional, Type
1010

1111
from json_repair import repair_json
1212
from loguru import logger
@@ -203,3 +203,42 @@ def trim_and_load_json(response: str, metric: Any = None) -> Dict[str, Any]:
203203
metric_name = getattr(metric, "name", "unknown_metric")
204204
logger.error(f"{metric_name}: {error_msg}")
205205
raise ValueError(error_msg) from e
206+
207+
208+
async def parse_structured_chat_response(
209+
chat_response: Any,
210+
default: Optional[Dict[str, Any]] = None,
211+
) -> Dict[str, Any]:
212+
"""Parse structured response from streaming or non-streaming chat response.
213+
214+
For streaming responses, returns the last chunk's parsed result (complete).
215+
For non-streaming responses, returns the parsed result directly.
216+
217+
Args:
218+
chat_response: Chat response object from model.achat() with structured_model.
219+
Can be either streaming (async iterator) or non-streaming.
220+
default: Default dict to return if parsing fails. Defaults to empty dict.
221+
222+
Returns:
223+
Dict[str, Any]: The parsed structured response containing fields like
224+
'score' and 'reason'.
225+
226+
Example:
227+
>>> response = await model.achat(messages, structured_model=GraderScoreCallback)
228+
>>> parsed = await parse_structured_chat_response(response)
229+
>>> score = parsed.get("score", 5.0)
230+
>>> reason = parsed.get("reason", "")
231+
"""
232+
if default is None:
233+
default = {}
234+
235+
if hasattr(chat_response, "__aiter__"):
236+
# Streaming response - only the last chunk contains complete result
237+
parsed = None
238+
async for chunk in chat_response:
239+
if chunk.parsed:
240+
parsed = chunk.parsed
241+
return parsed if parsed is not None else default
242+
243+
# Non-streaming response
244+
return chat_response.parsed if chat_response.parsed else default

tests/graders/multimodal/test_image_coherence.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@ def __init__(self):
9292
@pytest.mark.asyncio
9393
async def test_error_handling(self):
9494
"""Test graceful error handling"""
95+
from openjudge.graders.base_grader import GraderError
96+
9597
# Create mock model that raises exception
9698
mock_model = AsyncMock()
9799
mock_model.achat = AsyncMock(side_effect=Exception("API Error"))
@@ -105,9 +107,9 @@ async def test_error_handling(self):
105107
response=["Text before", mock_image, "Text after"],
106108
)
107109

108-
# Assertions
109-
assert result.score == 0.0
110-
assert "Evaluation error: API Error" in result.reason
110+
# Assertions - grader returns GraderError on exception
111+
assert isinstance(result, GraderError)
112+
assert "Evaluation error: API Error" in result.error
111113

112114

113115
# ==================== QUALITY TESTS ====================

tests/graders/multimodal/test_image_helpfulness.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@ def __init__(self):
9292
@pytest.mark.asyncio
9393
async def test_error_handling(self):
9494
"""Test graceful error handling"""
95+
from openjudge.graders.base_grader import GraderError
96+
9597
# Create mock model that raises exception
9698
mock_model = AsyncMock()
9799
mock_model.achat = AsyncMock(side_effect=Exception("API Error"))
@@ -105,9 +107,9 @@ async def test_error_handling(self):
105107
response=["Text before", mock_image, "Text after"],
106108
)
107109

108-
# Assertions
109-
assert result.score == 0.0
110-
assert "Evaluation error: API Error" in result.reason
110+
# Assertions - grader returns GraderError on exception
111+
assert isinstance(result, GraderError)
112+
assert "Evaluation error: API Error" in result.error
111113

112114

113115
# ==================== QUALITY TESTS ====================

tests/graders/multimodal/test_text_to_image.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@ def __init__(self, score, reason):
9595
@pytest.mark.asyncio
9696
async def test_error_handling(self):
9797
"""Test graceful error handling"""
98+
from openjudge.graders.base_grader import GraderError
99+
98100
# Create mock model that raises exception
99101
mock_model = AsyncMock(spec=BaseChatModel)
100102
mock_model.achat = AsyncMock(side_effect=Exception("API Error"))
@@ -109,10 +111,9 @@ async def test_error_handling(self):
109111
response=mock_image,
110112
)
111113

112-
# Assertions
113-
# TextToImageGrader returns 0.5 (default) on error, not 0.0
114-
assert result.score == 0.5
115-
assert "error" in result.reason.lower()
114+
# Assertions - grader returns GraderError on exception
115+
assert isinstance(result, GraderError)
116+
assert "Evaluation error: API Error" in result.error
116117

117118

118119
# ==================== QUALITY TESTS ====================

0 commit comments

Comments
 (0)