diff --git a/Cargo.toml b/Cargo.toml index 167a46a657ee..c64eaf6907a5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -206,7 +206,7 @@ fs-set-times = "0.20.0" system-interface = { version = "0.26.0", features = ["cap_std_impls"] } io-lifetimes = { version = "2.0.2", default-features = false } io-extras = "0.18.0" -rustix = "0.38.8" +rustix = "0.38.21" is-terminal = "0.4.0" # wit-bindgen: wit-bindgen = { version = "0.13.0", default-features = false } diff --git a/crates/test-programs/src/bin/preview2_udp_connect.rs b/crates/test-programs/src/bin/preview2_udp_connect.rs new file mode 100644 index 000000000000..5100525e82ab --- /dev/null +++ b/crates/test-programs/src/bin/preview2_udp_connect.rs @@ -0,0 +1,41 @@ +use test_programs::wasi::sockets::network::{ + ErrorCode, IpAddress, IpAddressFamily, IpSocketAddress, Network, +}; +use test_programs::wasi::sockets::udp::UdpSocket; + +fn test_udp_connect_disconnect_reconnect(net: &Network, family: IpAddressFamily) { + let unspecified_addr = IpSocketAddress::new(IpAddress::new_unspecified(family), 0); + let remote1 = IpSocketAddress::new(IpAddress::new_loopback(family), 4321); + let remote2 = IpSocketAddress::new(IpAddress::new_loopback(family), 4320); + + let client = UdpSocket::new(family).unwrap(); + client.blocking_bind(&net, unspecified_addr).unwrap(); + + _ = client.stream(None).unwrap(); + assert_eq!(client.remote_address(), Err(ErrorCode::InvalidState)); + + _ = client.stream(None).unwrap(); + assert_eq!(client.remote_address(), Err(ErrorCode::InvalidState)); + + _ = client.stream(Some(remote1)).unwrap(); + assert_eq!(client.remote_address(), Ok(remote1)); + + _ = client.stream(Some(remote1)).unwrap(); + assert_eq!(client.remote_address(), Ok(remote1)); + + _ = client.stream(Some(remote2)).unwrap(); + assert_eq!(client.remote_address(), Ok(remote2)); + + _ = client.stream(None).unwrap(); + assert_eq!(client.remote_address(), Err(ErrorCode::InvalidState)); + + _ = client.stream(Some(remote1)).unwrap(); + assert_eq!(client.remote_address(), Ok(remote1)); +} + +fn main() { + let net = Network::default(); + + test_udp_connect_disconnect_reconnect(&net, IpAddressFamily::Ipv4); + test_udp_connect_disconnect_reconnect(&net, IpAddressFamily::Ipv6); +} diff --git a/crates/wasi/src/preview2/host/udp.rs b/crates/wasi/src/preview2/host/udp.rs index a97bd1bcd187..e2d1e239fa6d 100644 --- a/crates/wasi/src/preview2/host/udp.rs +++ b/crates/wasi/src/preview2/host/udp.rs @@ -97,13 +97,19 @@ impl udp::HostUdpSocket for T { _ => return Err(ErrorCode::InvalidState.into()), } + // We disconnect & (re)connect in two distinct steps for two reasons: + // - To leave our socket instance in a consistent state in case the + // connect fails. + // - When reconnecting to a different address, Linux sometimes fails + // if there isn't a disconnect in between. + + // Step #1: Disconnect if let UdpState::Connected = socket.udp_state { - // FIXME: Allow multiple (dis)connects. This needs to be supported by rustix first. - // rustix::net::disconnect(socket.udp_socket())?; - // socket.udp_state = UdpState::Bound; - return Err(ErrorCode::NotSupported.into()); + disconnect(socket.udp_socket())?; + socket.udp_state = UdpState::Bound; } + // Step #2: (Re)connect if let Some(connect_addr) = remote_address { rustix::net::connect(socket.udp_socket(), &connect_addr)?; socket.udp_state = UdpState::Connected; @@ -481,3 +487,12 @@ impl Subscribe for OutgoingDatagramStream { } } } + +fn disconnect(sockfd: Fd) -> rustix::io::Result<()> { + match rustix::net::connect_unspec(sockfd) { + // BSD platforms return an error even if the socket was disconnected successfully. + #[cfg(target_os = "macos")] + Err(rustix::io::Errno::INVAL | rustix::io::Errno::AFNOSUPPORT) => Ok(()), + r => r, + } +} diff --git a/crates/wasi/tests/all/async_.rs b/crates/wasi/tests/all/async_.rs index c1813476bf74..553e3846f220 100644 --- a/crates/wasi/tests/all/async_.rs +++ b/crates/wasi/tests/all/async_.rs @@ -319,6 +319,10 @@ async fn preview2_tcp_bind() { run(PREVIEW2_TCP_BIND_COMPONENT, false).await.unwrap() } #[test_log::test(tokio::test(flavor = "multi_thread"))] +async fn preview2_udp_connect() { + run(PREVIEW2_UDP_CONNECT_COMPONENT, false).await.unwrap() +} +#[test_log::test(tokio::test(flavor = "multi_thread"))] async fn preview2_udp_sample_application() { run(PREVIEW2_UDP_SAMPLE_APPLICATION_COMPONENT, false) .await diff --git a/crates/wasi/tests/all/sync.rs b/crates/wasi/tests/all/sync.rs index 4347e01d904a..0e087ab8ced7 100644 --- a/crates/wasi/tests/all/sync.rs +++ b/crates/wasi/tests/all/sync.rs @@ -262,6 +262,10 @@ fn preview2_tcp_bind() { run(PREVIEW2_TCP_BIND_COMPONENT, false).unwrap() } #[test_log::test] +fn preview2_udp_connect() { + run(PREVIEW2_UDP_CONNECT_COMPONENT, false).unwrap() +} +#[test_log::test] fn preview2_udp_sample_application() { run(PREVIEW2_UDP_SAMPLE_APPLICATION_COMPONENT, false).unwrap() }