Skip to content
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 27 additions & 14 deletions httpcore/_async/connection.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from ssl import SSLContext
from typing import List, Optional, Tuple, Union

from .._backends.auto import AsyncLock, AutoBackend
from .._backends.auto import AsyncLock, AsyncSocketStream, AutoBackend
from .._types import URL, Headers, Origin, TimeoutDict
from .base import (
AsyncByteStream,
Expand All @@ -15,11 +15,16 @@

class AsyncHTTPConnection(AsyncHTTPTransport):
def __init__(
self, origin: Origin, http2: bool = False, ssl_context: SSLContext = None,
self,
origin: Origin,
http2: bool = False,
ssl_context: SSLContext = None,
socket: AsyncSocketStream = None,
):
self.origin = origin
self.http2 = http2
self.ssl_context = SSLContext() if ssl_context is None else ssl_context
self.socket = socket

if self.http2:
self.ssl_context.set_alpn_protocols(["http/1.1", "h2"])
Expand Down Expand Up @@ -48,14 +53,11 @@ async def request(
timeout: TimeoutDict = None,
) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]], AsyncByteStream]:
assert url[:3] == self.origin

async with self.request_lock:
if self.state == ConnectionState.PENDING:
try:
await self._connect(timeout)
except Exception:
self.connect_failed = True
raise
if not self.socket:
self.socket = await self._open_socket(timeout)
self._create_connection(self.socket)
elif self.state in (ConnectionState.READY, ConnectionState.IDLE):
pass
elif self.state == ConnectionState.ACTIVE and self.is_http2:
Expand All @@ -66,20 +68,30 @@ async def request(
assert self.connection is not None
return await self.connection.request(method, url, headers, stream, timeout)

async def _connect(self, timeout: TimeoutDict = None) -> None:
async def _open_socket(self, timeout: TimeoutDict = None) -> AsyncSocketStream:
scheme, hostname, port = self.origin
timeout = {} if timeout is None else timeout
ssl_context = self.ssl_context if scheme == b"https" else None
socket = await self.backend.open_tcp_stream(
hostname, port, ssl_context, timeout
)
try:
return await self.backend.open_tcp_stream(
hostname, port, ssl_context, timeout
)
except Exception:
self.connect_failed = True
raise

def _create_connection(self, socket: AsyncSocketStream) -> None:
http_version = socket.get_http_version()
if http_version == "HTTP/2":
self.is_http2 = True
self.connection = AsyncHTTP2Connection(socket=socket, backend=self.backend)
self.connection = AsyncHTTP2Connection(
socket=socket, backend=self.backend, ssl_context=self.ssl_context
Comment thread
florimondmanca marked this conversation as resolved.
)
else:
self.is_http11 = True
self.connection = AsyncHTTP11Connection(socket=socket)
self.connection = AsyncHTTP11Connection(
socket=socket, ssl_context=self.ssl_context
)

@property
def state(self) -> ConnectionState:
Expand All @@ -99,3 +111,4 @@ def mark_as_ready(self) -> None:
async def start_tls(self, hostname: bytes, timeout: TimeoutDict = None) -> None:
if self.connection is not None:
await self.connection.start_tls(hostname, timeout)
self.socket = self.connection.socket
Comment thread
florimondmanca marked this conversation as resolved.
2 changes: 1 addition & 1 deletion httpcore/_async/http11.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ async def _receive_response_data(
event = await self._receive_event(timeout)
if isinstance(event, h11.Data):
yield bytes(event.data)
elif isinstance(event, h11.EndOfMessage):
elif isinstance(event, (h11.EndOfMessage, h11.PAUSED)):
Comment thread
yeraydiazdiaz marked this conversation as resolved.
break

async def _receive_event(self, timeout: TimeoutDict) -> H11Event:
Expand Down
54 changes: 25 additions & 29 deletions httpcore/_async/http_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,6 @@
from .connection_pool import AsyncConnectionPool, ResponseByteStream


async def read_body(stream: AsyncByteStream) -> bytes:
try:
return b"".join([chunk async for chunk in stream])
finally:
await stream.aclose()


class AsyncHTTPProxy(AsyncConnectionPool):
"""
A connection pool for making HTTP requests via an HTTP proxy.
Expand All @@ -26,7 +19,8 @@ class AsyncHTTPProxy(AsyncConnectionPool):
* **proxy_headers** - `Optional[List[Tuple[bytes, bytes]]]` - A list of
proxy headers to include.
* **proxy_mode** - `str` - A proxy mode to operate in. May be "DEFAULT",
"FORWARD_ONLY", or "TUNNEL_ONLY".
Comment thread
yeraydiazdiaz marked this conversation as resolved.
"FORWARD_ONLY", or "TUNNEL_ONLY". "DEFAULT" is identical to "FORWARD_ONLY"
but is kept for backward compatibility purposes.
* **ssl_context** - `Optional[SSLContext]` - An SSL context to use for
verifying connections.
* **max_connections** - `Optional[int]` - The maximum number of concurrent
Expand All @@ -39,8 +33,8 @@ class AsyncHTTPProxy(AsyncConnectionPool):
def __init__(
self,
proxy_origin: Origin,
proxy_mode: str,
proxy_headers: Headers = None,
proxy_mode: str = "DEFAULT",
Comment thread
yeraydiazdiaz marked this conversation as resolved.
ssl_context: SSLContext = None,
max_connections: int = None,
max_keepalive: int = None,
Expand Down Expand Up @@ -140,47 +134,49 @@ async def _tunnel_request(
connection = await self._get_connection_from_pool(origin)

if connection is None:
connection = AsyncHTTPConnection(
origin=origin, http2=False, ssl_context=self._ssl_context,
# First, create a connection to the proxy server
proxy_connection = AsyncHTTPConnection(
origin=self.proxy_origin, http2=False, ssl_context=self._ssl_context,
)
async with self._thread_lock:
self._connections.setdefault(origin, set())
self._connections[origin].add(connection)

# Establish the connection by issuing a CONNECT request...
# Issue a CONNECT request...

# CONNECT www.example.org:80 HTTP/1.1
# [proxy-headers]
target = b"%b:%d" % (url[1], url[2])
connect_url = self.proxy_origin + (target,)
connect_headers = self.proxy_headers
proxy_response = await connection.request(
b"CONNECT", connect_url, headers=connect_headers, timeout=timeout
proxy_response = await proxy_connection.request(
b"CONNECT", connect_url, headers=self.proxy_headers, timeout=timeout
)
proxy_status_code = proxy_response[1]
proxy_reason_phrase = proxy_response[2]
proxy_stream = proxy_response[4]

# Ingest any request body.
await read_body(proxy_stream)
# Read the response data without closing the socket
async for _ in proxy_stream:
pass

# If the proxy responds with an error, then drop the connection
# from the pool, and raise an exception.
# See if the tunnel was successfully established.
if proxy_status_code < 200 or proxy_status_code > 299:
async with self._thread_lock:
self._connections[connection.origin].remove(connection)
if not self._connections[connection.origin]:
del self._connections[connection.origin]
msg = "%d %s" % (proxy_status_code, proxy_reason_phrase.decode("ascii"))
raise ProxyError(msg)

# Upgrade to TLS.
await connection.start_tls(target, timeout)
# The CONNECT request is successful, so we have now SWITCHED PROTOCOLS.
# This means the proxy connection is now unusable, and we must create
# a new one for regular requests, making sure to use the same socket to
# retain the tunnel.
connection = AsyncHTTPConnection(
origin=origin,
http2=False,
ssl_context=self._ssl_context,
socket=proxy_connection.socket,
)
await self._add_to_pool(connection)

# Once the connection has been established we can send requests on
# it as normal.
response = await connection.request(
method, url, headers=headers, stream=stream, timeout=timeout
method, url, headers=headers, stream=stream, timeout=timeout,
)
wrapped_stream = ResponseByteStream(
response[4], connection=connection, callback=self._response_closed
Expand Down
6 changes: 3 additions & 3 deletions httpcore/_backends/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,9 @@ async def start_tls(

transport = await asyncio.wait_for(
loop_start_tls(
transport=transport,
protocol=protocol,
sslcontext=ssl_context,
transport,
protocol,
ssl_context,
Comment thread
florimondmanca marked this conversation as resolved.
server_hostname=hostname.decode("ascii"),
),
timeout=timeout.get("connect"),
Expand Down
41 changes: 27 additions & 14 deletions httpcore/_sync/connection.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from ssl import SSLContext
from typing import List, Optional, Tuple, Union

from .._backends.auto import SyncLock, SyncBackend
from .._backends.auto import SyncLock, SyncSocketStream, SyncBackend
from .._types import URL, Headers, Origin, TimeoutDict
from .base import (
SyncByteStream,
Expand All @@ -15,11 +15,16 @@

class SyncHTTPConnection(SyncHTTPTransport):
def __init__(
self, origin: Origin, http2: bool = False, ssl_context: SSLContext = None,
self,
origin: Origin,
http2: bool = False,
ssl_context: SSLContext = None,
socket: SyncSocketStream = None,
):
self.origin = origin
self.http2 = http2
self.ssl_context = SSLContext() if ssl_context is None else ssl_context
self.socket = socket

if self.http2:
self.ssl_context.set_alpn_protocols(["http/1.1", "h2"])
Expand Down Expand Up @@ -48,14 +53,11 @@ def request(
timeout: TimeoutDict = None,
) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]], SyncByteStream]:
assert url[:3] == self.origin

with self.request_lock:
if self.state == ConnectionState.PENDING:
try:
self._connect(timeout)
except Exception:
self.connect_failed = True
raise
if not self.socket:
self.socket = self._open_socket(timeout)
self._create_connection(self.socket)
elif self.state in (ConnectionState.READY, ConnectionState.IDLE):
pass
elif self.state == ConnectionState.ACTIVE and self.is_http2:
Expand All @@ -66,20 +68,30 @@ def request(
assert self.connection is not None
return self.connection.request(method, url, headers, stream, timeout)

def _connect(self, timeout: TimeoutDict = None) -> None:
def _open_socket(self, timeout: TimeoutDict = None) -> SyncSocketStream:
scheme, hostname, port = self.origin
timeout = {} if timeout is None else timeout
ssl_context = self.ssl_context if scheme == b"https" else None
socket = self.backend.open_tcp_stream(
hostname, port, ssl_context, timeout
)
try:
return self.backend.open_tcp_stream(
hostname, port, ssl_context, timeout
)
except Exception:
self.connect_failed = True
raise

def _create_connection(self, socket: SyncSocketStream) -> None:
http_version = socket.get_http_version()
if http_version == "HTTP/2":
self.is_http2 = True
self.connection = SyncHTTP2Connection(socket=socket, backend=self.backend)
self.connection = SyncHTTP2Connection(
socket=socket, backend=self.backend, ssl_context=self.ssl_context
)
else:
self.is_http11 = True
self.connection = SyncHTTP11Connection(socket=socket)
self.connection = SyncHTTP11Connection(
socket=socket, ssl_context=self.ssl_context
)

@property
def state(self) -> ConnectionState:
Expand All @@ -99,3 +111,4 @@ def mark_as_ready(self) -> None:
def start_tls(self, hostname: bytes, timeout: TimeoutDict = None) -> None:
if self.connection is not None:
self.connection.start_tls(hostname, timeout)
self.socket = self.connection.socket
2 changes: 1 addition & 1 deletion httpcore/_sync/http11.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def _receive_response_data(
event = self._receive_event(timeout)
if isinstance(event, h11.Data):
yield bytes(event.data)
elif isinstance(event, h11.EndOfMessage):
elif isinstance(event, (h11.EndOfMessage, h11.PAUSED)):
break

def _receive_event(self, timeout: TimeoutDict) -> H11Event:
Expand Down
Loading