diff --git a/crates/test-programs/src/bin/preview2_tcp_streams.rs b/crates/test-programs/src/bin/preview2_tcp_streams.rs index a9a01b5ec692..688827be661e 100644 --- a/crates/test-programs/src/bin/preview2_tcp_streams.rs +++ b/crates/test-programs/src/bin/preview2_tcp_streams.rs @@ -1,48 +1,96 @@ -use std::time::Duration; - -use test_programs::wasi::io::streams::StreamError; +use test_programs::wasi::io::streams::{InputStream, OutputStream, StreamError}; use test_programs::wasi::sockets::network::{IpAddress, IpAddressFamily, IpSocketAddress, Network}; use test_programs::wasi::sockets::tcp::{ShutdownType, TcpSocket}; /// InputStream::read should return `StreamError::Closed` after the connection has been shut down by the server. -fn test_tcp_read_from_closed_input_stream(net: &Network, family: IpAddressFamily) { - // Set up server & client sockets: +fn test_tcp_input_stream_should_be_closed_by_remote_shutdown( + net: &Network, + family: IpAddressFamily, +) { + setup(net, family, |server, client| { + // Shut down the connection from the server side: + server.socket.shutdown(ShutdownType::Both).unwrap(); + drop(server); + + // Wait for the shutdown signal to reach the client: + client.input.subscribe().block(); + + // The input stream should immediately signal StreamError::Closed. + // Notably, it should _not_ return an empty list (the wasi-io equivalent of EWOULDBLOCK) + // See: https://github.com/bytecodealliance/wasmtime/pull/8968 + assert!(matches!(client.input.read(10), Err(StreamError::Closed))); + + // Stream should still be closed, even when requesting 0 bytes: + assert!(matches!(client.input.read(0), Err(StreamError::Closed))); + }); +} + +/// OutputStream should return `StreamError::Closed` after the connection has been locally shut down for sending. +fn test_tcp_output_stream_should_be_closed_by_local_shutdown( + net: &Network, + family: IpAddressFamily, +) { + setup(net, family, |_server, client| { + let message = b"Hi!"; + + // The stream should be writable: + assert!(client.output.check_write().unwrap() as usize >= message.len()); + + // Perform the shutdown + client.socket.shutdown(ShutdownType::Send).unwrap(); + + // Stream should be closed: + assert!(matches!( + client.output.write(message), + Err(StreamError::Closed) + )); + + // The stream should remain closed: + assert!(matches!( + client.output.check_write(), + Err(StreamError::Closed) + )); + assert!(matches!(client.output.flush(), Err(StreamError::Closed))); + }); +} + +fn main() { + let net = Network::default(); + + test_tcp_input_stream_should_be_closed_by_remote_shutdown(&net, IpAddressFamily::Ipv4); + test_tcp_input_stream_should_be_closed_by_remote_shutdown(&net, IpAddressFamily::Ipv6); + + test_tcp_output_stream_should_be_closed_by_local_shutdown(&net, IpAddressFamily::Ipv4); + test_tcp_output_stream_should_be_closed_by_local_shutdown(&net, IpAddressFamily::Ipv6); +} + +struct Connection { + input: InputStream, + output: OutputStream, + socket: TcpSocket, +} + +/// Set up a connected pair of sockets +fn setup(net: &Network, family: IpAddressFamily, body: impl FnOnce(Connection, Connection)) { let bind_address = IpSocketAddress::new(IpAddress::new_loopback(family), 0); let listener = TcpSocket::new(family).unwrap(); listener.blocking_bind(&net, bind_address).unwrap(); listener.blocking_listen().unwrap(); let bound_address = listener.local_address().unwrap(); - let client = TcpSocket::new(family).unwrap(); - let (connected_input, connected_output) = client.blocking_connect(net, bound_address).unwrap(); - let (accepted, accepted_input, accepted_output) = listener.blocking_accept().unwrap(); - - // Shut down the connection from the server side and give the kernel a bit - // of time to propagate the shutdown signal from the server socket to the - // client socket. - accepted.shutdown(ShutdownType::Both).unwrap(); - drop(accepted_input); - drop(accepted_output); - drop(accepted); - std::thread::sleep(Duration::from_millis(50)); - - // And now the actual test: - - // The input stream should immediately signal StreamError::Closed. - // Notably, it should _not_ return an empty list (the wasi-io equivalent of EWOULDBLOCK) - // See: https://github.com/bytecodealliance/wasmtime/pull/8968 - assert!(matches!(connected_input.read(10), Err(StreamError::Closed))); // If this randomly fails, try tweaking the timeout above. - - // Stream should still be closed, even when requesting 0 bytes: - assert!(matches!(connected_input.read(0), Err(StreamError::Closed))); - - drop(connected_input); - drop(connected_output); - drop(client); -} - -fn main() { - let net = Network::default(); + let client_socket = TcpSocket::new(family).unwrap(); + let (client_input, client_output) = client_socket.blocking_connect(net, bound_address).unwrap(); + let (accepted_socket, accepted_input, accepted_output) = listener.blocking_accept().unwrap(); - test_tcp_read_from_closed_input_stream(&net, IpAddressFamily::Ipv4); - test_tcp_read_from_closed_input_stream(&net, IpAddressFamily::Ipv6); + body( + Connection { + input: accepted_input, + output: accepted_output, + socket: accepted_socket, + }, + Connection { + input: client_input, + output: client_output, + socket: client_socket, + }, + ); } diff --git a/crates/wasi/src/tcp.rs b/crates/wasi/src/tcp.rs index d755a316d944..4bc811f49f03 100644 --- a/crates/wasi/src/tcp.rs +++ b/crates/wasi/src/tcp.rs @@ -665,13 +665,13 @@ impl Subscribe for TcpSocket { } } -pub(crate) struct TcpReadStream { +struct TcpReadStream { stream: Arc, closed: bool, } impl TcpReadStream { - pub(crate) fn new(stream: Arc) -> Self { + fn new(stream: Arc) -> Self { Self { stream, closed: false, @@ -725,7 +725,7 @@ impl Subscribe for TcpReadStream { const SOCKET_READY_SIZE: usize = 1024 * 1024 * 1024; -pub(crate) struct TcpWriteStream { +struct TcpWriteStream { stream: Arc, last_write: LastWrite, } @@ -734,16 +734,31 @@ enum LastWrite { Waiting(AbortOnDropJoinHandle>), Error(Error), Done, + Closed, } impl TcpWriteStream { - pub(crate) fn new(stream: Arc) -> Self { + fn new(stream: Arc) -> Self { Self { stream, last_write: LastWrite::Done, } } + fn try_write_portable(stream: &tokio::net::TcpStream, buf: &[u8]) -> io::Result { + stream.try_write(buf).map_err(|error| { + match Errno::from_io_error(&error) { + // Windows returns `WSAESHUTDOWN` when writing to a shut down socket. + // We normalize this to EPIPE, because that is what the other platforms return. + // See: https://learn.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-send#:~:text=WSAESHUTDOWN + #[cfg(windows)] + Some(Errno::SHUTDOWN) => io::Error::new(io::ErrorKind::BrokenPipe, error), + + _ => error, + } + }) + } + /// Write `bytes` in a background task, remembering the task handle for use in a future call to /// `write_ready` fn background_write(&mut self, mut bytes: bytes::Bytes) { @@ -758,7 +773,7 @@ impl TcpWriteStream { // to flush. while !bytes.is_empty() { stream.writable().await?; - match stream.try_write(&bytes) { + match Self::try_write_portable(&stream, &bytes) { Ok(n) => { let _ = bytes.split_to(n); } @@ -776,14 +791,14 @@ impl HostOutputStream for TcpWriteStream { fn write(&mut self, mut bytes: bytes::Bytes) -> Result<(), StreamError> { match self.last_write { LastWrite::Done => {} - LastWrite::Waiting(_) | LastWrite::Error(_) => { + LastWrite::Waiting(_) | LastWrite::Error(_) | LastWrite::Closed => { return Err(StreamError::Trap(anyhow::anyhow!( "unpermitted: must call check_write first" ))); } } while !bytes.is_empty() { - match self.stream.try_write(&bytes) { + match Self::try_write_portable(&self.stream, &bytes) { Ok(n) => { let _ = bytes.split_to(n); } @@ -796,6 +811,11 @@ impl HostOutputStream for TcpWriteStream { return Ok(()); } + Err(e) if e.kind() == std::io::ErrorKind::BrokenPipe => { + self.last_write = LastWrite::Closed; + return Err(StreamError::Closed); + } + Err(e) => return Err(StreamError::LastOperationFailed(e.into())), } } @@ -807,16 +827,22 @@ impl HostOutputStream for TcpWriteStream { // `flush` is a no-op here, as we're not managing any internal buffer. Additionally, // `write_ready` will join the background write task if it's active, so following `flush` // with `write_ready` will have the desired effect. - Ok(()) + match self.last_write { + LastWrite::Done | LastWrite::Waiting(_) | LastWrite::Error(_) => Ok(()), + LastWrite::Closed => Err(StreamError::Closed), + } } fn check_write(&mut self) -> Result { - match mem::replace(&mut self.last_write, LastWrite::Done) { + match mem::replace(&mut self.last_write, LastWrite::Closed) { LastWrite::Waiting(task) => { self.last_write = LastWrite::Waiting(task); return Ok(0); } - LastWrite::Done => {} + LastWrite::Done => { + self.last_write = LastWrite::Done; + } + LastWrite::Closed => return Err(StreamError::Closed), LastWrite::Error(e) => return Err(StreamError::LastOperationFailed(e.into())), }