Skip to content

Commit 2a1c6ef

Browse files
Implement curio backend (#168)
* Implemented curio backend (#94) * Fixing PR remarks (#94) * Fixing PR remarks. Mention curio in the same context when the pair of asyncio and trio are mentioned (#94) * Fixing PR remarks. Made pytest.mark.curio completely isolated from conftest.py (for now it's easier to add new backend with custom marks) (#94) * PR review. Updated tests/marks/curio.py (removed unnecessary fixture) Co-authored-by: Florimond Manca <florimond.manca@gmail.com> * Added "curio" test mark programmatically (#94) * Fixed PR remarks (#94) Added timeout handling to Semaphore::acquire and tried to avoid private API usage in SocketStream::get_http_version, also changed is_connection_dropped behaviour * Fixed PR remarks (#94) Rewrote _wrap_ssl_client using ssl.SSLContext::wrap_socket * PR review (#94) Co-authored-by: Florimond Manca <florimond.manca@gmail.com> * PR review (#94) Co-authored-by: Florimond Manca <florimond.manca@gmail.com> Co-authored-by: Florimond Manca <florimond.manca@gmail.com>
1 parent ccea853 commit 2a1c6ef

File tree

8 files changed

+274
-3
lines changed

8 files changed

+274
-3
lines changed

httpcore/_backends/auto.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@ def backend(self) -> AsyncBackend:
2424
from .trio import TrioBackend
2525

2626
self._backend_implementation = TrioBackend()
27+
elif backend == "curio":
28+
from .curio import CurioBackend
29+
30+
self._backend_implementation = CurioBackend()
2731
else: # pragma: nocover
2832
raise RuntimeError(f"Unsupported concurrency backend {backend!r}")
2933
return self._backend_implementation

httpcore/_backends/curio.py

Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
import select
2+
from ssl import SSLContext, SSLSocket
3+
from typing import Optional
4+
5+
import curio
6+
import curio.io
7+
8+
from .._exceptions import (
9+
ConnectError,
10+
ConnectTimeout,
11+
ReadError,
12+
ReadTimeout,
13+
WriteError,
14+
WriteTimeout,
15+
map_exceptions,
16+
)
17+
from .._types import TimeoutDict
18+
from .._utils import get_logger
19+
from .base import AsyncBackend, AsyncLock, AsyncSemaphore, AsyncSocketStream
20+
21+
logger = get_logger(__name__)
22+
23+
ONE_DAY_IN_SECONDS = float(60 * 60 * 24)
24+
25+
26+
def convert_timeout(value: Optional[float]) -> float:
27+
return value if value is not None else ONE_DAY_IN_SECONDS
28+
29+
30+
class Lock(AsyncLock):
31+
def __init__(self) -> None:
32+
self._lock = curio.Lock()
33+
34+
async def acquire(self) -> None:
35+
await self._lock.acquire()
36+
37+
async def release(self) -> None:
38+
await self._lock.release()
39+
40+
41+
class Semaphore(AsyncSemaphore):
42+
def __init__(self, max_value: int, exc_class: type) -> None:
43+
self.max_value = max_value
44+
self.exc_class = exc_class
45+
46+
@property
47+
def semaphore(self) -> curio.Semaphore:
48+
if not hasattr(self, "_semaphore"):
49+
self._semaphore = curio.Semaphore(value=self.max_value)
50+
return self._semaphore
51+
52+
async def acquire(self, timeout: float = None) -> None:
53+
timeout = convert_timeout(timeout)
54+
55+
try:
56+
return await curio.timeout_after(timeout, self.semaphore.acquire())
57+
except curio.TaskTimeout:
58+
raise self.exc_class()
59+
60+
async def release(self) -> None:
61+
await self.semaphore.release()
62+
63+
64+
class SocketStream(AsyncSocketStream):
65+
def __init__(self, socket: curio.io.Socket) -> None:
66+
self.read_lock = curio.Lock()
67+
self.write_lock = curio.Lock()
68+
self.socket = socket
69+
self.stream = socket.as_stream()
70+
71+
def get_http_version(self) -> str:
72+
if hasattr(self.socket, "_socket"):
73+
raw_socket = self.socket._socket
74+
75+
if isinstance(raw_socket, SSLSocket):
76+
ident = raw_socket.selected_alpn_protocol()
77+
return "HTTP/2" if ident == "h2" else "HTTP/1.1"
78+
79+
return "HTTP/1.1"
80+
81+
async def start_tls(
82+
self, hostname: bytes, ssl_context: SSLContext, timeout: TimeoutDict
83+
) -> "AsyncSocketStream":
84+
connect_timeout = convert_timeout(timeout.get("connect"))
85+
exc_map = {
86+
curio.TaskTimeout: ConnectTimeout,
87+
curio.CurioError: ConnectError,
88+
OSError: ConnectError,
89+
}
90+
91+
with map_exceptions(exc_map):
92+
wrapped_sock = curio.io.Socket(
93+
ssl_context.wrap_socket(
94+
self.socket._socket,
95+
do_handshake_on_connect=False,
96+
server_hostname=hostname.decode("ascii"),
97+
)
98+
)
99+
100+
await curio.timeout_after(
101+
connect_timeout,
102+
wrapped_sock.do_handshake(),
103+
)
104+
105+
return SocketStream(wrapped_sock)
106+
107+
async def read(self, n: int, timeout: TimeoutDict) -> bytes:
108+
read_timeout = convert_timeout(timeout.get("read"))
109+
exc_map = {
110+
curio.TaskTimeout: ReadTimeout,
111+
curio.CurioError: ReadError,
112+
OSError: ReadError,
113+
}
114+
115+
with map_exceptions(exc_map):
116+
async with self.read_lock:
117+
return await curio.timeout_after(read_timeout, self.stream.read(n))
118+
119+
async def write(self, data: bytes, timeout: TimeoutDict) -> None:
120+
write_timeout = convert_timeout(timeout.get("write"))
121+
exc_map = {
122+
curio.TaskTimeout: WriteTimeout,
123+
curio.CurioError: WriteError,
124+
OSError: WriteError,
125+
}
126+
127+
with map_exceptions(exc_map):
128+
async with self.write_lock:
129+
await curio.timeout_after(write_timeout, self.stream.write(data))
130+
131+
async def aclose(self) -> None:
132+
await self.stream.close()
133+
await self.socket.close()
134+
135+
def is_connection_dropped(self) -> bool:
136+
rready, _, _ = select.select([self.socket.fileno()], [], [], 0)
137+
138+
return bool(rready)
139+
140+
141+
class CurioBackend(AsyncBackend):
142+
async def open_tcp_stream(
143+
self,
144+
hostname: bytes,
145+
port: int,
146+
ssl_context: Optional[SSLContext],
147+
timeout: TimeoutDict,
148+
*,
149+
local_address: Optional[str],
150+
) -> AsyncSocketStream:
151+
connect_timeout = convert_timeout(timeout.get("connect"))
152+
exc_map = {
153+
curio.TaskTimeout: ConnectTimeout,
154+
curio.CurioError: ConnectError,
155+
OSError: ConnectError,
156+
}
157+
host = hostname.decode("ascii")
158+
kwargs = (
159+
{} if ssl_context is None else {"ssl": ssl_context, "server_hostname": host}
160+
)
161+
162+
with map_exceptions(exc_map):
163+
sock: curio.io.Socket = await curio.timeout_after(
164+
connect_timeout,
165+
curio.open_connection(hostname, port, **kwargs),
166+
)
167+
168+
return SocketStream(sock)
169+
170+
async def open_uds_stream(
171+
self,
172+
path: str,
173+
hostname: bytes,
174+
ssl_context: Optional[SSLContext],
175+
timeout: TimeoutDict,
176+
) -> AsyncSocketStream:
177+
connect_timeout = convert_timeout(timeout.get("connect"))
178+
exc_map = {
179+
curio.TaskTimeout: ConnectTimeout,
180+
curio.CurioError: ConnectError,
181+
OSError: ConnectError,
182+
}
183+
host = hostname.decode("ascii")
184+
kwargs = (
185+
{} if ssl_context is None else {"ssl": ssl_context, "server_hostname": host}
186+
)
187+
188+
with map_exceptions(exc_map):
189+
sock: curio.io.Socket = await curio.timeout_after(
190+
connect_timeout, curio.open_unix_connection(path, **kwargs)
191+
)
192+
193+
return SocketStream(sock)
194+
195+
def create_lock(self) -> AsyncLock:
196+
return Lock()
197+
198+
def create_semaphore(self, max_value: int, exc_class: type) -> AsyncSemaphore:
199+
return Semaphore(max_value, exc_class)
200+
201+
async def time(self) -> float:
202+
return await curio.clock()

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# Optionals
44
trio
55
trio-typing
6+
curio
67

78
# Docs
89
mkautodoc

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def get_packages(package):
6666
"Topic :: Internet :: WWW/HTTP",
6767
"Framework :: AsyncIO",
6868
"Framework :: Trio",
69+
"Framework :: Curio",
6970
"Programming Language :: Python :: 3",
7071
"Programming Language :: Python :: 3.6",
7172
"Programming Language :: Python :: 3.7",

tests/conftest.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,22 +14,42 @@
1414

1515
from httpcore._types import URL
1616

17+
from .marks.curio import curio_pytest_pycollect_makeitem, curio_pytest_pyfunc_call
18+
1719
PROXY_HOST = "127.0.0.1"
1820
PROXY_PORT = 8080
1921

2022

23+
def pytest_configure(config):
24+
config.addinivalue_line(
25+
"markers",
26+
"curio: mark the test as a coroutine, it will be run using a Curio kernel.",
27+
)
28+
29+
30+
@pytest.mark.tryfirst
31+
def pytest_pycollect_makeitem(collector, name, obj):
32+
curio_pytest_pycollect_makeitem(collector, name, obj)
33+
34+
35+
@pytest.hookimpl(tryfirst=True, hookwrapper=True)
36+
def pytest_pyfunc_call(pyfuncitem):
37+
yield from curio_pytest_pyfunc_call(pyfuncitem)
38+
39+
2140
@pytest.fixture(
2241
params=[
2342
pytest.param("asyncio", marks=pytest.mark.asyncio),
2443
pytest.param("trio", marks=pytest.mark.trio),
44+
pytest.param("curio", marks=pytest.mark.curio),
2545
]
2646
)
2747
def async_environment(request: typing.Any) -> str:
2848
"""
29-
Mark a test function to be run on both asyncio and trio.
49+
Mark a test function to be run on asyncio, trio and curio.
3050
31-
Equivalent to having a pair of tests, each respectively marked with
32-
'@pytest.mark.asyncio' and '@pytest.mark.trio'.
51+
Equivalent to having three tests, each respectively marked with
52+
'@pytest.mark.asyncio', '@pytest.mark.trio' and '@pytest.mark.curio'.
3353
3454
Intended usage:
3555

tests/marks/__init__.py

Whitespace-only changes.

tests/marks/curio.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import functools
2+
import inspect
3+
4+
import curio
5+
import curio.debug
6+
import curio.meta
7+
import curio.monitor
8+
import pytest
9+
10+
11+
def _is_coroutine(obj):
12+
"""Check to see if an object is really a coroutine."""
13+
return curio.meta.iscoroutinefunction(obj) or inspect.isgeneratorfunction(obj)
14+
15+
16+
@pytest.mark.tryfirst
17+
def curio_pytest_pycollect_makeitem(collector, name, obj):
18+
"""A pytest hook to collect coroutines in a test module."""
19+
if collector.funcnamefilter(name) and _is_coroutine(obj):
20+
item = pytest.Function.from_parent(collector, name=name)
21+
if "curio" in item.keywords:
22+
return list(collector._genfunctions(name, obj)) # pragma: nocover
23+
24+
25+
@pytest.hookimpl(tryfirst=True, hookwrapper=True)
26+
def curio_pytest_pyfunc_call(pyfuncitem):
27+
"""Run curio marked test functions in a Curio kernel
28+
instead of a normal function call."""
29+
if pyfuncitem.get_closest_marker("curio"):
30+
pyfuncitem.obj = wrap_in_sync(pyfuncitem.obj)
31+
yield
32+
33+
34+
def wrap_in_sync(func):
35+
"""Return a sync wrapper around an async function executing it in a Kernel."""
36+
37+
@functools.wraps(func)
38+
def inner(**kwargs):
39+
coro = func(**kwargs)
40+
curio.Kernel().run(coro, shutdown=True)
41+
42+
return inner

unasync.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
('__aiter__', '__iter__'),
2121
('@pytest.mark.asyncio', ''),
2222
('@pytest.mark.trio', ''),
23+
('@pytest.mark.curio', ''),
2324
('@pytest.mark.usefixtures.*', ''),
2425
]
2526
COMPILED_SUBS = [

0 commit comments

Comments
 (0)