Skip to content

Commit ca7edc3

Browse files
fix: fix athrow() RuntimeError on streaming responses (#912)
When the server sends a `Message` event in the SSE stream, `_process_stream` does an early return abandoning the generator chain while the SSE connection is still open. `send_http_stream_request` yields inside `async with aconnect_sse(...)`. `aconnect_sse` (from httpx-sse) is an `@asynccontextmanager` . During event loop shutdown, `shutdown_asyncgens` tries to finalize all tracked generators independently - two `athrow()` calls hit the same chain, producing the `RuntimeError`. Replace `aconnect_sse` with `_SSEEventSource` - a class-based async context manager that calls `httpx_client.send(..., stream=True)` directly and `response.aclose()` in `__aexit__`. Added test fails without a fix: https://github.com/a2aproject/a2a-python/actions/runs/23648762100/job/68887648853. Fixes #911 --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 9cade9b commit ca7edc3

File tree

4 files changed

+181
-10
lines changed

4 files changed

+181
-10
lines changed

src/a2a/client/transports/http_helpers.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import httpx
88

9-
from httpx_sse import SSEError, aconnect_sse
9+
from httpx_sse import EventSource, SSEError
1010

1111
from a2a.client.client import ClientCallContext
1212
from a2a.client.errors import A2AClientError, A2AClientTimeoutError
@@ -75,7 +75,7 @@ async def send_http_stream_request(
7575
) -> AsyncGenerator[str]:
7676
"""Sends a streaming HTTP request, yielding SSE data strings and handling exceptions."""
7777
with handle_http_exceptions(status_error_handler):
78-
async with aconnect_sse(
78+
async with _SSEEventSource(
7979
httpx_client, method, url, **kwargs
8080
) as event_source:
8181
try:
@@ -98,3 +98,39 @@ async def send_http_stream_request(
9898
if not sse.data:
9999
continue
100100
yield sse.data
101+
102+
103+
class _SSEEventSource:
104+
"""Class-based replacement for ``httpx_sse.aconnect_sse``.
105+
106+
``aconnect_sse`` is an ``@asynccontextmanager`` whose internal async
107+
generator gets tracked by the event loop. When the enclosing async
108+
generator is abandoned, the event loop's generator cleanup collides
109+
with the cascading cleanup — see https://bugs.python.org/issue38559.
110+
111+
Plain ``__aenter__``/``__aexit__`` coroutines avoid this entirely.
112+
"""
113+
114+
def __init__(
115+
self,
116+
client: httpx.AsyncClient,
117+
method: str,
118+
url: str,
119+
**kwargs: Any,
120+
) -> None:
121+
headers = httpx.Headers(kwargs.pop('headers', None))
122+
headers.setdefault('Accept', 'text/event-stream')
123+
headers.setdefault('Cache-Control', 'no-store')
124+
self._request = client.build_request(
125+
method, url, headers=headers, **kwargs
126+
)
127+
self._client = client
128+
self._response: httpx.Response | None = None
129+
130+
async def __aenter__(self) -> EventSource:
131+
self._response = await self._client.send(self._request, stream=True)
132+
return EventSource(self._response)
133+
134+
async def __aexit__(self, *args: object) -> None:
135+
if self._response is not None:
136+
await self._response.aclose()

tests/client/transports/test_jsonrpc_client.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -433,7 +433,7 @@ async def test_close(self, transport, mock_httpx_client):
433433

434434
class TestStreamingErrors:
435435
@pytest.mark.asyncio
436-
@patch('a2a.client.transports.http_helpers.aconnect_sse')
436+
@patch('a2a.client.transports.http_helpers._SSEEventSource')
437437
async def test_send_message_streaming_sse_error(
438438
self,
439439
mock_aconnect_sse: AsyncMock,
@@ -457,7 +457,7 @@ async def test_send_message_streaming_sse_error(
457457
pass
458458

459459
@pytest.mark.asyncio
460-
@patch('a2a.client.transports.http_helpers.aconnect_sse')
460+
@patch('a2a.client.transports.http_helpers._SSEEventSource')
461461
async def test_send_message_streaming_request_error(
462462
self,
463463
mock_aconnect_sse: AsyncMock,
@@ -483,7 +483,7 @@ async def test_send_message_streaming_request_error(
483483
pass
484484

485485
@pytest.mark.asyncio
486-
@patch('a2a.client.transports.http_helpers.aconnect_sse')
486+
@patch('a2a.client.transports.http_helpers._SSEEventSource')
487487
async def test_send_message_streaming_timeout(
488488
self,
489489
mock_aconnect_sse: AsyncMock,
@@ -560,7 +560,7 @@ async def test_extensions_added_to_request(
560560
)
561561

562562
@pytest.mark.asyncio
563-
@patch('a2a.client.transports.http_helpers.aconnect_sse')
563+
@patch('a2a.client.transports.http_helpers._SSEEventSource')
564564
async def test_send_message_streaming_server_error_propagates(
565565
self,
566566
mock_aconnect_sse: AsyncMock,

tests/client/transports/test_rest_client.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def _assert_extensions_header(mock_kwargs: dict, expected_extensions: set[str]):
7070

7171
class TestRestTransport:
7272
@pytest.mark.asyncio
73-
@patch('a2a.client.transports.http_helpers.aconnect_sse')
73+
@patch('a2a.client.transports.http_helpers._SSEEventSource')
7474
async def test_send_message_streaming_timeout(
7575
self,
7676
mock_aconnect_sse: AsyncMock,
@@ -280,7 +280,7 @@ async def test_send_message_with_default_extensions(
280280
)
281281

282282
@pytest.mark.asyncio
283-
@patch('a2a.client.transports.http_helpers.aconnect_sse')
283+
@patch('a2a.client.transports.http_helpers._SSEEventSource')
284284
async def test_send_message_streaming_with_new_extensions(
285285
self,
286286
mock_aconnect_sse: AsyncMock,
@@ -329,7 +329,7 @@ async def test_send_message_streaming_with_new_extensions(
329329
)
330330

331331
@pytest.mark.asyncio
332-
@patch('a2a.client.transports.http_helpers.aconnect_sse')
332+
@patch('a2a.client.transports.http_helpers._SSEEventSource')
333333
async def test_send_message_streaming_server_error_propagates(
334334
self,
335335
mock_aconnect_sse: AsyncMock,
@@ -693,7 +693,7 @@ async def test_rest_get_task_prepend_empty_tenant(
693693
],
694694
)
695695
@pytest.mark.asyncio
696-
@patch('a2a.client.transports.http_helpers.aconnect_sse')
696+
@patch('a2a.client.transports.http_helpers._SSEEventSource')
697697
async def test_rest_streaming_methods_prepend_tenant( # noqa: PLR0913
698698
self,
699699
mock_aconnect_sse,
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
"""Test that streaming SSE responses clean up without athrow() errors.
2+
3+
Reproduces https://github.com/a2aproject/a2a-python/issues/911 —
4+
``RuntimeError: athrow(): asynchronous generator is already running``
5+
during event-loop shutdown after consuming a streaming response.
6+
"""
7+
8+
import asyncio
9+
import gc
10+
11+
from typing import Any
12+
from uuid import uuid4
13+
14+
import httpx
15+
import pytest
16+
17+
from starlette.applications import Starlette
18+
19+
from a2a.client.base_client import BaseClient
20+
from a2a.client.client import ClientConfig
21+
from a2a.client.client_factory import ClientFactory
22+
from a2a.server.agent_execution import AgentExecutor, RequestContext
23+
from a2a.server.events import EventQueue
24+
from a2a.server.events.in_memory_queue_manager import InMemoryQueueManager
25+
from a2a.server.request_handlers import DefaultRequestHandler
26+
from a2a.server.routes import create_agent_card_routes, create_jsonrpc_routes
27+
from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore
28+
from a2a.types import (
29+
AgentCapabilities,
30+
AgentCard,
31+
AgentInterface,
32+
Message,
33+
Part,
34+
Role,
35+
SendMessageRequest,
36+
)
37+
from a2a.utils import TransportProtocol
38+
39+
40+
class _MessageExecutor(AgentExecutor):
41+
"""Responds with a single Message event."""
42+
43+
async def execute(self, ctx: RequestContext, eq: EventQueue) -> None:
44+
await eq.enqueue_event(
45+
Message(
46+
role=Role.ROLE_AGENT,
47+
message_id=str(uuid4()),
48+
parts=[Part(text='Hello')],
49+
context_id=ctx.context_id,
50+
task_id=ctx.task_id,
51+
)
52+
)
53+
54+
async def cancel(self, ctx: RequestContext, eq: EventQueue) -> None:
55+
pass
56+
57+
58+
@pytest.fixture
59+
def client():
60+
"""Creates a JSON-RPC client backed by an in-process ASGI server."""
61+
card = AgentCard(
62+
name='T',
63+
description='T',
64+
version='1',
65+
capabilities=AgentCapabilities(streaming=True),
66+
default_input_modes=['text/plain'],
67+
default_output_modes=['text/plain'],
68+
supported_interfaces=[
69+
AgentInterface(
70+
protocol_binding=TransportProtocol.JSONRPC,
71+
url='http://test',
72+
),
73+
],
74+
)
75+
handler = DefaultRequestHandler(
76+
agent_executor=_MessageExecutor(),
77+
task_store=InMemoryTaskStore(),
78+
queue_manager=InMemoryQueueManager(),
79+
)
80+
app = Starlette(
81+
routes=[
82+
*create_agent_card_routes(agent_card=card, card_url='/card'),
83+
*create_jsonrpc_routes(
84+
agent_card=card,
85+
request_handler=handler,
86+
extended_agent_card=card,
87+
rpc_url='/',
88+
),
89+
]
90+
)
91+
return ClientFactory(
92+
config=ClientConfig(
93+
httpx_client=httpx.AsyncClient(
94+
transport=httpx.ASGITransport(app=app),
95+
base_url='http://test',
96+
)
97+
)
98+
).create(card)
99+
100+
101+
@pytest.mark.asyncio
102+
async def test_stream_message_no_athrow(client: BaseClient) -> None:
103+
"""Consuming a streamed Message must not leave broken async generators."""
104+
errors: list[dict[str, Any]] = []
105+
loop = asyncio.get_event_loop()
106+
orig = loop.get_exception_handler()
107+
loop.set_exception_handler(lambda _l, ctx: errors.append(ctx))
108+
109+
try:
110+
msg = Message(
111+
role=Role.ROLE_USER,
112+
message_id=f'msg-{uuid4()}',
113+
parts=[Part(text='hi')],
114+
)
115+
events = [
116+
e
117+
async for e in client.send_message(
118+
request=SendMessageRequest(message=msg)
119+
)
120+
]
121+
assert events
122+
assert events[0][0].HasField('message')
123+
124+
gc.collect()
125+
await loop.shutdown_asyncgens()
126+
127+
bad = [
128+
e
129+
for e in errors
130+
if 'asynchronous generator' in str(e.get('message', ''))
131+
]
132+
assert not bad, '\n'.join(str(e.get('message', '')) for e in bad)
133+
finally:
134+
loop.set_exception_handler(orig)
135+
await client.close()

0 commit comments

Comments
 (0)