Skip to content

Commit 405be3f

Browse files
authored
fix: fix REST error handling (#893)
Do one iteration to catch exceptions occurred beforehand to return an error instead of sending headers for SSE. Error handling during the execution is not defined in the spec: a2aproject/A2A#1262.
1 parent 734d062 commit 405be3f

File tree

2 files changed

+92
-19
lines changed

2 files changed

+92
-19
lines changed

src/a2a/server/apps/rest/rest_adapter.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -149,15 +149,26 @@ async def _handle_streaming_request(
149149

150150
call_context = self._build_call_context(request)
151151

152-
async def event_generator(
153-
stream: AsyncIterable[Any],
154-
) -> AsyncIterator[str]:
152+
# Eagerly fetch the first item from the stream so that errors raised
153+
# before any event is yielded (e.g. validation, parsing, or handler
154+
# failures) propagate here and are caught by
155+
# @rest_stream_error_handler, which returns a JSONResponse with
156+
# the correct HTTP status code instead of starting an SSE stream.
157+
# Without this, the error would be raised after SSE headers are
158+
# already sent, and the client would see a broken stream instead
159+
# of a proper error response.
160+
stream = aiter(method(request, call_context))
161+
try:
162+
first_item = await anext(stream)
163+
except StopAsyncIteration:
164+
return EventSourceResponse(iter([]))
165+
166+
async def event_generator() -> AsyncIterator[str]:
167+
yield json.dumps(first_item)
155168
async for item in stream:
156169
yield json.dumps(item)
157170

158-
return EventSourceResponse(
159-
event_generator(method(request, call_context))
160-
)
171+
return EventSourceResponse(event_generator())
161172

162173
async def handle_get_agent_card(
163174
self, request: Request, call_context: ServerCallContext | None = None

tests/integration/test_client_server_integration.py

Lines changed: 75 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import asyncio
2-
32
from collections.abc import AsyncGenerator
43
from typing import Any, NamedTuple
54
from unittest.mock import ANY, AsyncMock, patch
@@ -8,22 +7,25 @@
87
import httpx
98
import pytest
109
import pytest_asyncio
11-
1210
from cryptography.hazmat.primitives.asymmetric import ec
1311
from google.protobuf.json_format import MessageToDict
1412
from google.protobuf.timestamp_pb2 import Timestamp
1513

1614
from a2a.client import Client, ClientConfig
1715
from a2a.client.base_client import BaseClient
1816
from a2a.client.card_resolver import A2ACardResolver
19-
from a2a.client.client_factory import ClientFactory
2017
from a2a.client.client import ClientCallContext
18+
from a2a.client.client_factory import ClientFactory
2119
from a2a.client.service_parameters import (
2220
ServiceParametersFactory,
2321
with_a2a_extensions,
2422
)
2523
from a2a.client.transports import JsonRpcTransport, RestTransport
2624
from starlette.applications import Starlette
25+
26+
# Compat v0.3 imports for dedicated tests
27+
from a2a.compat.v0_3 import a2a_v0_3_pb2, a2a_v0_3_pb2_grpc
28+
from a2a.compat.v0_3.grpc_handler import CompatGrpcHandler
2729
from a2a.server.apps import A2ARESTFastAPIApplication
2830
from a2a.server.routes import create_agent_card_routes, create_jsonrpc_routes
2931
from a2a.server.request_handlers import GrpcHandler, RequestHandler
@@ -52,12 +54,10 @@
5254
TaskStatus,
5355
TaskStatusUpdateEvent,
5456
)
55-
from a2a.utils.constants import (
56-
TransportProtocol,
57-
)
57+
from a2a.utils.constants import TransportProtocol
5858
from a2a.utils.errors import (
59-
ExtendedAgentCardNotConfiguredError,
6059
ContentTypeNotSupportedError,
60+
ExtendedAgentCardNotConfiguredError,
6161
ExtensionSupportRequiredError,
6262
InternalError,
6363
InvalidAgentResponseError,
@@ -75,11 +75,6 @@
7575
create_signature_verifier,
7676
)
7777

78-
# Compat v0.3 imports for dedicated tests
79-
from a2a.compat.v0_3 import a2a_v0_3_pb2, a2a_v0_3_pb2_grpc
80-
from a2a.compat.v0_3.grpc_handler import CompatGrpcHandler
81-
82-
8378
# --- Test Constants ---
8479

8580
TASK_FROM_STREAM = Task(
@@ -368,9 +363,9 @@ def grpc_03_setup(
368363
) -> TransportSetup:
369364
"""Sets up the CompatGrpcTransport and in-process 0.3 server."""
370365
server_address, handler = grpc_03_server_and_handler
371-
from a2a.compat.v0_3.grpc_transport import CompatGrpcTransport
372366
from a2a.client.base_client import BaseClient
373367
from a2a.client.client import ClientConfig
368+
from a2a.compat.v0_3.grpc_transport import CompatGrpcTransport
374369

375370
channel = grpc.aio.insecure_channel(server_address)
376371
transport = CompatGrpcTransport(channel=channel, agent_card=agent_card)
@@ -926,6 +921,73 @@ async def test_client_handles_a2a_errors(transport_setups, error_cls) -> None:
926921
await client.close()
927922

928923

924+
@pytest.mark.asyncio
925+
@pytest.mark.parametrize(
926+
'error_cls',
927+
[
928+
TaskNotFoundError,
929+
TaskNotCancelableError,
930+
PushNotificationNotSupportedError,
931+
UnsupportedOperationError,
932+
ContentTypeNotSupportedError,
933+
InvalidAgentResponseError,
934+
ExtendedAgentCardNotConfiguredError,
935+
ExtensionSupportRequiredError,
936+
VersionNotSupportedError,
937+
],
938+
)
939+
@pytest.mark.parametrize(
940+
'handler_attr, client_method, request_params',
941+
[
942+
pytest.param(
943+
'on_message_send_stream',
944+
'send_message',
945+
SendMessageRequest(
946+
message=Message(
947+
role=Role.ROLE_USER,
948+
message_id='msg-integration-test',
949+
parts=[Part(text='Hello, integration test!')],
950+
)
951+
),
952+
id='stream',
953+
),
954+
pytest.param(
955+
'on_subscribe_to_task',
956+
'subscribe',
957+
SubscribeToTaskRequest(id='some-id'),
958+
id='subscribe',
959+
),
960+
],
961+
)
962+
async def test_client_handles_a2a_errors_streaming(
963+
transport_setups, error_cls, handler_attr, client_method, request_params
964+
) -> None:
965+
"""Integration test to verify error propagation from streaming handlers to client.
966+
967+
The handler raises an A2AError before yielding any events. All transports
968+
must propagate this as the exact error_cls, not wrapped in an ExceptionGroup
969+
or converted to a generic client error.
970+
"""
971+
client = transport_setups.client
972+
handler = transport_setups.handler
973+
974+
async def mock_generator(*args, **kwargs):
975+
raise error_cls('Test error message')
976+
yield
977+
978+
getattr(handler, handler_attr).side_effect = mock_generator
979+
980+
with pytest.raises(error_cls) as exc_info:
981+
async for _ in getattr(client, client_method)(request=request_params):
982+
pass
983+
984+
assert 'Test error message' in str(exc_info.value)
985+
986+
getattr(handler, handler_attr).side_effect = None
987+
988+
await client.close()
989+
990+
929991
@pytest.mark.asyncio
930992
@pytest.mark.parametrize(
931993
'request_kwargs, expected_error_code',

0 commit comments

Comments
 (0)