|
1 | 1 | import asyncio |
2 | | - |
3 | 2 | from collections.abc import AsyncGenerator |
4 | 3 | from typing import Any, NamedTuple |
5 | 4 | from unittest.mock import ANY, AsyncMock, patch |
|
8 | 7 | import httpx |
9 | 8 | import pytest |
10 | 9 | import pytest_asyncio |
11 | | - |
12 | 10 | from cryptography.hazmat.primitives.asymmetric import ec |
13 | 11 | from google.protobuf.json_format import MessageToDict |
14 | 12 | from google.protobuf.timestamp_pb2 import Timestamp |
15 | 13 |
|
16 | 14 | from a2a.client import Client, ClientConfig |
17 | 15 | from a2a.client.base_client import BaseClient |
18 | 16 | from a2a.client.card_resolver import A2ACardResolver |
19 | | -from a2a.client.client_factory import ClientFactory |
20 | 17 | from a2a.client.client import ClientCallContext |
| 18 | +from a2a.client.client_factory import ClientFactory |
21 | 19 | from a2a.client.service_parameters import ( |
22 | 20 | ServiceParametersFactory, |
23 | 21 | with_a2a_extensions, |
24 | 22 | ) |
25 | 23 | from a2a.client.transports import JsonRpcTransport, RestTransport |
26 | 24 | from starlette.applications import Starlette |
| 25 | + |
| 26 | +# Compat v0.3 imports for dedicated tests |
| 27 | +from a2a.compat.v0_3 import a2a_v0_3_pb2, a2a_v0_3_pb2_grpc |
| 28 | +from a2a.compat.v0_3.grpc_handler import CompatGrpcHandler |
27 | 29 | from a2a.server.apps import A2ARESTFastAPIApplication |
28 | 30 | from a2a.server.routes import create_agent_card_routes, create_jsonrpc_routes |
29 | 31 | from a2a.server.request_handlers import GrpcHandler, RequestHandler |
|
52 | 54 | TaskStatus, |
53 | 55 | TaskStatusUpdateEvent, |
54 | 56 | ) |
55 | | -from a2a.utils.constants import ( |
56 | | - TransportProtocol, |
57 | | -) |
| 57 | +from a2a.utils.constants import TransportProtocol |
58 | 58 | from a2a.utils.errors import ( |
59 | | - ExtendedAgentCardNotConfiguredError, |
60 | 59 | ContentTypeNotSupportedError, |
| 60 | + ExtendedAgentCardNotConfiguredError, |
61 | 61 | ExtensionSupportRequiredError, |
62 | 62 | InternalError, |
63 | 63 | InvalidAgentResponseError, |
|
75 | 75 | create_signature_verifier, |
76 | 76 | ) |
77 | 77 |
|
78 | | -# Compat v0.3 imports for dedicated tests |
79 | | -from a2a.compat.v0_3 import a2a_v0_3_pb2, a2a_v0_3_pb2_grpc |
80 | | -from a2a.compat.v0_3.grpc_handler import CompatGrpcHandler |
81 | | - |
82 | | - |
83 | 78 | # --- Test Constants --- |
84 | 79 |
|
85 | 80 | TASK_FROM_STREAM = Task( |
@@ -368,9 +363,9 @@ def grpc_03_setup( |
368 | 363 | ) -> TransportSetup: |
369 | 364 | """Sets up the CompatGrpcTransport and in-process 0.3 server.""" |
370 | 365 | server_address, handler = grpc_03_server_and_handler |
371 | | - from a2a.compat.v0_3.grpc_transport import CompatGrpcTransport |
372 | 366 | from a2a.client.base_client import BaseClient |
373 | 367 | from a2a.client.client import ClientConfig |
| 368 | + from a2a.compat.v0_3.grpc_transport import CompatGrpcTransport |
374 | 369 |
|
375 | 370 | channel = grpc.aio.insecure_channel(server_address) |
376 | 371 | transport = CompatGrpcTransport(channel=channel, agent_card=agent_card) |
@@ -926,6 +921,73 @@ async def test_client_handles_a2a_errors(transport_setups, error_cls) -> None: |
926 | 921 | await client.close() |
927 | 922 |
|
928 | 923 |
|
| 924 | +@pytest.mark.asyncio |
| 925 | +@pytest.mark.parametrize( |
| 926 | + 'error_cls', |
| 927 | + [ |
| 928 | + TaskNotFoundError, |
| 929 | + TaskNotCancelableError, |
| 930 | + PushNotificationNotSupportedError, |
| 931 | + UnsupportedOperationError, |
| 932 | + ContentTypeNotSupportedError, |
| 933 | + InvalidAgentResponseError, |
| 934 | + ExtendedAgentCardNotConfiguredError, |
| 935 | + ExtensionSupportRequiredError, |
| 936 | + VersionNotSupportedError, |
| 937 | + ], |
| 938 | +) |
| 939 | +@pytest.mark.parametrize( |
| 940 | + 'handler_attr, client_method, request_params', |
| 941 | + [ |
| 942 | + pytest.param( |
| 943 | + 'on_message_send_stream', |
| 944 | + 'send_message', |
| 945 | + SendMessageRequest( |
| 946 | + message=Message( |
| 947 | + role=Role.ROLE_USER, |
| 948 | + message_id='msg-integration-test', |
| 949 | + parts=[Part(text='Hello, integration test!')], |
| 950 | + ) |
| 951 | + ), |
| 952 | + id='stream', |
| 953 | + ), |
| 954 | + pytest.param( |
| 955 | + 'on_subscribe_to_task', |
| 956 | + 'subscribe', |
| 957 | + SubscribeToTaskRequest(id='some-id'), |
| 958 | + id='subscribe', |
| 959 | + ), |
| 960 | + ], |
| 961 | +) |
| 962 | +async def test_client_handles_a2a_errors_streaming( |
| 963 | + transport_setups, error_cls, handler_attr, client_method, request_params |
| 964 | +) -> None: |
| 965 | + """Integration test to verify error propagation from streaming handlers to client. |
| 966 | +
|
| 967 | + The handler raises an A2AError before yielding any events. All transports |
| 968 | + must propagate this as the exact error_cls, not wrapped in an ExceptionGroup |
| 969 | + or converted to a generic client error. |
| 970 | + """ |
| 971 | + client = transport_setups.client |
| 972 | + handler = transport_setups.handler |
| 973 | + |
| 974 | + async def mock_generator(*args, **kwargs): |
| 975 | + raise error_cls('Test error message') |
| 976 | + yield |
| 977 | + |
| 978 | + getattr(handler, handler_attr).side_effect = mock_generator |
| 979 | + |
| 980 | + with pytest.raises(error_cls) as exc_info: |
| 981 | + async for _ in getattr(client, client_method)(request=request_params): |
| 982 | + pass |
| 983 | + |
| 984 | + assert 'Test error message' in str(exc_info.value) |
| 985 | + |
| 986 | + getattr(handler, handler_attr).side_effect = None |
| 987 | + |
| 988 | + await client.close() |
| 989 | + |
| 990 | + |
929 | 991 | @pytest.mark.asyncio |
930 | 992 | @pytest.mark.parametrize( |
931 | 993 | 'request_kwargs, expected_error_code', |
|
0 commit comments