1- import contextlib
2- from typing import (
3- TYPE_CHECKING ,
4- AsyncIterator ,
5- Awaitable ,
6- Callable ,
7- List ,
8- Mapping ,
9- Optional ,
10- Tuple ,
11- Union ,
12- )
1+ import sys
2+ from typing import AsyncIterator , Callable , List , Mapping , Optional , Tuple
133
144import httpcore
15- import sniffio
165
17- if TYPE_CHECKING : # pragma: no cover
18- import asyncio
19-
20- import trio
21-
22- Event = Union [asyncio .Event , trio .Event ]
23-
24-
25- def create_event () -> "Event" :
26- if sniffio .current_async_library () == "trio" :
27- import trio
28-
29- return trio .Event ()
30- else :
31- import asyncio
32-
33- return asyncio .Event ()
34-
35-
36- async def create_background_task (
37- async_fn : Callable [[], Awaitable [None ]]
38- ) -> Callable [[], Awaitable [None ]]:
39- if sniffio .current_async_library () == "trio" :
40- import trio
41-
42- nursery_manager = trio .open_nursery ()
43- nursery = await nursery_manager .__aenter__ ()
44- nursery .start_soon (async_fn )
45-
46- async def aclose () -> None :
47- await nursery_manager .__aexit__ (None , None , None )
48-
49- return aclose
50-
51- else :
52- import asyncio
53-
54- loop = asyncio .get_event_loop ()
55- task = loop .create_task (async_fn ())
56-
57- async def aclose () -> None :
58- task .cancel ()
59- # Task must be awaited in all cases to avoid debug warnings.
60- with contextlib .suppress (asyncio .CancelledError ):
61- await task
62-
63- return aclose
64-
65-
66- def create_channel (
67- capacity : int ,
68- ) -> Tuple [
69- Callable [[bytes ], Awaitable [None ]],
70- Callable [[], Awaitable [None ]],
71- Callable [[], AsyncIterator [bytes ]],
72- ]:
73- """
74- Create an in-memory channel to pass data chunks between tasks.
75-
76- * `produce()`: send data through the channel, blocking if necessary.
77- * `consume()`: iterate over data in the channel.
78- * `aclose_produce()`: mark that no more data will be produced, causing
79- `consume()` to flush remaining data chunks then stop.
80- """
81- if sniffio .current_async_library () == "trio" :
82- import trio
83-
84- send_channel , receive_channel = trio .open_memory_channel [bytes ](capacity )
85-
86- async def consume () -> AsyncIterator [bytes ]:
87- async for chunk in receive_channel :
88- yield chunk
89-
90- return send_channel .send , send_channel .aclose , consume
91-
92- else :
93- import asyncio
94-
95- queue : asyncio .Queue [bytes ] = asyncio .Queue (capacity )
96- produce_closed = False
97-
98- async def produce (chunk : bytes ) -> None :
99- assert not produce_closed
100- await queue .put (chunk )
101-
102- async def aclose_produce () -> None :
103- nonlocal produce_closed
104- await queue .put (b"" ) # Make sure (*) doesn't block forever.
105- produce_closed = True
106-
107- async def consume () -> AsyncIterator [bytes ]:
108- while True :
109- if produce_closed and queue .empty ():
110- break
111- yield await queue .get () # (*)
112-
113- return produce , aclose_produce , consume
6+ try :
7+ from contextlib import asynccontextmanager # type: ignore # Python 3.6.
8+ except ImportError : # pragma: no cover # Python 3.6.
9+ from async_generator import asynccontextmanager # type: ignore
11410
11511
11612class ASGITransport (httpcore .AsyncHTTPTransport ):
@@ -153,6 +49,11 @@ def __init__(
15349 root_path : str = "" ,
15450 client : Tuple [str , int ] = ("127.0.0.1" , 123 ),
15551 ) -> None :
52+ try :
53+ import anyio # noqa
54+ except ImportError :
55+ raise ImportError ("ASGITransport requires anyio. (Hint: pip install anyio)" )
56+
15657 self .app = app
15758 self .raise_app_exceptions = raise_app_exceptions
15859 self .root_path = root_path
@@ -166,111 +67,120 @@ async def request(
16667 stream : httpcore .AsyncByteStream = None ,
16768 timeout : Mapping [str , Optional [float ]] = None ,
16869 ) -> Tuple [bytes , int , bytes , List [Tuple [bytes , bytes ]], httpcore .AsyncByteStream ]:
70+
16971 headers = [] if headers is None else headers
17072 stream = httpcore .PlainByteStream (content = b"" ) if stream is None else stream
17173
172- # ASGI scope.
173- scheme , host , port , full_path = url
174- path , _ , query = full_path .partition (b"?" )
175- scope = {
176- "type" : "http" ,
177- "asgi" : {"version" : "3.0" },
178- "http_version" : "1.1" ,
179- "method" : method .decode (),
180- "headers" : headers ,
181- "scheme" : scheme .decode ("ascii" ),
182- "path" : path .decode ("ascii" ),
183- "query_string" : query ,
184- "server" : (host .decode ("ascii" ), port ),
185- "client" : self .client ,
186- "root_path" : self .root_path ,
187- }
188-
189- # Request.
190- request_body_chunks = stream .__aiter__ ()
191- request_complete = False
192-
193- # Response.
194- status_code : Optional [int ] = None
195- response_headers : Optional [List [Tuple [bytes , bytes ]]] = None
196- produce_body , aclose_body , consume_body = create_channel (1 )
197- response_started_or_app_crashed = create_event ()
198- response_complete = create_event ()
199-
200- # ASGI callables.
201-
202- async def receive () -> dict :
203- nonlocal request_complete
204-
205- if request_complete :
206- await response_complete .wait ()
207- return {"type" : "http.disconnect" }
208-
209- try :
210- body = await request_body_chunks .__anext__ ()
211- except StopAsyncIteration :
212- request_complete = True
213- return {"type" : "http.request" , "body" : b"" , "more_body" : False }
214- return {"type" : "http.request" , "body" : body , "more_body" : True }
74+ app_context = run_asgi (
75+ self .app ,
76+ method ,
77+ url ,
78+ headers ,
79+ stream ,
80+ client = self .client ,
81+ root_path = self .root_path ,
82+ )
21583
216- async def send (message : dict ) -> None :
217- nonlocal status_code , response_headers
218- if message ["type" ] == "http.response.start" :
219- assert not response_started_or_app_crashed .is_set ()
220- status_code = message ["status" ]
221- response_headers = message .get ("headers" , [])
222- response_started_or_app_crashed .set ()
84+ status_code , response_headers , response_body = await app_context .__aenter__ ()
22385
224- elif message ["type" ] == "http.response.body" :
225- assert not response_complete .is_set ()
226- body = message .get ("body" , b"" )
227- more_body = message .get ("more_body" , False )
86+ async def aclose () -> None :
87+ await app_context .__aexit__ (* sys .exc_info ())
22888
229- if body and method != b"HEAD" :
230- await produce_body (body )
89+ stream = httpcore .AsyncIteratorByteStream (response_body , aclose_func = aclose )
23190
232- if not more_body :
233- await aclose_body ()
234- response_complete .set ()
91+ return (b"HTTP/1.1" , status_code , b"" , response_headers , stream )
23592
236- # Application wrapper.
23793
238- app_exception : Optional [Exception ] = None
94+ @asynccontextmanager
95+ async def run_asgi (
96+ app : Callable ,
97+ method : bytes ,
98+ url : Tuple [bytes , bytes , Optional [int ], bytes ],
99+ headers : List [Tuple [bytes , bytes ]],
100+ stream : httpcore .AsyncByteStream ,
101+ * ,
102+ client : str ,
103+ root_path : str ,
104+ ) -> AsyncIterator [Tuple [int , List [Tuple [bytes , bytes ]], AsyncIterator [bytes ]]]:
105+ import anyio
106+
107+ # ASGI scope.
108+ scheme , host , port , full_path = url
109+ path , _ , query = full_path .partition (b"?" )
110+ scope = {
111+ "type" : "http" ,
112+ "asgi" : {"version" : "3.0" },
113+ "http_version" : "1.1" ,
114+ "method" : method .decode (),
115+ "headers" : headers ,
116+ "scheme" : scheme .decode ("ascii" ),
117+ "path" : path .decode ("ascii" ),
118+ "query_string" : query ,
119+ "server" : (host .decode ("ascii" ), port ),
120+ "client" : client ,
121+ "root_path" : root_path ,
122+ }
123+
124+ # Request.
125+ request_body_chunks = stream .__aiter__ ()
126+ request_complete = False
127+
128+ # Response.
129+ status_code : Optional [int ] = None
130+ response_headers : Optional [List [Tuple [bytes , bytes ]]] = None
131+ response_body_queue = anyio .create_queue (1 )
132+ response_started = anyio .create_event ()
133+ response_complete = anyio .create_event ()
134+
135+ async def receive () -> dict :
136+ nonlocal request_complete
137+
138+ if request_complete :
139+ await response_complete .wait ()
140+ return {"type" : "http.disconnect" }
141+
142+ try :
143+ body = await request_body_chunks .__anext__ ()
144+ except StopAsyncIteration :
145+ request_complete = True
146+ return {"type" : "http.request" , "body" : b"" , "more_body" : False }
147+ else :
148+ return {"type" : "http.request" , "body" : body , "more_body" : True }
239149
240- async def run_app () -> None :
241- nonlocal app_exception
242- try :
243- await self .app (scope , receive , send )
244- except Exception as exc :
245- app_exception = exc
246- response_started_or_app_crashed .set ()
247- await aclose_body () # Stop response body consumer once flushed (*).
150+ async def send (message : dict ) -> None :
151+ nonlocal status_code , response_headers
248152
249- # Response body iterator.
153+ if message ["type" ] == "http.response.start" :
154+ assert not response_started .is_set ()
155+ status_code = message ["status" ]
156+ response_headers = message .get ("headers" , [])
157+ await response_started .set ()
250158
251- async def aiter_response_body () -> AsyncIterator [bytes ]:
252- async for chunk in consume_body (): # (*)
253- yield chunk
159+ elif message ["type" ] == "http.response.body" :
160+ assert not response_complete .is_set ()
161+ body = message .get ("body" , b"" )
162+ more_body = message .get ("more_body" , False )
254163
255- if app_exception is not None and self . raise_app_exceptions :
256- raise app_exception
164+ if body and method != b"HEAD" :
165+ await response_body_queue . put ( body )
257166
258- # Now we wire things up...
167+ if not more_body :
168+ await response_body_queue .put (None )
169+ await response_complete .set ()
259170
260- aclose = await create_background_task (run_app )
171+ async def body_iterator () -> AsyncIterator [bytes ]:
172+ while True :
173+ chunk = await response_body_queue .get ()
174+ if chunk is None :
175+ break
176+ yield chunk
261177
262- await response_started_or_app_crashed .wait ()
178+ async with anyio .create_task_group () as task_group :
179+ await task_group .spawn (app , scope , receive , send )
263180
264- if app_exception is not None :
265- await aclose ()
266- if self .raise_app_exceptions or not response_complete .is_set ():
267- raise app_exception
181+ await response_started .wait ()
268182
269183 assert status_code is not None
270184 assert response_headers is not None
271185
272- stream = httpcore .AsyncIteratorByteStream (
273- aiter_response_body (), aclose_func = aclose
274- )
275-
276- return (b"HTTP/1.1" , status_code , b"" , response_headers , stream )
186+ yield status_code , response_headers , body_iterator ()
0 commit comments