Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 85 additions & 37 deletions crates/test-programs/src/bin/preview2_tcp_streams.rs
Original file line number Diff line number Diff line change
@@ -1,48 +1,96 @@
use std::time::Duration;

use test_programs::wasi::io::streams::StreamError;
use test_programs::wasi::io::streams::{InputStream, OutputStream, StreamError};
use test_programs::wasi::sockets::network::{IpAddress, IpAddressFamily, IpSocketAddress, Network};
use test_programs::wasi::sockets::tcp::{ShutdownType, TcpSocket};

/// InputStream::read should return `StreamError::Closed` after the connection has been shut down by the server.
fn test_tcp_read_from_closed_input_stream(net: &Network, family: IpAddressFamily) {
// Set up server & client sockets:
fn test_tcp_input_stream_should_be_closed_by_remote_shutdown(
net: &Network,
family: IpAddressFamily,
) {
setup(net, family, |server, client| {
// Shut down the connection from the server side:
server.socket.shutdown(ShutdownType::Both).unwrap();
drop(server);

// Wait for the shutdown signal to reach the client:
client.input.subscribe().block();

// The input stream should immediately signal StreamError::Closed.
// Notably, it should _not_ return an empty list (the wasi-io equivalent of EWOULDBLOCK)
// See: https://github.com/bytecodealliance/wasmtime/pull/8968
assert!(matches!(client.input.read(10), Err(StreamError::Closed)));

// Stream should still be closed, even when requesting 0 bytes:
assert!(matches!(client.input.read(0), Err(StreamError::Closed)));
});
}

/// OutputStream should return `StreamError::Closed` after the connection has been locally shut down for sending.
fn test_tcp_output_stream_should_be_closed_by_local_shutdown(
net: &Network,
family: IpAddressFamily,
) {
setup(net, family, |_server, client| {
let message = b"Hi!";

// The stream should be writable:
assert!(client.output.check_write().unwrap() as usize >= message.len());

// Perform the shutdown
client.socket.shutdown(ShutdownType::Send).unwrap();

// Stream should be closed:
assert!(matches!(
client.output.write(message),
Err(StreamError::Closed)
));

// The stream should remain closed:
assert!(matches!(
client.output.check_write(),
Err(StreamError::Closed)
));
assert!(matches!(client.output.flush(), Err(StreamError::Closed)));
});
}

fn main() {
let net = Network::default();

test_tcp_input_stream_should_be_closed_by_remote_shutdown(&net, IpAddressFamily::Ipv4);
test_tcp_input_stream_should_be_closed_by_remote_shutdown(&net, IpAddressFamily::Ipv6);

test_tcp_output_stream_should_be_closed_by_local_shutdown(&net, IpAddressFamily::Ipv4);
test_tcp_output_stream_should_be_closed_by_local_shutdown(&net, IpAddressFamily::Ipv6);
}

struct Connection {
input: InputStream,
output: OutputStream,
socket: TcpSocket,
}

/// Set up a connected pair of sockets
fn setup(net: &Network, family: IpAddressFamily, body: impl FnOnce(Connection, Connection)) {
let bind_address = IpSocketAddress::new(IpAddress::new_loopback(family), 0);
let listener = TcpSocket::new(family).unwrap();
listener.blocking_bind(&net, bind_address).unwrap();
listener.blocking_listen().unwrap();
let bound_address = listener.local_address().unwrap();
let client = TcpSocket::new(family).unwrap();
let (connected_input, connected_output) = client.blocking_connect(net, bound_address).unwrap();
let (accepted, accepted_input, accepted_output) = listener.blocking_accept().unwrap();

// Shut down the connection from the server side and give the kernel a bit
// of time to propagate the shutdown signal from the server socket to the
// client socket.
accepted.shutdown(ShutdownType::Both).unwrap();
drop(accepted_input);
drop(accepted_output);
drop(accepted);
std::thread::sleep(Duration::from_millis(50));

// And now the actual test:

// The input stream should immediately signal StreamError::Closed.
// Notably, it should _not_ return an empty list (the wasi-io equivalent of EWOULDBLOCK)
// See: https://github.com/bytecodealliance/wasmtime/pull/8968
assert!(matches!(connected_input.read(10), Err(StreamError::Closed))); // If this randomly fails, try tweaking the timeout above.

// Stream should still be closed, even when requesting 0 bytes:
assert!(matches!(connected_input.read(0), Err(StreamError::Closed)));

drop(connected_input);
drop(connected_output);
drop(client);
}

fn main() {
let net = Network::default();
let client_socket = TcpSocket::new(family).unwrap();
let (client_input, client_output) = client_socket.blocking_connect(net, bound_address).unwrap();
let (accepted_socket, accepted_input, accepted_output) = listener.blocking_accept().unwrap();

test_tcp_read_from_closed_input_stream(&net, IpAddressFamily::Ipv4);
test_tcp_read_from_closed_input_stream(&net, IpAddressFamily::Ipv6);
body(
Connection {
input: accepted_input,
output: accepted_output,
socket: accepted_socket,
},
Connection {
input: client_input,
output: client_output,
socket: client_socket,
},
);
}
46 changes: 36 additions & 10 deletions crates/wasi/src/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -665,13 +665,13 @@ impl Subscribe for TcpSocket {
}
}

pub(crate) struct TcpReadStream {
struct TcpReadStream {
stream: Arc<tokio::net::TcpStream>,
closed: bool,
}

impl TcpReadStream {
pub(crate) fn new(stream: Arc<tokio::net::TcpStream>) -> Self {
fn new(stream: Arc<tokio::net::TcpStream>) -> Self {
Self {
stream,
closed: false,
Expand Down Expand Up @@ -725,7 +725,7 @@ impl Subscribe for TcpReadStream {

const SOCKET_READY_SIZE: usize = 1024 * 1024 * 1024;

pub(crate) struct TcpWriteStream {
struct TcpWriteStream {
stream: Arc<tokio::net::TcpStream>,
last_write: LastWrite,
}
Expand All @@ -734,16 +734,31 @@ enum LastWrite {
Waiting(AbortOnDropJoinHandle<Result<()>>),
Error(Error),
Done,
Closed,
}

impl TcpWriteStream {
pub(crate) fn new(stream: Arc<tokio::net::TcpStream>) -> Self {
fn new(stream: Arc<tokio::net::TcpStream>) -> Self {
Self {
stream,
last_write: LastWrite::Done,
}
}

fn try_write_portable(stream: &tokio::net::TcpStream, buf: &[u8]) -> io::Result<usize> {
stream.try_write(buf).map_err(|error| {
match Errno::from_io_error(&error) {
// Windows returns `WSAESHUTDOWN` when writing to a shut down socket.
// We normalize this to EPIPE, because that is what the other platforms return.
// See: https://learn.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-send#:~:text=WSAESHUTDOWN
#[cfg(windows)]
Some(Errno::SHUTDOWN) => io::Error::new(io::ErrorKind::BrokenPipe, error),

_ => error,
}
})
}

/// Write `bytes` in a background task, remembering the task handle for use in a future call to
/// `write_ready`
fn background_write(&mut self, mut bytes: bytes::Bytes) {
Expand All @@ -758,7 +773,7 @@ impl TcpWriteStream {
// to flush.
while !bytes.is_empty() {
stream.writable().await?;
match stream.try_write(&bytes) {
match Self::try_write_portable(&stream, &bytes) {
Ok(n) => {
let _ = bytes.split_to(n);
}
Expand All @@ -776,14 +791,14 @@ impl HostOutputStream for TcpWriteStream {
fn write(&mut self, mut bytes: bytes::Bytes) -> Result<(), StreamError> {
match self.last_write {
LastWrite::Done => {}
LastWrite::Waiting(_) | LastWrite::Error(_) => {
LastWrite::Waiting(_) | LastWrite::Error(_) | LastWrite::Closed => {
return Err(StreamError::Trap(anyhow::anyhow!(
"unpermitted: must call check_write first"
)));
}
}
while !bytes.is_empty() {
match self.stream.try_write(&bytes) {
match Self::try_write_portable(&self.stream, &bytes) {
Ok(n) => {
let _ = bytes.split_to(n);
}
Expand All @@ -796,6 +811,11 @@ impl HostOutputStream for TcpWriteStream {
return Ok(());
}

Err(e) if e.kind() == std::io::ErrorKind::BrokenPipe => {
self.last_write = LastWrite::Closed;
return Err(StreamError::Closed);
}

Err(e) => return Err(StreamError::LastOperationFailed(e.into())),
}
}
Expand All @@ -807,16 +827,22 @@ impl HostOutputStream for TcpWriteStream {
// `flush` is a no-op here, as we're not managing any internal buffer. Additionally,
// `write_ready` will join the background write task if it's active, so following `flush`
// with `write_ready` will have the desired effect.
Ok(())
match self.last_write {
LastWrite::Done | LastWrite::Waiting(_) | LastWrite::Error(_) => Ok(()),
LastWrite::Closed => Err(StreamError::Closed),
}
}

fn check_write(&mut self) -> Result<usize, StreamError> {
match mem::replace(&mut self.last_write, LastWrite::Done) {
match mem::replace(&mut self.last_write, LastWrite::Closed) {
LastWrite::Waiting(task) => {
self.last_write = LastWrite::Waiting(task);
return Ok(0);
}
LastWrite::Done => {}
LastWrite::Done => {
self.last_write = LastWrite::Done;
}
LastWrite::Closed => return Err(StreamError::Closed),
LastWrite::Error(e) => return Err(StreamError::LastOperationFailed(e.into())),
}

Expand Down