Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion httpx/concurrency/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ async def write(
)
break
except asyncio.TimeoutError:
# We check our flag at the possible moment, in order to
# We check our flag at the first possible moment, in order to
# allow us to suppress write timeouts, if we've since
# switched over to read-timeout mode.
should_raise = flag is None or flag.raise_on_write_timeout
Expand Down
33 changes: 0 additions & 33 deletions tests/client/test_client.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,8 @@
import asyncio
import functools

import pytest

import httpx


def threadpool(func):
"""
Our sync tests should run in separate thread to the uvicorn server.
"""

@functools.wraps(func)
async def wrapped(*args, **kwargs):
nonlocal func

loop = asyncio.get_event_loop()
if kwargs:
func = functools.partial(func, **kwargs)
await loop.run_in_executor(None, func, *args)

return pytest.mark.asyncio(wrapped)


@threadpool
def test_get(server):
url = "http://127.0.0.1:8000/"
with httpx.Client() as http:
Expand All @@ -40,23 +19,20 @@ def test_get(server):
assert repr(response) == "<Response [200 OK]>"


@threadpool
def test_post(server):
with httpx.Client() as http:
response = http.post("http://127.0.0.1:8000/", data=b"Hello, world!")
assert response.status_code == 200
assert response.reason_phrase == "OK"


@threadpool
def test_post_json(server):
with httpx.Client() as http:
response = http.post("http://127.0.0.1:8000/", json={"text": "Hello, world!"})
assert response.status_code == 200
assert response.reason_phrase == "OK"


@threadpool
def test_stream_response(server):
with httpx.Client() as http:
response = http.get("http://127.0.0.1:8000/", stream=True)
Expand All @@ -65,7 +41,6 @@ def test_stream_response(server):
assert content == b"Hello, world!"


@threadpool
def test_stream_iterator(server):
with httpx.Client() as http:
response = http.get("http://127.0.0.1:8000/", stream=True)
Expand All @@ -76,7 +51,6 @@ def test_stream_iterator(server):
assert body == b"Hello, world!"


@threadpool
def test_raw_iterator(server):
with httpx.Client() as http:
response = http.get("http://127.0.0.1:8000/", stream=True)
Expand All @@ -88,7 +62,6 @@ def test_raw_iterator(server):
response.close() # TODO: should Response be available as context managers?


@threadpool
def test_raise_for_status(server):
with httpx.Client() as client:
for status_code in (200, 400, 404, 500, 505):
Expand All @@ -103,47 +76,41 @@ def test_raise_for_status(server):
assert response.raise_for_status() is None


@threadpool
def test_options(server):
with httpx.Client() as http:
response = http.options("http://127.0.0.1:8000/")
assert response.status_code == 200
assert response.reason_phrase == "OK"


@threadpool
def test_head(server):
with httpx.Client() as http:
response = http.head("http://127.0.0.1:8000/")
assert response.status_code == 200
assert response.reason_phrase == "OK"


@threadpool
def test_put(server):
with httpx.Client() as http:
response = http.put("http://127.0.0.1:8000/", data=b"Hello, world!")
assert response.status_code == 200
assert response.reason_phrase == "OK"


@threadpool
def test_patch(server):
with httpx.Client() as http:
response = http.patch("http://127.0.0.1:8000/", data=b"Hello, world!")
assert response.status_code == 200
assert response.reason_phrase == "OK"


@threadpool
def test_delete(server):
with httpx.Client() as http:
response = http.delete("http://127.0.0.1:8000/")
assert response.status_code == 200
assert response.reason_phrase == "OK"


@threadpool
def test_base_url(server):
base_url = "http://127.0.0.1:8000/"
with httpx.Client(base_url=base_url) as http:
Expand Down
119 changes: 82 additions & 37 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import asyncio
import threading
import time

import pytest
import trustme
Expand Down Expand Up @@ -101,76 +103,119 @@ def encrypted_private_key_pem(self):
)


@pytest.fixture
SERVER_SCOPE = "session"


@pytest.fixture(scope=SERVER_SCOPE)
def example_cert():
ca = CAWithPKEncryption()
ca.issue_cert("example.org")
return ca


@pytest.fixture
@pytest.fixture(scope=SERVER_SCOPE)
def cert_pem_file(example_cert):
with example_cert.cert_pem.tempfile() as tmp:
yield tmp


@pytest.fixture
@pytest.fixture(scope=SERVER_SCOPE)
def cert_private_key_file(example_cert):
with example_cert.private_key_pem.tempfile() as tmp:
yield tmp


@pytest.fixture
@pytest.fixture(scope=SERVER_SCOPE)
def cert_encrypted_private_key_file(example_cert):
with example_cert.encrypted_private_key_pem.tempfile() as tmp:
yield tmp


class TestServer(Server):
def install_signal_handlers(self) -> None:
# Disable the default installation of handlers for signals such as SIGTERM,
# because it can only be done in the main thread.
pass

async def serve(self, sockets=None):
self.restart_requested = asyncio.Event()

loop = asyncio.get_event_loop()
tasks = {
loop.create_task(super().serve(sockets=sockets)),
loop.create_task(self.watch_restarts()),
}

await asyncio.wait(tasks)

async def restart(self) -> None:
# Ensure we are in an asyncio environment.
assert asyncio.get_event_loop() is not None
# This may be called from a different thread than the one the server is
# running on. For this reason, we use an event to coordinate with the server
# instead of calling shutdown()/startup() directly.
self.restart_requested.set()
self.started = False
while not self.started:
await asyncio.sleep(0.5)

async def watch_restarts(self):
while True:
if self.should_exit:
return

try:
await asyncio.wait_for(self.restart_requested.wait(), timeout=0.1)
except asyncio.TimeoutError:
continue

self.restart_requested.clear()
await self.shutdown()
await self.startup()


@pytest.fixture
async def server():
config = Config(app=app, lifespan="off")
server = Server(config=config)
task = asyncio.ensure_future(server.serve())
def restart(backend):
"""Restart the running server from an async test function.

This fixture deals with possible differences between the environment of the
test function and that of the server.
"""

async def restart(server):
await backend.run_in_threadpool(AsyncioBackend().run, server.restart)

return restart


def serve_in_thread(server: Server):
thread = threading.Thread(target=server.run)
thread.start()
try:
while not server.started:
await asyncio.sleep(0.0001)
time.sleep(1e-3)
yield server
finally:
server.should_exit = True
await task
thread.join()


@pytest.fixture
async def https_server(cert_pem_file, cert_private_key_file):
@pytest.fixture(scope=SERVER_SCOPE)
def server():
config = Config(app=app, lifespan="off", loop="asyncio")
server = TestServer(config=config)
yield from serve_in_thread(server)


@pytest.fixture(scope=SERVER_SCOPE)
def https_server(cert_pem_file, cert_private_key_file):
config = Config(
app=app,
lifespan="off",
ssl_certfile=cert_pem_file,
ssl_keyfile=cert_private_key_file,
port=8001,
loop="asyncio",
)
server = Server(config=config)
task = asyncio.ensure_future(server.serve())
try:
while not server.started:
await asyncio.sleep(0.0001)
yield server
finally:
server.should_exit = True
await task


@pytest.fixture
def restart(backend):
async def asyncio_restart(server):
await server.shutdown()
await server.startup()

if isinstance(backend, AsyncioBackend):
return asyncio_restart

# The uvicorn server runs under asyncio, so we will need to figure out
# how to restart it under a different I/O library.
# This will most likely require running `asyncio_restart` in the threadpool,
# but that might not be sufficient.
raise NotImplementedError
server = TestServer(config=config)
yield from serve_in_thread(server)
4 changes: 2 additions & 2 deletions tests/dispatch/test_connection_pools.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ async def test_keepalive_connection_closed_by_server_is_reestablished(
response = await http.request("GET", "http://127.0.0.1:8000/")
await response.read()

# shutdown the server to close the keep-alive connection
# Shutdown the server to close the keep-alive connection
await restart(server)

response = await http.request("GET", "http://127.0.0.1:8000/")
Expand All @@ -154,7 +154,7 @@ async def test_keepalive_http2_connection_closed_by_server_is_reestablished(
response = await http.request("GET", "http://127.0.0.1:8000/")
await response.read()

# shutdown the server to close the keep-alive connection
# Shutdown the server to close the keep-alive connection
await restart(server)

response = await http.request("GET", "http://127.0.0.1:8000/")
Expand Down
44 changes: 23 additions & 21 deletions tests/dispatch/test_connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,40 +2,42 @@


async def test_get(server, backend):
conn = HTTPConnection(origin="http://127.0.0.1:8000/", backend=backend)
response = await conn.request("GET", "http://127.0.0.1:8000/")
await response.read()
assert response.status_code == 200
assert response.content == b"Hello, world!"
async with HTTPConnection(origin="http://127.0.0.1:8000/", backend=backend) as conn:
response = await conn.request("GET", "http://127.0.0.1:8000/")
await response.read()
assert response.status_code == 200
assert response.content == b"Hello, world!"


async def test_post(server, backend):
conn = HTTPConnection(origin="http://127.0.0.1:8000/", backend=backend)
response = await conn.request(
"GET", "http://127.0.0.1:8000/", data=b"Hello, world!"
)
assert response.status_code == 200
async with HTTPConnection(origin="http://127.0.0.1:8000/", backend=backend) as conn:
response = await conn.request(
"GET", "http://127.0.0.1:8000/", data=b"Hello, world!"
)
assert response.status_code == 200


async def test_https_get_with_ssl_defaults(https_server, backend):
"""
An HTTPS request, with default SSL configuration set on the client.
"""
conn = HTTPConnection(
async with HTTPConnection(
origin="https://127.0.0.1:8001/", verify=False, backend=backend
)
response = await conn.request("GET", "https://127.0.0.1:8001/")
await response.read()
assert response.status_code == 200
assert response.content == b"Hello, world!"
) as conn:
response = await conn.request("GET", "https://127.0.0.1:8001/")
await response.read()
assert response.status_code == 200
assert response.content == b"Hello, world!"


async def test_https_get_with_sll_overrides(https_server, backend):
"""
An HTTPS request, with SSL configuration set on the request.
"""
conn = HTTPConnection(origin="https://127.0.0.1:8001/", backend=backend)
response = await conn.request("GET", "https://127.0.0.1:8001/", verify=False)
await response.read()
assert response.status_code == 200
assert response.content == b"Hello, world!"
async with HTTPConnection(
origin="https://127.0.0.1:8001/", backend=backend
) as conn:
response = await conn.request("GET", "https://127.0.0.1:8001/", verify=False)
await response.read()
assert response.status_code == 200
assert response.content == b"Hello, world!"
Loading