Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 38 additions & 38 deletions httpx/concurrency/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,44 @@ def __init__(
self.stream_writer = stream_writer
self.timeout = timeout

self._inner: typing.Optional[TCPStream] = None

async def start_tls(
self, hostname: str, ssl_context: ssl.SSLContext, timeout: TimeoutConfig
) -> BaseTCPStream:
loop = asyncio.get_event_loop()
if not hasattr(loop, "start_tls"): # pragma: no cover
raise NotImplementedError(
"asyncio.AbstractEventLoop.start_tls() is only available in Python 3.7+"
)

stream_reader = asyncio.StreamReader()
protocol = asyncio.StreamReaderProtocol(stream_reader)
transport = self.stream_writer.transport

loop_start_tls = loop.start_tls # type: ignore
transport = await asyncio.wait_for(
loop_start_tls(
transport=transport,
protocol=protocol,
sslcontext=ssl_context,
server_hostname=hostname,
),
timeout=timeout.connect_timeout,
)

stream_reader.set_transport(transport)
stream_writer = asyncio.StreamWriter(
transport=transport, protocol=protocol, reader=stream_reader, loop=loop
)

ssl_stream = TCPStream(stream_reader, stream_writer, self.timeout)
# When we return a new TCPStream with new StreamReader/StreamWriter instances,
# we need to keep references to the old StreamReader/StreamWriter so that they
# are not garbage collected and closed while we're still using them.
ssl_stream._inner = self
return ssl_stream

def get_http_version(self) -> str:
ssl_object = self.stream_writer.get_extra_info("ssl_object")

Expand Down Expand Up @@ -201,44 +239,6 @@ async def open_tcp_stream(
stream_reader=stream_reader, stream_writer=stream_writer, timeout=timeout
)

async def start_tls(
self,
stream: BaseTCPStream,
hostname: str,
ssl_context: ssl.SSLContext,
timeout: TimeoutConfig,
) -> BaseTCPStream:

loop = self.loop
if not hasattr(loop, "start_tls"): # pragma: no cover
raise NotImplementedError(
"asyncio.AbstractEventLoop.start_tls() is only available in Python 3.7+"
)

assert isinstance(stream, TCPStream)

stream_reader = asyncio.StreamReader()
protocol = asyncio.StreamReaderProtocol(stream_reader)
transport = stream.stream_writer.transport

loop_start_tls = loop.start_tls # type: ignore
transport = await asyncio.wait_for(
loop_start_tls(
transport=transport,
protocol=protocol,
sslcontext=ssl_context,
server_hostname=hostname,
),
timeout=timeout.connect_timeout,
)

stream_reader.set_transport(transport)
stream.stream_reader = stream_reader
stream.stream_writer = asyncio.StreamWriter(
transport=transport, protocol=protocol, reader=stream_reader, loop=loop
)
return stream

async def run_in_threadpool(
self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
) -> typing.Any:
Expand Down
14 changes: 5 additions & 9 deletions httpx/concurrency/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ class BaseTCPStream:
def get_http_version(self) -> str:
raise NotImplementedError() # pragma: no cover

async def start_tls(
self, hostname: str, ssl_context: ssl.SSLContext, timeout: TimeoutConfig
) -> "BaseTCPStream":
raise NotImplementedError() # pragma: no cover

async def read(
self, n: int, timeout: TimeoutConfig = None, flag: typing.Any = None
) -> bytes:
Expand Down Expand Up @@ -119,15 +124,6 @@ async def open_tcp_stream(
) -> BaseTCPStream:
raise NotImplementedError() # pragma: no cover

async def start_tls(
self,
stream: BaseTCPStream,
hostname: str,
ssl_context: ssl.SSLContext,
timeout: TimeoutConfig,
) -> BaseTCPStream:
raise NotImplementedError() # pragma: no cover

def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore:
raise NotImplementedError() # pragma: no cover

Expand Down
44 changes: 20 additions & 24 deletions httpx/concurrency/trio.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,26 @@ def __init__(
self.write_buffer = b""
self.write_lock = trio.Lock()

async def start_tls(
self, hostname: str, ssl_context: ssl.SSLContext, timeout: TimeoutConfig
) -> BaseTCPStream:
# Check that the write buffer is empty. We should never start a TLS stream
# while there is still pending data to write.
assert self.write_buffer == b""

connect_timeout = _or_inf(timeout.connect_timeout)
ssl_stream = trio.SSLStream(
self.stream, ssl_context=ssl_context, server_hostname=hostname
)

with trio.move_on_after(connect_timeout) as cancel_scope:
await ssl_stream.do_handshake()

if cancel_scope.cancelled_caught:
raise ConnectTimeout()

return TCPStream(ssl_stream, self.timeout)

def get_http_version(self) -> str:
if not isinstance(self.stream, trio.SSLStream):
return "HTTP/1.1"
Expand Down Expand Up @@ -171,30 +191,6 @@ async def open_tcp_stream(

return TCPStream(stream=stream, timeout=timeout)

async def start_tls(
self,
stream: BaseTCPStream,
hostname: str,
ssl_context: ssl.SSLContext,
timeout: TimeoutConfig,
) -> BaseTCPStream:
assert isinstance(stream, TCPStream)

connect_timeout = _or_inf(timeout.connect_timeout)
ssl_stream = trio.SSLStream(
stream.stream, ssl_context=ssl_context, server_hostname=hostname
)

with trio.move_on_after(connect_timeout) as cancel_scope:
await ssl_stream.do_handshake()

if cancel_scope.cancelled_caught:
raise ConnectTimeout()

stream.stream = ssl_stream

return stream

async def run_in_threadpool(
self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
) -> typing.Any:
Expand Down
7 changes: 2 additions & 5 deletions httpx/dispatch/proxy_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,11 +192,8 @@ async def tunnel_start_tls(
f"proxy_url={self.proxy_url!r} "
f"origin={origin!r}"
)
stream = await self.backend.start_tls(
stream=stream,
hostname=origin.host,
ssl_context=ssl_context,
timeout=timeout,
stream = await stream.start_tls(
hostname=origin.host, ssl_context=ssl_context, timeout=timeout
)
http_version = stream.get_http_version()
logger.debug(
Expand Down
16 changes: 6 additions & 10 deletions tests/dispatch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,16 +184,6 @@ async def open_tcp_stream(
)
return self.stream

async def start_tls(
self,
stream: BaseTCPStream,
hostname: str,
ssl_context: ssl.SSLContext,
timeout: TimeoutConfig,
) -> BaseTCPStream:
self.received_data.append(b"--- START_TLS(%s) ---" % hostname.encode())
return self.stream

# Defer all other attributes and methods to the underlying backend.
def __getattr__(self, name: str) -> typing.Any:
return getattr(self.backend, name)
Expand All @@ -203,6 +193,12 @@ class MockRawSocketStream(BaseTCPStream):
def __init__(self, backend: MockRawSocketBackend):
self.backend = backend

async def start_tls(
self, hostname: str, ssl_context: ssl.SSLContext, timeout: TimeoutConfig
) -> BaseTCPStream:
self.backend.received_data.append(b"--- START_TLS(%s) ---" % hostname.encode())
return MockRawSocketStream(self.backend)

def get_http_version(self) -> str:
return "HTTP/1.1"

Expand Down
2 changes: 1 addition & 1 deletion tests/test_concurrency.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ async def test_start_tls_on_socket_stream(https_server, backend, get_cipher):
assert stream.is_connection_dropped() is False
assert get_cipher(stream) is None

stream = await backend.start_tls(stream, https_server.url.host, ctx, timeout)
stream = await stream.start_tls(https_server.url.host, ctx, timeout)
assert stream.is_connection_dropped() is False
assert get_cipher(stream) is not None

Expand Down