44import httpcore
55import sniffio
66
7- from .._content_streams import ByteStream
7+ from .._content_streams import AsyncIteratorStream , ByteStream
88from .._utils import warn_deprecated
99
1010if typing .TYPE_CHECKING : # pragma: no cover
@@ -25,6 +25,75 @@ def create_event() -> "Event":
2525 return asyncio .Event ()
2626
2727
28+ async def create_background_task (async_fn : typing .Callable ) -> typing .Callable :
29+ if sniffio .current_async_library () == "trio" :
30+ import trio
31+
32+ nursery_manager = trio .open_nursery ()
33+ nursery = await nursery_manager .__aenter__ ()
34+ nursery .start_soon (async_fn )
35+
36+ async def aclose (exc : Exception = None ) -> None :
37+ if exc is not None :
38+ await nursery_manager .__aexit__ (type (exc ), exc , exc .__traceback__ )
39+ else :
40+ await nursery_manager .__aexit__ (None , None , None )
41+
42+ return aclose
43+
44+ else :
45+ import asyncio
46+
47+ task = asyncio .create_task (async_fn ())
48+
49+ async def aclose (exc : Exception = None ) -> None :
50+ if not task .done ():
51+ task .cancel ()
52+
53+ return aclose
54+
55+
56+ def create_channel (
57+ capacity : int ,
58+ ) -> typing .Tuple [
59+ typing .Callable [[], typing .Awaitable [bytes ]],
60+ typing .Callable [[bytes ], typing .Awaitable [None ]],
61+ ]:
62+ if sniffio .current_async_library () == "trio" :
63+ import trio
64+
65+ send_channel , receive_channel = trio .open_memory_channel [bytes ](capacity )
66+ return receive_channel .receive , send_channel .send
67+
68+ else :
69+ import asyncio
70+
71+ queue : asyncio .Queue [bytes ] = asyncio .Queue (capacity )
72+ return queue .get , queue .put
73+
74+
75+ async def run_until_first_complete (* async_fns : typing .Callable ) -> None :
76+ if sniffio .current_async_library () == "trio" :
77+ import trio
78+
79+ async with trio .open_nursery () as nursery :
80+
81+ async def run (async_fn : typing .Callable ) -> None :
82+ await async_fn ()
83+ nursery .cancel_scope .cancel ()
84+
85+ for async_fn in async_fns :
86+ nursery .start_soon (run , async_fn )
87+
88+ else :
89+ import asyncio
90+
91+ coros = [async_fn () for async_fn in async_fns ]
92+ done , pending = await asyncio .wait (coros , return_when = asyncio .FIRST_COMPLETED )
93+ for task in pending :
94+ task .cancel ()
95+
96+
2897class ASGITransport (httpcore .AsyncHTTPTransport ):
2998 """
3099 A custom AsyncTransport that handles sending requests directly to an ASGI app.
@@ -95,18 +164,20 @@ async def request(
95164 }
96165 status_code = None
97166 response_headers = None
98- body_parts = []
167+ consume_response_body_chunk , produce_response_body_chunk = create_channel ( 1 )
99168 request_complete = False
100- response_started = False
169+ response_started = create_event ()
101170 response_complete = create_event ()
171+ app_crashed = create_event ()
172+ app_exception : typing .Optional [Exception ] = None
102173
103174 headers = [] if headers is None else headers
104175 stream = ByteStream (b"" ) if stream is None else stream
105176
106177 request_body_chunks = stream .__aiter__ ()
107178
108179 async def receive () -> dict :
109- nonlocal request_complete , response_complete
180+ nonlocal request_complete
110181
111182 if request_complete :
112183 await response_complete .wait ()
@@ -120,38 +191,76 @@ async def receive() -> dict:
120191 return {"type" : "http.request" , "body" : body , "more_body" : True }
121192
122193 async def send (message : dict ) -> None :
123- nonlocal status_code , response_headers , body_parts
124- nonlocal response_started , response_complete
194+ nonlocal status_code , response_headers
125195
126196 if message ["type" ] == "http.response.start" :
127- assert not response_started
197+ assert not response_started . is_set ()
128198
129199 status_code = message ["status" ]
130200 response_headers = message .get ("headers" , [])
131- response_started = True
201+ response_started . set ()
132202
133203 elif message ["type" ] == "http.response.body" :
134204 assert not response_complete .is_set ()
135205 body = message .get ("body" , b"" )
136206 more_body = message .get ("more_body" , False )
137207
138208 if body and method != b"HEAD" :
139- body_parts . append (body )
209+ await produce_response_body_chunk (body )
140210
141211 if not more_body :
142212 response_complete .set ()
143213
144- try :
145- await self .app (scope , receive , send )
146- except Exception :
147- if self .raise_app_exceptions or not response_complete :
148- raise
214+ async def run_app () -> None :
215+ nonlocal app_exception
216+ try :
217+ await self .app (scope , receive , send )
218+ except Exception as exc :
219+ app_exception = exc
220+ app_crashed .set ()
221+
222+ aclose_app = await create_background_task (run_app )
223+
224+ await run_until_first_complete (app_crashed .wait , response_started .wait )
149225
150- assert response_complete .is_set ()
226+ if app_crashed .is_set ():
227+ assert app_exception is not None
228+ await aclose_app (app_exception )
229+ if self .raise_app_exceptions or not response_started .is_set ():
230+ raise app_exception
231+
232+ assert response_started .is_set ()
151233 assert status_code is not None
152234 assert response_headers is not None
153235
154- stream = ByteStream (b"" .join (body_parts ))
236+ async def aiter_response_body_chunks () -> typing .AsyncIterator [bytes ]:
237+ chunk = b""
238+
239+ async def consume_chunk () -> None :
240+ nonlocal chunk
241+ chunk = await consume_response_body_chunk ()
242+
243+ while True :
244+ await run_until_first_complete (
245+ app_crashed .wait , consume_chunk , response_complete .wait
246+ )
247+
248+ if app_crashed .is_set ():
249+ assert app_exception is not None
250+ if self .raise_app_exceptions :
251+ raise app_exception
252+ else :
253+ break
254+
255+ yield chunk
256+
257+ if response_complete .is_set ():
258+ break
259+
260+ async def aclose () -> None :
261+ await aclose_app (app_exception )
262+
263+ stream = AsyncIteratorStream (aiter_response_body_chunks (), close_func = aclose )
155264
156265 return (b"HTTP/1.1" , status_code , b"" , response_headers , stream )
157266
0 commit comments