Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 38 additions & 2 deletions src/a2a/client/transports/http_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import httpx

from httpx_sse import SSEError, aconnect_sse
from httpx_sse import EventSource, SSEError

from a2a.client.client import ClientCallContext
from a2a.client.errors import A2AClientError, A2AClientTimeoutError
Expand Down Expand Up @@ -75,7 +75,7 @@
) -> AsyncGenerator[str]:
"""Sends a streaming HTTP request, yielding SSE data strings and handling exceptions."""
with handle_http_exceptions(status_error_handler):
async with aconnect_sse(
async with _SSEEventSource(
httpx_client, method, url, **kwargs
) as event_source:
try:
Expand All @@ -98,3 +98,39 @@
if not sse.data:
continue
yield sse.data


class _SSEEventSource:
"""Class-based replacement for ``httpx_sse.aconnect_sse``.

``aconnect_sse`` is an ``@asynccontextmanager`` whose internal async
generator leaks into ``loop._asyncgens``. When the enclosing async

Check failure on line 107 in src/a2a/client/transports/http_helpers.py

View workflow job for this annotation

GitHub Actions / Check Spelling

`asyncgens` is not a recognized word. (unrecognized-spelling)
generator is abandoned, ``shutdown_asyncgens`` collides with the

Check failure on line 108 in src/a2a/client/transports/http_helpers.py

View workflow job for this annotation

GitHub Actions / Check Spelling

`asyncgens` is not a recognized word. (unrecognized-spelling)

Check warning on line 108 in src/a2a/client/transports/http_helpers.py

View workflow job for this annotation

GitHub Actions / Check Spelling

`asyncgens` is not a recognized word. (unrecognized-spelling)
cascading ``athrow()`` cleanup — see https://bugs.python.org/issue38559.

Check failure on line 109 in src/a2a/client/transports/http_helpers.py

View workflow job for this annotation

GitHub Actions / Check Spelling

`athrow` is not a recognized word. (unrecognized-spelling)

Plain ``__aenter__``/``__aexit__`` coroutines avoid this entirely.
"""

def __init__(
self,
client: httpx.AsyncClient,
method: str,
url: str,
**kwargs: Any,
) -> None:
headers = kwargs.pop('headers', {})
headers['Accept'] = 'text/event-stream'
headers['Cache-Control'] = 'no-store'
Comment thread
ishymko marked this conversation as resolved.
Outdated
self._request = client.build_request(
method, url, headers=headers, **kwargs
)
self._client = client
self._response: httpx.Response | None = None

async def __aenter__(self) -> EventSource:
self._response = await self._client.send(self._request, stream=True)
return EventSource(self._response)

async def __aexit__(self, *args: object) -> None:
if self._response is not None:
await self._response.aclose()
135 changes: 135 additions & 0 deletions tests/integration/test_stream_generator_cleanup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
"""Test that streaming SSE responses clean up without athrow() errors.

Reproduces https://github.com/a2aproject/a2a-python/issues/XXX —
Comment thread
ishymko marked this conversation as resolved.
Outdated
``RuntimeError: athrow(): asynchronous generator is already running``
during event-loop shutdown after consuming a streaming response.
"""

import asyncio
import gc

from typing import Any
from uuid import uuid4

import httpx
import pytest

from starlette.applications import Starlette

from a2a.client.base_client import BaseClient
from a2a.client.client import ClientConfig
from a2a.client.client_factory import ClientFactory
from a2a.server.agent_execution import AgentExecutor, RequestContext
from a2a.server.events import EventQueue
from a2a.server.events.in_memory_queue_manager import InMemoryQueueManager
from a2a.server.request_handlers import DefaultRequestHandler
from a2a.server.routes import create_agent_card_routes, create_jsonrpc_routes
from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore
from a2a.types import (
AgentCapabilities,
AgentCard,
AgentInterface,
Message,
Part,
Role,
SendMessageRequest,
)
from a2a.utils import TransportProtocol


class _MessageExecutor(AgentExecutor):
"""Responds with a single Message event."""

async def execute(self, ctx: RequestContext, eq: EventQueue) -> None:
await eq.enqueue_event(
Message(
role=Role.ROLE_AGENT,
message_id=str(uuid4()),
parts=[Part(text='Hello')],
context_id=ctx.context_id,
task_id=ctx.task_id,
)
)

async def cancel(self, ctx: RequestContext, eq: EventQueue) -> None:
pass


@pytest.fixture
def client():
"""Creates a JSON-RPC client backed by an in-process ASGI server."""
card = AgentCard(
name='T',
description='T',
version='1',
capabilities=AgentCapabilities(streaming=True),
default_input_modes=['text/plain'],
default_output_modes=['text/plain'],
supported_interfaces=[
AgentInterface(
protocol_binding=TransportProtocol.JSONRPC,
url='http://test',
),
],
)
handler = DefaultRequestHandler(
agent_executor=_MessageExecutor(),
task_store=InMemoryTaskStore(),
queue_manager=InMemoryQueueManager(),
)
app = Starlette(
routes=[
*create_agent_card_routes(agent_card=card, card_url='/card'),
*create_jsonrpc_routes(
agent_card=card,
request_handler=handler,
extended_agent_card=card,
rpc_url='/',
),
]
)
return ClientFactory(
config=ClientConfig(
httpx_client=httpx.AsyncClient(
transport=httpx.ASGITransport(app=app),
base_url='http://test',
)
)
).create(card)


@pytest.mark.asyncio
async def test_stream_message_no_athrow(client: BaseClient) -> None:
"""Consuming a streamed Message must not leave broken async generators."""
errors: list[dict[str, Any]] = []
loop = asyncio.get_event_loop()
orig = loop.get_exception_handler()
loop.set_exception_handler(lambda _l, ctx: errors.append(ctx))

try:
msg = Message(
role=Role.ROLE_USER,
message_id=f'msg-{uuid4()}',
parts=[Part(text='hi')],
)
events = [
e
async for e in client.send_message(
request=SendMessageRequest(message=msg)
)
]
assert events
assert events[0][0].HasField('message')

gc.collect()
await loop.shutdown_asyncgens()

bad = [
e
for e in errors
if 'asynchronous generator' in str(e.get('message', ''))
]
assert not bad, '\n'.join(str(e.get('message', '')) for e in bad)
finally:
loop.set_exception_handler(orig)
await client.close()
Loading