Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 38 additions & 2 deletions src/a2a/client/transports/http_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import httpx

from httpx_sse import SSEError, aconnect_sse
from httpx_sse import EventSource, SSEError

from a2a.client.client import ClientCallContext
from a2a.client.errors import A2AClientError, A2AClientTimeoutError
Expand Down Expand Up @@ -75,7 +75,7 @@ async def send_http_stream_request(
) -> AsyncGenerator[str]:
"""Sends a streaming HTTP request, yielding SSE data strings and handling exceptions."""
with handle_http_exceptions(status_error_handler):
async with aconnect_sse(
async with _SSEEventSource(
httpx_client, method, url, **kwargs
) as event_source:
try:
Expand All @@ -98,3 +98,39 @@ async def send_http_stream_request(
if not sse.data:
continue
yield sse.data


class _SSEEventSource:
"""Class-based replacement for ``httpx_sse.aconnect_sse``.

``aconnect_sse`` is an ``@asynccontextmanager`` whose internal async
generator gets tracked by the event loop. When the enclosing async
generator is abandoned, the event loop's generator cleanup collides
with the cascading cleanup — see https://bugs.python.org/issue38559.

Plain ``__aenter__``/``__aexit__`` coroutines avoid this entirely.
"""

def __init__(
self,
client: httpx.AsyncClient,
method: str,
url: str,
**kwargs: Any,
) -> None:
headers = httpx.Headers(kwargs.pop('headers', None))
headers.setdefault('Accept', 'text/event-stream')
headers.setdefault('Cache-Control', 'no-store')
self._request = client.build_request(
method, url, headers=headers, **kwargs
)
self._client = client
self._response: httpx.Response | None = None

async def __aenter__(self) -> EventSource:
self._response = await self._client.send(self._request, stream=True)
return EventSource(self._response)

async def __aexit__(self, *args: object) -> None:
if self._response is not None:
await self._response.aclose()
8 changes: 4 additions & 4 deletions tests/client/transports/test_jsonrpc_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ async def test_close(self, transport, mock_httpx_client):

class TestStreamingErrors:
@pytest.mark.asyncio
@patch('a2a.client.transports.http_helpers.aconnect_sse')
@patch('a2a.client.transports.http_helpers._SSEEventSource')
async def test_send_message_streaming_sse_error(
self,
mock_aconnect_sse: AsyncMock,
Expand All @@ -457,7 +457,7 @@ async def test_send_message_streaming_sse_error(
pass

@pytest.mark.asyncio
@patch('a2a.client.transports.http_helpers.aconnect_sse')
@patch('a2a.client.transports.http_helpers._SSEEventSource')
async def test_send_message_streaming_request_error(
self,
mock_aconnect_sse: AsyncMock,
Expand All @@ -483,7 +483,7 @@ async def test_send_message_streaming_request_error(
pass

@pytest.mark.asyncio
@patch('a2a.client.transports.http_helpers.aconnect_sse')
@patch('a2a.client.transports.http_helpers._SSEEventSource')
async def test_send_message_streaming_timeout(
self,
mock_aconnect_sse: AsyncMock,
Expand Down Expand Up @@ -560,7 +560,7 @@ async def test_extensions_added_to_request(
)

@pytest.mark.asyncio
@patch('a2a.client.transports.http_helpers.aconnect_sse')
@patch('a2a.client.transports.http_helpers._SSEEventSource')
async def test_send_message_streaming_server_error_propagates(
self,
mock_aconnect_sse: AsyncMock,
Expand Down
8 changes: 4 additions & 4 deletions tests/client/transports/test_rest_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def _assert_extensions_header(mock_kwargs: dict, expected_extensions: set[str]):

class TestRestTransport:
@pytest.mark.asyncio
@patch('a2a.client.transports.http_helpers.aconnect_sse')
@patch('a2a.client.transports.http_helpers._SSEEventSource')
async def test_send_message_streaming_timeout(
self,
mock_aconnect_sse: AsyncMock,
Expand Down Expand Up @@ -280,7 +280,7 @@ async def test_send_message_with_default_extensions(
)

@pytest.mark.asyncio
@patch('a2a.client.transports.http_helpers.aconnect_sse')
@patch('a2a.client.transports.http_helpers._SSEEventSource')
async def test_send_message_streaming_with_new_extensions(
self,
mock_aconnect_sse: AsyncMock,
Expand Down Expand Up @@ -329,7 +329,7 @@ async def test_send_message_streaming_with_new_extensions(
)

@pytest.mark.asyncio
@patch('a2a.client.transports.http_helpers.aconnect_sse')
@patch('a2a.client.transports.http_helpers._SSEEventSource')
async def test_send_message_streaming_server_error_propagates(
self,
mock_aconnect_sse: AsyncMock,
Expand Down Expand Up @@ -693,7 +693,7 @@ async def test_rest_get_task_prepend_empty_tenant(
],
)
@pytest.mark.asyncio
@patch('a2a.client.transports.http_helpers.aconnect_sse')
@patch('a2a.client.transports.http_helpers._SSEEventSource')
async def test_rest_streaming_methods_prepend_tenant( # noqa: PLR0913
self,
mock_aconnect_sse,
Expand Down
135 changes: 135 additions & 0 deletions tests/integration/test_stream_generator_cleanup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
"""Test that streaming SSE responses clean up without athrow() errors.

Reproduces https://github.com/a2aproject/a2a-python/issues/911 —
``RuntimeError: athrow(): asynchronous generator is already running``
during event-loop shutdown after consuming a streaming response.
"""

import asyncio
import gc

from typing import Any
from uuid import uuid4

import httpx
import pytest

from starlette.applications import Starlette

from a2a.client.base_client import BaseClient
from a2a.client.client import ClientConfig
from a2a.client.client_factory import ClientFactory
from a2a.server.agent_execution import AgentExecutor, RequestContext
from a2a.server.events import EventQueue
from a2a.server.events.in_memory_queue_manager import InMemoryQueueManager
from a2a.server.request_handlers import DefaultRequestHandler
from a2a.server.routes import create_agent_card_routes, create_jsonrpc_routes
from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore
from a2a.types import (
AgentCapabilities,
AgentCard,
AgentInterface,
Message,
Part,
Role,
SendMessageRequest,
)
from a2a.utils import TransportProtocol


class _MessageExecutor(AgentExecutor):
"""Responds with a single Message event."""

async def execute(self, ctx: RequestContext, eq: EventQueue) -> None:
await eq.enqueue_event(
Message(
role=Role.ROLE_AGENT,
message_id=str(uuid4()),
parts=[Part(text='Hello')],
context_id=ctx.context_id,
task_id=ctx.task_id,
)
)

async def cancel(self, ctx: RequestContext, eq: EventQueue) -> None:
pass


@pytest.fixture
def client():
"""Creates a JSON-RPC client backed by an in-process ASGI server."""
card = AgentCard(
name='T',
description='T',
version='1',
capabilities=AgentCapabilities(streaming=True),
default_input_modes=['text/plain'],
default_output_modes=['text/plain'],
supported_interfaces=[
AgentInterface(
protocol_binding=TransportProtocol.JSONRPC,
url='http://test',
),
],
)
handler = DefaultRequestHandler(
agent_executor=_MessageExecutor(),
task_store=InMemoryTaskStore(),
queue_manager=InMemoryQueueManager(),
)
app = Starlette(
routes=[
*create_agent_card_routes(agent_card=card, card_url='/card'),
*create_jsonrpc_routes(
agent_card=card,
request_handler=handler,
extended_agent_card=card,
rpc_url='/',
),
]
)
return ClientFactory(
config=ClientConfig(
httpx_client=httpx.AsyncClient(
transport=httpx.ASGITransport(app=app),
base_url='http://test',
)
)
).create(card)


@pytest.mark.asyncio
async def test_stream_message_no_athrow(client: BaseClient) -> None:
"""Consuming a streamed Message must not leave broken async generators."""
errors: list[dict[str, Any]] = []
loop = asyncio.get_event_loop()
orig = loop.get_exception_handler()
loop.set_exception_handler(lambda _l, ctx: errors.append(ctx))

try:
msg = Message(
role=Role.ROLE_USER,
message_id=f'msg-{uuid4()}',
parts=[Part(text='hi')],
)
events = [
e
async for e in client.send_message(
request=SendMessageRequest(message=msg)
)
]
assert events
assert events[0][0].HasField('message')

gc.collect()
await loop.shutdown_asyncgens()

bad = [
e
for e in errors
if 'asynchronous generator' in str(e.get('message', ''))
]
assert not bad, '\n'.join(str(e.get('message', '')) for e in bad)
finally:
loop.set_exception_handler(orig)
await client.close()
Loading