Skip to content

Commit dc940e7

Browse files
anyio solution
1 parent ccc2d19 commit dc940e7

File tree

3 files changed

+111
-199
lines changed

3 files changed

+111
-199
lines changed

httpx/_transports/asgi.py

Lines changed: 107 additions & 197 deletions
Original file line numberDiff line numberDiff line change
@@ -1,116 +1,12 @@
1-
import contextlib
2-
from typing import (
3-
TYPE_CHECKING,
4-
AsyncIterator,
5-
Awaitable,
6-
Callable,
7-
List,
8-
Mapping,
9-
Optional,
10-
Tuple,
11-
Union,
12-
)
1+
import sys
2+
from typing import AsyncIterator, Callable, List, Mapping, Optional, Tuple
133

144
import httpcore
15-
import sniffio
165

17-
if TYPE_CHECKING: # pragma: no cover
18-
import asyncio
19-
20-
import trio
21-
22-
Event = Union[asyncio.Event, trio.Event]
23-
24-
25-
def create_event() -> "Event":
26-
if sniffio.current_async_library() == "trio":
27-
import trio
28-
29-
return trio.Event()
30-
else:
31-
import asyncio
32-
33-
return asyncio.Event()
34-
35-
36-
async def create_background_task(
37-
async_fn: Callable[[], Awaitable[None]]
38-
) -> Callable[[], Awaitable[None]]:
39-
if sniffio.current_async_library() == "trio":
40-
import trio
41-
42-
nursery_manager = trio.open_nursery()
43-
nursery = await nursery_manager.__aenter__()
44-
nursery.start_soon(async_fn)
45-
46-
async def aclose() -> None:
47-
await nursery_manager.__aexit__(None, None, None)
48-
49-
return aclose
50-
51-
else:
52-
import asyncio
53-
54-
loop = asyncio.get_event_loop()
55-
task = loop.create_task(async_fn())
56-
57-
async def aclose() -> None:
58-
task.cancel()
59-
# Task must be awaited in all cases to avoid debug warnings.
60-
with contextlib.suppress(asyncio.CancelledError):
61-
await task
62-
63-
return aclose
64-
65-
66-
def create_channel(
67-
capacity: int,
68-
) -> Tuple[
69-
Callable[[bytes], Awaitable[None]],
70-
Callable[[], Awaitable[None]],
71-
Callable[[], AsyncIterator[bytes]],
72-
]:
73-
"""
74-
Create an in-memory channel to pass data chunks between tasks.
75-
76-
* `produce()`: send data through the channel, blocking if necessary.
77-
* `consume()`: iterate over data in the channel.
78-
* `aclose_produce()`: mark that no more data will be produced, causing
79-
`consume()` to flush remaining data chunks then stop.
80-
"""
81-
if sniffio.current_async_library() == "trio":
82-
import trio
83-
84-
send_channel, receive_channel = trio.open_memory_channel[bytes](capacity)
85-
86-
async def consume() -> AsyncIterator[bytes]:
87-
async for chunk in receive_channel:
88-
yield chunk
89-
90-
return send_channel.send, send_channel.aclose, consume
91-
92-
else:
93-
import asyncio
94-
95-
queue: asyncio.Queue[bytes] = asyncio.Queue(capacity)
96-
produce_closed = False
97-
98-
async def produce(chunk: bytes) -> None:
99-
assert not produce_closed
100-
await queue.put(chunk)
101-
102-
async def aclose_produce() -> None:
103-
nonlocal produce_closed
104-
await queue.put(b"") # Make sure (*) doesn't block forever.
105-
produce_closed = True
106-
107-
async def consume() -> AsyncIterator[bytes]:
108-
while True:
109-
if produce_closed and queue.empty():
110-
break
111-
yield await queue.get() # (*)
112-
113-
return produce, aclose_produce, consume
6+
try:
7+
from contextlib import asynccontextmanager # type: ignore # Python 3.6.
8+
except ImportError: # pragma: no cover # Python 3.6.
9+
from async_generator import asynccontextmanager # type: ignore
11410

11511

11612
class ASGITransport(httpcore.AsyncHTTPTransport):
@@ -153,6 +49,11 @@ def __init__(
15349
root_path: str = "",
15450
client: Tuple[str, int] = ("127.0.0.1", 123),
15551
) -> None:
52+
try:
53+
import anyio # noqa
54+
except ImportError:
55+
raise ImportError("ASGITransport requires anyio. (Hint: pip install anyio)")
56+
15657
self.app = app
15758
self.raise_app_exceptions = raise_app_exceptions
15859
self.root_path = root_path
@@ -166,111 +67,120 @@ async def request(
16667
stream: httpcore.AsyncByteStream = None,
16768
timeout: Mapping[str, Optional[float]] = None,
16869
) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]], httpcore.AsyncByteStream]:
70+
16971
headers = [] if headers is None else headers
17072
stream = httpcore.PlainByteStream(content=b"") if stream is None else stream
17173

172-
# ASGI scope.
173-
scheme, host, port, full_path = url
174-
path, _, query = full_path.partition(b"?")
175-
scope = {
176-
"type": "http",
177-
"asgi": {"version": "3.0"},
178-
"http_version": "1.1",
179-
"method": method.decode(),
180-
"headers": headers,
181-
"scheme": scheme.decode("ascii"),
182-
"path": path.decode("ascii"),
183-
"query_string": query,
184-
"server": (host.decode("ascii"), port),
185-
"client": self.client,
186-
"root_path": self.root_path,
187-
}
188-
189-
# Request.
190-
request_body_chunks = stream.__aiter__()
191-
request_complete = False
192-
193-
# Response.
194-
status_code: Optional[int] = None
195-
response_headers: Optional[List[Tuple[bytes, bytes]]] = None
196-
produce_body, aclose_body, consume_body = create_channel(1)
197-
response_started_or_app_crashed = create_event()
198-
response_complete = create_event()
199-
200-
# ASGI callables.
201-
202-
async def receive() -> dict:
203-
nonlocal request_complete
204-
205-
if request_complete:
206-
await response_complete.wait()
207-
return {"type": "http.disconnect"}
208-
209-
try:
210-
body = await request_body_chunks.__anext__()
211-
except StopAsyncIteration:
212-
request_complete = True
213-
return {"type": "http.request", "body": b"", "more_body": False}
214-
return {"type": "http.request", "body": body, "more_body": True}
74+
app_context = run_asgi(
75+
self.app,
76+
method,
77+
url,
78+
headers,
79+
stream,
80+
client=self.client,
81+
root_path=self.root_path,
82+
)
21583

216-
async def send(message: dict) -> None:
217-
nonlocal status_code, response_headers
218-
if message["type"] == "http.response.start":
219-
assert not response_started_or_app_crashed.is_set()
220-
status_code = message["status"]
221-
response_headers = message.get("headers", [])
222-
response_started_or_app_crashed.set()
84+
status_code, response_headers, response_body = await app_context.__aenter__()
22385

224-
elif message["type"] == "http.response.body":
225-
assert not response_complete.is_set()
226-
body = message.get("body", b"")
227-
more_body = message.get("more_body", False)
86+
async def aclose() -> None:
87+
await app_context.__aexit__(*sys.exc_info())
22888

229-
if body and method != b"HEAD":
230-
await produce_body(body)
89+
stream = httpcore.AsyncIteratorByteStream(response_body, aclose_func=aclose)
23190

232-
if not more_body:
233-
await aclose_body()
234-
response_complete.set()
91+
return (b"HTTP/1.1", status_code, b"", response_headers, stream)
23592

236-
# Application wrapper.
23793

238-
app_exception: Optional[Exception] = None
94+
@asynccontextmanager
95+
async def run_asgi(
96+
app: Callable,
97+
method: bytes,
98+
url: Tuple[bytes, bytes, Optional[int], bytes],
99+
headers: List[Tuple[bytes, bytes]],
100+
stream: httpcore.AsyncByteStream,
101+
*,
102+
client: str,
103+
root_path: str,
104+
) -> AsyncIterator[Tuple[int, List[Tuple[bytes, bytes]], AsyncIterator[bytes]]]:
105+
import anyio
106+
107+
# ASGI scope.
108+
scheme, host, port, full_path = url
109+
path, _, query = full_path.partition(b"?")
110+
scope = {
111+
"type": "http",
112+
"asgi": {"version": "3.0"},
113+
"http_version": "1.1",
114+
"method": method.decode(),
115+
"headers": headers,
116+
"scheme": scheme.decode("ascii"),
117+
"path": path.decode("ascii"),
118+
"query_string": query,
119+
"server": (host.decode("ascii"), port),
120+
"client": client,
121+
"root_path": root_path,
122+
}
123+
124+
# Request.
125+
request_body_chunks = stream.__aiter__()
126+
request_complete = False
127+
128+
# Response.
129+
status_code: Optional[int] = None
130+
response_headers: Optional[List[Tuple[bytes, bytes]]] = None
131+
response_body_queue = anyio.create_queue(1)
132+
response_started = anyio.create_event()
133+
response_complete = anyio.create_event()
134+
135+
async def receive() -> dict:
136+
nonlocal request_complete
137+
138+
if request_complete:
139+
await response_complete.wait()
140+
return {"type": "http.disconnect"}
141+
142+
try:
143+
body = await request_body_chunks.__anext__()
144+
except StopAsyncIteration:
145+
request_complete = True
146+
return {"type": "http.request", "body": b"", "more_body": False}
147+
else:
148+
return {"type": "http.request", "body": body, "more_body": True}
239149

240-
async def run_app() -> None:
241-
nonlocal app_exception
242-
try:
243-
await self.app(scope, receive, send)
244-
except Exception as exc:
245-
app_exception = exc
246-
response_started_or_app_crashed.set()
247-
await aclose_body() # Stop response body consumer once flushed (*).
150+
async def send(message: dict) -> None:
151+
nonlocal status_code, response_headers
248152

249-
# Response body iterator.
153+
if message["type"] == "http.response.start":
154+
assert not response_started.is_set()
155+
status_code = message["status"]
156+
response_headers = message.get("headers", [])
157+
await response_started.set()
250158

251-
async def aiter_response_body() -> AsyncIterator[bytes]:
252-
async for chunk in consume_body(): # (*)
253-
yield chunk
159+
elif message["type"] == "http.response.body":
160+
assert not response_complete.is_set()
161+
body = message.get("body", b"")
162+
more_body = message.get("more_body", False)
254163

255-
if app_exception is not None and self.raise_app_exceptions:
256-
raise app_exception
164+
if body and method != b"HEAD":
165+
await response_body_queue.put(body)
257166

258-
# Now we wire things up...
167+
if not more_body:
168+
await response_body_queue.put(None)
169+
await response_complete.set()
259170

260-
aclose = await create_background_task(run_app)
171+
async def body_iterator() -> AsyncIterator[bytes]:
172+
while True:
173+
chunk = await response_body_queue.get()
174+
if chunk is None:
175+
break
176+
yield chunk
261177

262-
await response_started_or_app_crashed.wait()
178+
async with anyio.create_task_group() as task_group:
179+
await task_group.spawn(app, scope, receive, send)
263180

264-
if app_exception is not None:
265-
await aclose()
266-
if self.raise_app_exceptions or not response_complete.is_set():
267-
raise app_exception
181+
await response_started.wait()
268182

269183
assert status_code is not None
270184
assert response_headers is not None
271185

272-
stream = httpcore.AsyncIteratorByteStream(
273-
aiter_response_body(), aclose_func=aclose
274-
)
275-
276-
return (b"HTTP/1.1", status_code, b"", response_headers, stream)
186+
yield status_code, response_headers, body_iterator()

requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
-e .[http2]
22

33
# Optional
4+
async_generator; python_version < '3.7'
5+
anyio
46
brotlipy==0.7.*
57

68
# Documentation

tests/test_asgi.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ async def test_asgi_streaming_exc():
134134
@pytest.mark.usefixtures("async_environment")
135135
async def test_asgi_streaming_exc_after_response():
136136
client = httpx.AsyncClient(app=raise_exc_after_response)
137-
async with client.stream("GET", "http://www.example.org/") as response:
138-
with pytest.raises(ValueError):
137+
with pytest.raises(ValueError):
138+
async with client.stream("GET", "http://www.example.org/") as response:
139139
async for _ in response.aiter_bytes():
140140
pass # pragma: no cover

0 commit comments

Comments
 (0)