Skip to content

Commit cfc5035

Browse files
authored
wasi-sockets: Return StreamError::Closed when the underlying socket is shut down. (#9055)
* Use proper I/O readiness to wait for the shutdown signal instead of an arbitrary timeout. * Move generic socket setup & teardown code out of the test case. * Make TcpRead/WriteStream private. * Return StreamError::Closed when the underlying socket is shut down. Added an additional LastWrite::Closed status to ensure that the OutputStream _keeps_ closed.
1 parent 7179592 commit cfc5035

2 files changed

Lines changed: 121 additions & 47 deletions

File tree

Lines changed: 85 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,96 @@
1-
use std::time::Duration;
2-
3-
use test_programs::wasi::io::streams::StreamError;
1+
use test_programs::wasi::io::streams::{InputStream, OutputStream, StreamError};
42
use test_programs::wasi::sockets::network::{IpAddress, IpAddressFamily, IpSocketAddress, Network};
53
use test_programs::wasi::sockets::tcp::{ShutdownType, TcpSocket};
64

75
/// InputStream::read should return `StreamError::Closed` after the connection has been shut down by the server.
8-
fn test_tcp_read_from_closed_input_stream(net: &Network, family: IpAddressFamily) {
9-
// Set up server & client sockets:
6+
fn test_tcp_input_stream_should_be_closed_by_remote_shutdown(
7+
net: &Network,
8+
family: IpAddressFamily,
9+
) {
10+
setup(net, family, |server, client| {
11+
// Shut down the connection from the server side:
12+
server.socket.shutdown(ShutdownType::Both).unwrap();
13+
drop(server);
14+
15+
// Wait for the shutdown signal to reach the client:
16+
client.input.subscribe().block();
17+
18+
// The input stream should immediately signal StreamError::Closed.
19+
// Notably, it should _not_ return an empty list (the wasi-io equivalent of EWOULDBLOCK)
20+
// See: https://github.com/bytecodealliance/wasmtime/pull/8968
21+
assert!(matches!(client.input.read(10), Err(StreamError::Closed)));
22+
23+
// Stream should still be closed, even when requesting 0 bytes:
24+
assert!(matches!(client.input.read(0), Err(StreamError::Closed)));
25+
});
26+
}
27+
28+
/// OutputStream should return `StreamError::Closed` after the connection has been locally shut down for sending.
29+
fn test_tcp_output_stream_should_be_closed_by_local_shutdown(
30+
net: &Network,
31+
family: IpAddressFamily,
32+
) {
33+
setup(net, family, |_server, client| {
34+
let message = b"Hi!";
35+
36+
// The stream should be writable:
37+
assert!(client.output.check_write().unwrap() as usize >= message.len());
38+
39+
// Perform the shutdown
40+
client.socket.shutdown(ShutdownType::Send).unwrap();
41+
42+
// Stream should be closed:
43+
assert!(matches!(
44+
client.output.write(message),
45+
Err(StreamError::Closed)
46+
));
47+
48+
// The stream should remain closed:
49+
assert!(matches!(
50+
client.output.check_write(),
51+
Err(StreamError::Closed)
52+
));
53+
assert!(matches!(client.output.flush(), Err(StreamError::Closed)));
54+
});
55+
}
56+
57+
fn main() {
58+
let net = Network::default();
59+
60+
test_tcp_input_stream_should_be_closed_by_remote_shutdown(&net, IpAddressFamily::Ipv4);
61+
test_tcp_input_stream_should_be_closed_by_remote_shutdown(&net, IpAddressFamily::Ipv6);
62+
63+
test_tcp_output_stream_should_be_closed_by_local_shutdown(&net, IpAddressFamily::Ipv4);
64+
test_tcp_output_stream_should_be_closed_by_local_shutdown(&net, IpAddressFamily::Ipv6);
65+
}
66+
67+
struct Connection {
68+
input: InputStream,
69+
output: OutputStream,
70+
socket: TcpSocket,
71+
}
72+
73+
/// Set up a connected pair of sockets
74+
fn setup(net: &Network, family: IpAddressFamily, body: impl FnOnce(Connection, Connection)) {
1075
let bind_address = IpSocketAddress::new(IpAddress::new_loopback(family), 0);
1176
let listener = TcpSocket::new(family).unwrap();
1277
listener.blocking_bind(&net, bind_address).unwrap();
1378
listener.blocking_listen().unwrap();
1479
let bound_address = listener.local_address().unwrap();
15-
let client = TcpSocket::new(family).unwrap();
16-
let (connected_input, connected_output) = client.blocking_connect(net, bound_address).unwrap();
17-
let (accepted, accepted_input, accepted_output) = listener.blocking_accept().unwrap();
18-
19-
// Shut down the connection from the server side and give the kernel a bit
20-
// of time to propagate the shutdown signal from the server socket to the
21-
// client socket.
22-
accepted.shutdown(ShutdownType::Both).unwrap();
23-
drop(accepted_input);
24-
drop(accepted_output);
25-
drop(accepted);
26-
std::thread::sleep(Duration::from_millis(50));
27-
28-
// And now the actual test:
29-
30-
// The input stream should immediately signal StreamError::Closed.
31-
// Notably, it should _not_ return an empty list (the wasi-io equivalent of EWOULDBLOCK)
32-
// See: https://github.com/bytecodealliance/wasmtime/pull/8968
33-
assert!(matches!(connected_input.read(10), Err(StreamError::Closed))); // If this randomly fails, try tweaking the timeout above.
34-
35-
// Stream should still be closed, even when requesting 0 bytes:
36-
assert!(matches!(connected_input.read(0), Err(StreamError::Closed)));
37-
38-
drop(connected_input);
39-
drop(connected_output);
40-
drop(client);
41-
}
42-
43-
fn main() {
44-
let net = Network::default();
80+
let client_socket = TcpSocket::new(family).unwrap();
81+
let (client_input, client_output) = client_socket.blocking_connect(net, bound_address).unwrap();
82+
let (accepted_socket, accepted_input, accepted_output) = listener.blocking_accept().unwrap();
4583

46-
test_tcp_read_from_closed_input_stream(&net, IpAddressFamily::Ipv4);
47-
test_tcp_read_from_closed_input_stream(&net, IpAddressFamily::Ipv6);
84+
body(
85+
Connection {
86+
input: accepted_input,
87+
output: accepted_output,
88+
socket: accepted_socket,
89+
},
90+
Connection {
91+
input: client_input,
92+
output: client_output,
93+
socket: client_socket,
94+
},
95+
);
4896
}

crates/wasi/src/tcp.rs

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -665,13 +665,13 @@ impl Subscribe for TcpSocket {
665665
}
666666
}
667667

668-
pub(crate) struct TcpReadStream {
668+
struct TcpReadStream {
669669
stream: Arc<tokio::net::TcpStream>,
670670
closed: bool,
671671
}
672672

673673
impl TcpReadStream {
674-
pub(crate) fn new(stream: Arc<tokio::net::TcpStream>) -> Self {
674+
fn new(stream: Arc<tokio::net::TcpStream>) -> Self {
675675
Self {
676676
stream,
677677
closed: false,
@@ -725,7 +725,7 @@ impl Subscribe for TcpReadStream {
725725

726726
const SOCKET_READY_SIZE: usize = 1024 * 1024 * 1024;
727727

728-
pub(crate) struct TcpWriteStream {
728+
struct TcpWriteStream {
729729
stream: Arc<tokio::net::TcpStream>,
730730
last_write: LastWrite,
731731
}
@@ -734,16 +734,31 @@ enum LastWrite {
734734
Waiting(AbortOnDropJoinHandle<Result<()>>),
735735
Error(Error),
736736
Done,
737+
Closed,
737738
}
738739

739740
impl TcpWriteStream {
740-
pub(crate) fn new(stream: Arc<tokio::net::TcpStream>) -> Self {
741+
fn new(stream: Arc<tokio::net::TcpStream>) -> Self {
741742
Self {
742743
stream,
743744
last_write: LastWrite::Done,
744745
}
745746
}
746747

748+
fn try_write_portable(stream: &tokio::net::TcpStream, buf: &[u8]) -> io::Result<usize> {
749+
stream.try_write(buf).map_err(|error| {
750+
match Errno::from_io_error(&error) {
751+
// Windows returns `WSAESHUTDOWN` when writing to a shut down socket.
752+
// We normalize this to EPIPE, because that is what the other platforms return.
753+
// See: https://learn.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-send#:~:text=WSAESHUTDOWN
754+
#[cfg(windows)]
755+
Some(Errno::SHUTDOWN) => io::Error::new(io::ErrorKind::BrokenPipe, error),
756+
757+
_ => error,
758+
}
759+
})
760+
}
761+
747762
/// Write `bytes` in a background task, remembering the task handle for use in a future call to
748763
/// `write_ready`
749764
fn background_write(&mut self, mut bytes: bytes::Bytes) {
@@ -758,7 +773,7 @@ impl TcpWriteStream {
758773
// to flush.
759774
while !bytes.is_empty() {
760775
stream.writable().await?;
761-
match stream.try_write(&bytes) {
776+
match Self::try_write_portable(&stream, &bytes) {
762777
Ok(n) => {
763778
let _ = bytes.split_to(n);
764779
}
@@ -776,14 +791,14 @@ impl HostOutputStream for TcpWriteStream {
776791
fn write(&mut self, mut bytes: bytes::Bytes) -> Result<(), StreamError> {
777792
match self.last_write {
778793
LastWrite::Done => {}
779-
LastWrite::Waiting(_) | LastWrite::Error(_) => {
794+
LastWrite::Waiting(_) | LastWrite::Error(_) | LastWrite::Closed => {
780795
return Err(StreamError::Trap(anyhow::anyhow!(
781796
"unpermitted: must call check_write first"
782797
)));
783798
}
784799
}
785800
while !bytes.is_empty() {
786-
match self.stream.try_write(&bytes) {
801+
match Self::try_write_portable(&self.stream, &bytes) {
787802
Ok(n) => {
788803
let _ = bytes.split_to(n);
789804
}
@@ -796,6 +811,11 @@ impl HostOutputStream for TcpWriteStream {
796811
return Ok(());
797812
}
798813

814+
Err(e) if e.kind() == std::io::ErrorKind::BrokenPipe => {
815+
self.last_write = LastWrite::Closed;
816+
return Err(StreamError::Closed);
817+
}
818+
799819
Err(e) => return Err(StreamError::LastOperationFailed(e.into())),
800820
}
801821
}
@@ -807,16 +827,22 @@ impl HostOutputStream for TcpWriteStream {
807827
// `flush` is a no-op here, as we're not managing any internal buffer. Additionally,
808828
// `write_ready` will join the background write task if it's active, so following `flush`
809829
// with `write_ready` will have the desired effect.
810-
Ok(())
830+
match self.last_write {
831+
LastWrite::Done | LastWrite::Waiting(_) | LastWrite::Error(_) => Ok(()),
832+
LastWrite::Closed => Err(StreamError::Closed),
833+
}
811834
}
812835

813836
fn check_write(&mut self) -> Result<usize, StreamError> {
814-
match mem::replace(&mut self.last_write, LastWrite::Done) {
837+
match mem::replace(&mut self.last_write, LastWrite::Closed) {
815838
LastWrite::Waiting(task) => {
816839
self.last_write = LastWrite::Waiting(task);
817840
return Ok(0);
818841
}
819-
LastWrite::Done => {}
842+
LastWrite::Done => {
843+
self.last_write = LastWrite::Done;
844+
}
845+
LastWrite::Closed => return Err(StreamError::Closed),
820846
LastWrite::Error(e) => return Err(StreamError::LastOperationFailed(e.into())),
821847
}
822848

0 commit comments

Comments
 (0)