Skip to content

Commit c066d3e

Browse files
committed
Add uds arg to BaseClient and select tcp or uds in HttpConnection
1 parent 1e40664 commit c066d3e

7 files changed

Lines changed: 69 additions & 4 deletions

File tree

httpx/client.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ def __init__(
7474
app: typing.Callable = None,
7575
backend: ConcurrencyBackend = None,
7676
trust_env: bool = True,
77+
uds: str = None,
7778
):
7879
if backend is None:
7980
backend = AsyncioBackend()
@@ -99,6 +100,7 @@ def __init__(
99100
pool_limits=pool_limits,
100101
backend=backend,
101102
trust_env=self.trust_env,
103+
uds=uds,
102104
)
103105
elif isinstance(dispatch, Dispatcher):
104106
async_dispatch = ThreadedDispatcher(dispatch, backend)
@@ -721,6 +723,7 @@ class Client(BaseClient):
721723
async requests.
722724
* **trust_env** - *(optional)* Enables or disables usage of environment
723725
variables for configuration.
726+
* **uds** - *(optional)* A path to a Unix domain socket to connect through
724727
"""
725728

726729
def check_concurrency_backend(self, backend: ConcurrencyBackend) -> None:

httpx/concurrency/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ async def open_uds_stream(
130130
hostname: typing.Optional[str],
131131
ssl_context: typing.Optional[ssl.SSLContext],
132132
timeout: TimeoutConfig,
133-
) -> BaseTCPStream:
133+
) -> BaseSocketStream:
134134
raise NotImplementedError() # pragma: no cover
135135

136136
def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore:

httpx/dispatch/connection.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,15 @@ def __init__(
3838
http_versions: HTTPVersionTypes = None,
3939
backend: ConcurrencyBackend = None,
4040
release_func: typing.Optional[ReleaseCallback] = None,
41+
uds: typing.Optional[str] = None,
4142
):
4243
self.origin = Origin(origin) if isinstance(origin, str) else origin
4344
self.ssl = SSLConfig(cert=cert, verify=verify, trust_env=trust_env)
4445
self.timeout = TimeoutConfig(timeout)
4546
self.http_versions = HTTPVersionConfig(http_versions)
4647
self.backend = AsyncioBackend() if backend is None else backend
4748
self.release_func = release_func
49+
self.uds = uds
4850
self.h11_connection = None # type: typing.Optional[HTTP11Connection]
4951
self.h2_connection = None # type: typing.Optional[HTTP2Connection]
5052

@@ -84,8 +86,21 @@ async def connect(
8486
else:
8587
on_release = functools.partial(self.release_func, self)
8688

87-
logger.trace(f"start_connect host={host!r} port={port!r} timeout={timeout!r}")
88-
stream = await self.backend.open_tcp_stream(host, port, ssl_context, timeout)
89+
if self.uds is None:
90+
logger.trace(
91+
f"start_connect tcp host={host!r} port={port!r} timeout={timeout!r}"
92+
)
93+
stream = await self.backend.open_tcp_stream(
94+
host, port, ssl_context, timeout
95+
)
96+
else:
97+
logger.trace(
98+
f"start_connect uds path={self.uds!r} host={host!r} timeout={timeout!r}"
99+
)
100+
stream = await self.backend.open_uds_stream(
101+
self.uds, host, ssl_context, timeout
102+
)
103+
89104
http_version = stream.get_http_version()
90105
logger.trace(f"connected http_version={http_version!r}")
91106

httpx/dispatch/connection_pool.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def __init__(
8989
pool_limits: PoolLimits = DEFAULT_POOL_LIMITS,
9090
http_versions: HTTPVersionTypes = None,
9191
backend: ConcurrencyBackend = None,
92+
uds: typing.Optional[str] = None,
9293
):
9394
self.verify = verify
9495
self.cert = cert
@@ -97,6 +98,7 @@ def __init__(
9798
self.http_versions = http_versions
9899
self.is_closed = False
99100
self.trust_env = trust_env
101+
self.uds = uds
100102

101103
self.keepalive_connections = ConnectionStore()
102104
self.active_connections = ConnectionStore()
@@ -142,6 +144,7 @@ async def acquire_connection(self, origin: Origin) -> HTTPConnection:
142144
backend=self.backend,
143145
release_func=self.release_connection,
144146
trust_env=self.trust_env,
147+
uds=self.uds,
145148
)
146149
logger.trace(f"new_connection connection={connection!r}")
147150
else:

tests/client/test_async_client.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,3 +146,17 @@ async def test_100_continue(server, backend):
146146

147147
assert response.status_code == 200
148148
assert response.content == data
149+
150+
151+
async def test_uds(uds_server, backend):
152+
url = uds_server.url
153+
uds = uds_server.config.uds
154+
assert uds is not None
155+
async with httpx.AsyncClient(backend=backend, uds=uds) as client:
156+
response = await client.get(url)
157+
assert response.status_code == 200
158+
assert response.text == "Hello, world!"
159+
assert response.http_version == "HTTP/1.1"
160+
assert response.headers
161+
assert repr(response) == "<Response [200 OK]>"
162+
assert response.elapsed > timedelta(seconds=0)

tests/client/test_client.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,25 @@ def test_base_url(server):
138138
assert response.url == base_url
139139

140140

141+
def test_uds(uds_server):
142+
url = uds_server.url
143+
uds = uds_server.config.uds
144+
assert uds is not None
145+
with httpx.Client(uds=uds) as http:
146+
response = http.get(url)
147+
assert response.status_code == 200
148+
assert response.url == url
149+
assert response.content == b"Hello, world!"
150+
assert response.text == "Hello, world!"
151+
assert response.http_version == "HTTP/1.1"
152+
assert response.encoding == "iso-8859-1"
153+
assert response.request.url == url
154+
assert response.headers
155+
assert response.is_redirect is False
156+
assert repr(response) == "<Response [200 OK]>"
157+
assert response.elapsed > timedelta(0)
158+
159+
141160
def test_merge_url():
142161
client = httpx.Client(base_url="https://www.paypal.com/")
143162
url = client.merge_url("http://www.paypal.com")

tests/conftest.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,15 @@ def server():
288288
yield from serve_in_thread(server)
289289

290290

291+
@pytest.fixture(scope=SERVER_SCOPE)
292+
def uds_server():
293+
uds = "test_server.sock"
294+
config = Config(app=app, lifespan="off", loop="asyncio", uds=uds)
295+
server = TestServer(config=config)
296+
yield from serve_in_thread(server)
297+
os.remove(uds)
298+
299+
291300
@pytest.fixture(scope=SERVER_SCOPE)
292301
def https_server(cert_pem_file, cert_private_key_file):
293302
config = Config(
@@ -305,13 +314,15 @@ def https_server(cert_pem_file, cert_private_key_file):
305314

306315
@pytest.fixture(scope=SERVER_SCOPE)
307316
def https_uds_server(cert_pem_file, cert_private_key_file):
317+
uds = "https_test_server.sock"
308318
config = Config(
309319
app=app,
310320
lifespan="off",
311321
ssl_certfile=cert_pem_file,
312322
ssl_keyfile=cert_private_key_file,
313-
uds="https_test_server.sock",
323+
uds=uds,
314324
loop="asyncio",
315325
)
316326
server = TestServer(config=config)
317327
yield from serve_in_thread(server)
328+
os.remove(uds)

0 commit comments

Comments
 (0)