diff --git a/src/a2a/client/transports/http_helpers.py b/src/a2a/client/transports/http_helpers.py index 301782e3..eca386bd 100644 --- a/src/a2a/client/transports/http_helpers.py +++ b/src/a2a/client/transports/http_helpers.py @@ -6,7 +6,7 @@ import httpx -from httpx_sse import SSEError, aconnect_sse +from httpx_sse import EventSource, SSEError from a2a.client.client import ClientCallContext from a2a.client.errors import A2AClientError, A2AClientTimeoutError @@ -75,7 +75,7 @@ async def send_http_stream_request( ) -> AsyncGenerator[str]: """Sends a streaming HTTP request, yielding SSE data strings and handling exceptions.""" with handle_http_exceptions(status_error_handler): - async with aconnect_sse( + async with _SSEEventSource( httpx_client, method, url, **kwargs ) as event_source: try: @@ -98,3 +98,39 @@ async def send_http_stream_request( if not sse.data: continue yield sse.data + + +class _SSEEventSource: + """Class-based replacement for ``httpx_sse.aconnect_sse``. + + ``aconnect_sse`` is an ``@asynccontextmanager`` whose internal async + generator gets tracked by the event loop. When the enclosing async + generator is abandoned, the event loop's generator cleanup collides + with the cascading cleanup — see https://bugs.python.org/issue38559. + + Plain ``__aenter__``/``__aexit__`` coroutines avoid this entirely. + """ + + def __init__( + self, + client: httpx.AsyncClient, + method: str, + url: str, + **kwargs: Any, + ) -> None: + headers = httpx.Headers(kwargs.pop('headers', None)) + headers.setdefault('Accept', 'text/event-stream') + headers.setdefault('Cache-Control', 'no-store') + self._request = client.build_request( + method, url, headers=headers, **kwargs + ) + self._client = client + self._response: httpx.Response | None = None + + async def __aenter__(self) -> EventSource: + self._response = await self._client.send(self._request, stream=True) + return EventSource(self._response) + + async def __aexit__(self, *args: object) -> None: + if self._response is not None: + await self._response.aclose() diff --git a/tests/client/transports/test_jsonrpc_client.py b/tests/client/transports/test_jsonrpc_client.py index 5741aa00..1339bb8a 100644 --- a/tests/client/transports/test_jsonrpc_client.py +++ b/tests/client/transports/test_jsonrpc_client.py @@ -433,7 +433,7 @@ async def test_close(self, transport, mock_httpx_client): class TestStreamingErrors: @pytest.mark.asyncio - @patch('a2a.client.transports.http_helpers.aconnect_sse') + @patch('a2a.client.transports.http_helpers._SSEEventSource') async def test_send_message_streaming_sse_error( self, mock_aconnect_sse: AsyncMock, @@ -457,7 +457,7 @@ async def test_send_message_streaming_sse_error( pass @pytest.mark.asyncio - @patch('a2a.client.transports.http_helpers.aconnect_sse') + @patch('a2a.client.transports.http_helpers._SSEEventSource') async def test_send_message_streaming_request_error( self, mock_aconnect_sse: AsyncMock, @@ -483,7 +483,7 @@ async def test_send_message_streaming_request_error( pass @pytest.mark.asyncio - @patch('a2a.client.transports.http_helpers.aconnect_sse') + @patch('a2a.client.transports.http_helpers._SSEEventSource') async def test_send_message_streaming_timeout( self, mock_aconnect_sse: AsyncMock, @@ -560,7 +560,7 @@ async def test_extensions_added_to_request( ) @pytest.mark.asyncio - @patch('a2a.client.transports.http_helpers.aconnect_sse') + @patch('a2a.client.transports.http_helpers._SSEEventSource') async def test_send_message_streaming_server_error_propagates( self, mock_aconnect_sse: AsyncMock, diff --git a/tests/client/transports/test_rest_client.py b/tests/client/transports/test_rest_client.py index 7648de57..e7912566 100644 --- a/tests/client/transports/test_rest_client.py +++ b/tests/client/transports/test_rest_client.py @@ -70,7 +70,7 @@ def _assert_extensions_header(mock_kwargs: dict, expected_extensions: set[str]): class TestRestTransport: @pytest.mark.asyncio - @patch('a2a.client.transports.http_helpers.aconnect_sse') + @patch('a2a.client.transports.http_helpers._SSEEventSource') async def test_send_message_streaming_timeout( self, mock_aconnect_sse: AsyncMock, @@ -280,7 +280,7 @@ async def test_send_message_with_default_extensions( ) @pytest.mark.asyncio - @patch('a2a.client.transports.http_helpers.aconnect_sse') + @patch('a2a.client.transports.http_helpers._SSEEventSource') async def test_send_message_streaming_with_new_extensions( self, mock_aconnect_sse: AsyncMock, @@ -329,7 +329,7 @@ async def test_send_message_streaming_with_new_extensions( ) @pytest.mark.asyncio - @patch('a2a.client.transports.http_helpers.aconnect_sse') + @patch('a2a.client.transports.http_helpers._SSEEventSource') async def test_send_message_streaming_server_error_propagates( self, mock_aconnect_sse: AsyncMock, @@ -693,7 +693,7 @@ async def test_rest_get_task_prepend_empty_tenant( ], ) @pytest.mark.asyncio - @patch('a2a.client.transports.http_helpers.aconnect_sse') + @patch('a2a.client.transports.http_helpers._SSEEventSource') async def test_rest_streaming_methods_prepend_tenant( # noqa: PLR0913 self, mock_aconnect_sse, diff --git a/tests/integration/test_stream_generator_cleanup.py b/tests/integration/test_stream_generator_cleanup.py new file mode 100644 index 00000000..184bf665 --- /dev/null +++ b/tests/integration/test_stream_generator_cleanup.py @@ -0,0 +1,135 @@ +"""Test that streaming SSE responses clean up without athrow() errors. + +Reproduces https://github.com/a2aproject/a2a-python/issues/911 — +``RuntimeError: athrow(): asynchronous generator is already running`` +during event-loop shutdown after consuming a streaming response. +""" + +import asyncio +import gc + +from typing import Any +from uuid import uuid4 + +import httpx +import pytest + +from starlette.applications import Starlette + +from a2a.client.base_client import BaseClient +from a2a.client.client import ClientConfig +from a2a.client.client_factory import ClientFactory +from a2a.server.agent_execution import AgentExecutor, RequestContext +from a2a.server.events import EventQueue +from a2a.server.events.in_memory_queue_manager import InMemoryQueueManager +from a2a.server.request_handlers import DefaultRequestHandler +from a2a.server.routes import create_agent_card_routes, create_jsonrpc_routes +from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore +from a2a.types import ( + AgentCapabilities, + AgentCard, + AgentInterface, + Message, + Part, + Role, + SendMessageRequest, +) +from a2a.utils import TransportProtocol + + +class _MessageExecutor(AgentExecutor): + """Responds with a single Message event.""" + + async def execute(self, ctx: RequestContext, eq: EventQueue) -> None: + await eq.enqueue_event( + Message( + role=Role.ROLE_AGENT, + message_id=str(uuid4()), + parts=[Part(text='Hello')], + context_id=ctx.context_id, + task_id=ctx.task_id, + ) + ) + + async def cancel(self, ctx: RequestContext, eq: EventQueue) -> None: + pass + + +@pytest.fixture +def client(): + """Creates a JSON-RPC client backed by an in-process ASGI server.""" + card = AgentCard( + name='T', + description='T', + version='1', + capabilities=AgentCapabilities(streaming=True), + default_input_modes=['text/plain'], + default_output_modes=['text/plain'], + supported_interfaces=[ + AgentInterface( + protocol_binding=TransportProtocol.JSONRPC, + url='http://test', + ), + ], + ) + handler = DefaultRequestHandler( + agent_executor=_MessageExecutor(), + task_store=InMemoryTaskStore(), + queue_manager=InMemoryQueueManager(), + ) + app = Starlette( + routes=[ + *create_agent_card_routes(agent_card=card, card_url='/card'), + *create_jsonrpc_routes( + agent_card=card, + request_handler=handler, + extended_agent_card=card, + rpc_url='/', + ), + ] + ) + return ClientFactory( + config=ClientConfig( + httpx_client=httpx.AsyncClient( + transport=httpx.ASGITransport(app=app), + base_url='http://test', + ) + ) + ).create(card) + + +@pytest.mark.asyncio +async def test_stream_message_no_athrow(client: BaseClient) -> None: + """Consuming a streamed Message must not leave broken async generators.""" + errors: list[dict[str, Any]] = [] + loop = asyncio.get_event_loop() + orig = loop.get_exception_handler() + loop.set_exception_handler(lambda _l, ctx: errors.append(ctx)) + + try: + msg = Message( + role=Role.ROLE_USER, + message_id=f'msg-{uuid4()}', + parts=[Part(text='hi')], + ) + events = [ + e + async for e in client.send_message( + request=SendMessageRequest(message=msg) + ) + ] + assert events + assert events[0][0].HasField('message') + + gc.collect() + await loop.shutdown_asyncgens() + + bad = [ + e + for e in errors + if 'asynchronous generator' in str(e.get('message', '')) + ] + assert not bad, '\n'.join(str(e.get('message', '')) for e in bad) + finally: + loop.set_exception_handler(orig) + await client.close()