diff --git a/src/fd/eventfd.rs b/src/fd/eventfd.rs index e35f41f23f..7244d6622c 100644 --- a/src/fd/eventfd.rs +++ b/src/fd/eventfd.rs @@ -1,7 +1,7 @@ use alloc::boxed::Box; use alloc::collections::vec_deque::VecDeque; use core::future::{self, Future}; -use core::mem; +use core::mem::{self, MaybeUninit}; use core::task::{Poll, Waker, ready}; use async_lock::Mutex; @@ -45,7 +45,7 @@ impl EventFd { #[async_trait] impl ObjectInterface for EventFd { - async fn read(&self, buf: &mut [u8]) -> io::Result { + async fn read(&self, buf: &mut [MaybeUninit]) -> io::Result { let len = mem::size_of::(); if buf.len() < len { @@ -58,8 +58,7 @@ impl ObjectInterface for EventFd { let mut guard = ready!(pinned.as_mut().poll(cx)); if guard.counter > 0 { guard.counter -= 1; - let tmp = u64::to_ne_bytes(1); - buf[..len].copy_from_slice(&tmp); + buf[..len].write_copy_of_slice(&u64::to_ne_bytes(1)); if let Some(cx) = guard.write_queue.pop_front() { cx.wake_by_ref(); } @@ -74,7 +73,7 @@ impl ObjectInterface for EventFd { let tmp = guard.counter; if tmp > 0 { guard.counter = 0; - buf[..len].copy_from_slice(&u64::to_ne_bytes(tmp)); + buf[..len].write_copy_of_slice(&u64::to_ne_bytes(tmp)); if let Some(cx) = guard.read_queue.pop_front() { cx.wake_by_ref(); } diff --git a/src/fd/mod.rs b/src/fd/mod.rs index f6c9b9777d..a7367e7eff 100644 --- a/src/fd/mod.rs +++ b/src/fd/mod.rs @@ -2,6 +2,7 @@ use alloc::boxed::Box; use alloc::sync::Arc; use alloc::vec::Vec; use core::future::{self, Future}; +use core::mem::MaybeUninit; use core::task::Poll::{Pending, Ready}; use core::time::Duration; @@ -152,7 +153,7 @@ pub(crate) trait ObjectInterface: Sync + Send + core::fmt::Debug { /// `async_read` attempts to read `len` bytes from the object references /// by the descriptor - async fn read(&self, _buf: &mut [u8]) -> io::Result { + async fn read(&self, _buf: &mut [MaybeUninit]) -> io::Result { Err(io::Error::ENOSYS) } @@ -230,7 +231,7 @@ pub(crate) trait ObjectInterface: Sync + Send + core::fmt::Debug { /// receive a message from a socket #[cfg(any(feature = "tcp", feature = "udp", feature = "vsock"))] - async fn recvfrom(&self, _buffer: &mut [u8]) -> io::Result<(usize, Endpoint)> { + async fn recvfrom(&self, _buffer: &mut [MaybeUninit]) -> io::Result<(usize, Endpoint)> { Err(io::Error::ENOSYS) } @@ -264,7 +265,7 @@ pub(crate) trait ObjectInterface: Sync + Send + core::fmt::Debug { } } -pub(crate) fn read(fd: FileDescriptor, buf: &mut [u8]) -> io::Result { +pub(crate) fn read(fd: FileDescriptor, buf: &mut [MaybeUninit]) -> io::Result { let obj = get_object(fd)?; if buf.is_empty() { diff --git a/src/fd/socket/tcp.rs b/src/fd/socket/tcp.rs index 46a8b4ed54..2c70f37bd6 100644 --- a/src/fd/socket/tcp.rs +++ b/src/fd/socket/tcp.rs @@ -2,6 +2,7 @@ use alloc::boxed::Box; use alloc::collections::BTreeSet; use alloc::sync::Arc; use core::future; +use core::mem::MaybeUninit; use core::sync::atomic::{AtomicU16, Ordering}; use core::task::Poll; @@ -171,7 +172,7 @@ impl Socket { .await } - async fn read(&self, buffer: &mut [u8]) -> io::Result { + async fn read(&self, buffer: &mut [MaybeUninit]) -> io::Result { future::poll_fn(|cx| { self.with(|socket| { let state = socket.state(); @@ -187,7 +188,7 @@ impl Socket { socket .recv(|data| { let len = core::cmp::min(buffer.len(), data.len()); - buffer[..len].copy_from_slice(&data[..len]); + buffer[..len].write_copy_of_slice(&data[..len]); (len, len) }) .map_err(|_| io::Error::EIO), @@ -468,7 +469,7 @@ impl ObjectInterface for async_lock::RwLock { self.read().await.poll(event).await } - async fn read(&self, buffer: &mut [u8]) -> io::Result { + async fn read(&self, buffer: &mut [MaybeUninit]) -> io::Result { self.read().await.read(buffer).await } diff --git a/src/fd/socket/udp.rs b/src/fd/socket/udp.rs index 86792c97bd..ad63be363c 100644 --- a/src/fd/socket/udp.rs +++ b/src/fd/socket/udp.rs @@ -1,5 +1,6 @@ use alloc::boxed::Box; use core::future; +use core::mem::MaybeUninit; use core::task::Poll; use async_trait::async_trait; @@ -141,24 +142,23 @@ impl Socket { } } - async fn recvfrom(&self, buffer: &mut [u8]) -> io::Result<(usize, Endpoint)> { + async fn recvfrom(&self, buffer: &mut [MaybeUninit]) -> io::Result<(usize, Endpoint)> { future::poll_fn(|cx| { self.with(|socket| { if socket.is_open() { if socket.can_recv() { - match socket.recv_slice(buffer) { - Ok((len, meta)) => match self.endpoint { - Some(ep) => { - if meta.endpoint == ep { - Poll::Ready(Ok((len, meta.endpoint))) - } else { - buffer[..len].iter_mut().for_each(|x| *x = 0); - socket.register_recv_waker(cx.waker()); - Poll::Pending - } + match socket.recv() { + // Drop the packet when the provided buffer cannot + // fit the payload. + Ok((data, meta)) if data.len() <= buffer.len() => { + if self.endpoint.is_none_or(|ep| meta.endpoint == ep) { + buffer[..data.len()].write_copy_of_slice(data); + Poll::Ready(Ok((data.len(), meta.endpoint))) + } else { + socket.register_recv_waker(cx.waker()); + Poll::Pending } - None => Poll::Ready(Ok((len, meta.endpoint))), - }, + } _ => Poll::Ready(Err(io::Error::EIO)), } } else { @@ -174,24 +174,23 @@ impl Socket { .map(|(len, endpoint)| (len, Endpoint::Ip(endpoint))) } - async fn read(&self, buffer: &mut [u8]) -> io::Result { + async fn read(&self, buffer: &mut [MaybeUninit]) -> io::Result { future::poll_fn(|cx| { self.with(|socket| { if socket.is_open() { if socket.can_recv() { - match socket.recv_slice(buffer) { - Ok((len, meta)) => match self.endpoint { - Some(ep) => { - if meta.endpoint == ep { - Poll::Ready(Ok(len)) - } else { - buffer[..len].iter_mut().for_each(|x| *x = 0); - socket.register_recv_waker(cx.waker()); - Poll::Pending - } + match socket.recv() { + // Drop the packet when the provided buffer cannot + // fit the payload. + Ok((data, meta)) if data.len() <= buffer.len() => { + if self.endpoint.is_none_or(|ep| meta.endpoint == ep) { + buffer[..data.len()].write_copy_of_slice(data); + Poll::Ready(Ok(data.len())) + } else { + socket.register_recv_waker(cx.waker()); + Poll::Pending } - None => Poll::Ready(Ok(len)), - }, + } _ => Poll::Ready(Err(io::Error::EIO)), } } else { @@ -257,11 +256,11 @@ impl ObjectInterface for async_lock::RwLock { self.read().await.sendto(buffer, endpoint).await } - async fn recvfrom(&self, buffer: &mut [u8]) -> io::Result<(usize, Endpoint)> { + async fn recvfrom(&self, buffer: &mut [MaybeUninit]) -> io::Result<(usize, Endpoint)> { self.read().await.recvfrom(buffer).await } - async fn read(&self, buffer: &mut [u8]) -> io::Result { + async fn read(&self, buffer: &mut [MaybeUninit]) -> io::Result { self.read().await.read(buffer).await } diff --git a/src/fd/socket/vsock.rs b/src/fd/socket/vsock.rs index 30f957182d..c10c350a59 100644 --- a/src/fd/socket/vsock.rs +++ b/src/fd/socket/vsock.rs @@ -2,6 +2,7 @@ use alloc::boxed::Box; use alloc::sync::Arc; use alloc::vec::Vec; use core::future; +use core::mem::MaybeUninit; use core::task::Poll; use async_trait::async_trait; @@ -312,7 +313,7 @@ impl Socket { } } - async fn read(&self, buffer: &mut [u8]) -> io::Result { + async fn read(&self, buffer: &mut [MaybeUninit]) -> io::Result { let port = self.port; future::poll_fn(|cx| { let mut guard = VSOCK_MAP.lock(); @@ -331,7 +332,7 @@ impl Socket { } } else { let tmp: Vec<_> = raw.buffer.drain(..len).collect(); - buffer[..len].copy_from_slice(tmp.as_slice()); + buffer[..len].write_copy_of_slice(tmp.as_slice()); Poll::Ready(Ok(len)) } @@ -343,7 +344,7 @@ impl Socket { Poll::Ready(Ok(0)) } else { let tmp: Vec<_> = raw.buffer.drain(..len).collect(); - buffer[..len].copy_from_slice(tmp.as_slice()); + buffer[..len].write_copy_of_slice(tmp.as_slice()); Poll::Ready(Ok(len)) } @@ -424,7 +425,7 @@ impl ObjectInterface for async_lock::RwLock { self.read().await.poll(event).await } - async fn read(&self, buffer: &mut [u8]) -> io::Result { + async fn read(&self, buffer: &mut [MaybeUninit]) -> io::Result { self.read().await.read(buffer).await } diff --git a/src/fd/stdio.rs b/src/fd/stdio.rs index 18b43b4446..c24647dec7 100644 --- a/src/fd/stdio.rs +++ b/src/fd/stdio.rs @@ -1,5 +1,6 @@ use alloc::boxed::Box; use core::future; +use core::mem::MaybeUninit; use core::task::Poll; use async_trait::async_trait; @@ -27,7 +28,7 @@ impl ObjectInterface for GenericStdin { Ok(event & available) } - async fn read(&self, buf: &mut [u8]) -> io::Result { + async fn read(&self, buf: &mut [MaybeUninit]) -> io::Result { future::poll_fn(|cx| { let mut read_bytes = 0; let mut guard = CONSOLE.lock(); @@ -36,7 +37,7 @@ impl ObjectInterface for GenericStdin { let c = unsafe { char::from_u32_unchecked(byte.into()) }; guard.write(c.as_bytes()); - buf[read_bytes] = byte; + buf[read_bytes].write(byte); read_bytes += 1; if read_bytes >= buf.len() { diff --git a/src/fs/fuse.rs b/src/fs/fuse.rs index 83d599d8f6..fd5d2e6ac2 100644 --- a/src/fs/fuse.rs +++ b/src/fs/fuse.rs @@ -4,6 +4,7 @@ use alloc::ffi::CString; use alloc::string::String; use alloc::sync::Arc; use alloc::vec::Vec; +use core::mem::MaybeUninit; use core::sync::atomic::{AtomicU64, Ordering}; use core::task::Poll; use core::{future, mem}; @@ -629,7 +630,7 @@ impl FuseFileHandleInner { .await } - fn read(&mut self, buf: &mut [u8]) -> io::Result { + fn read(&mut self, buf: &mut [MaybeUninit]) -> io::Result { let mut len = buf.len(); if len > MAX_READ_LEN { debug!("Reading longer than max_read_len: {}", len); @@ -651,7 +652,7 @@ impl FuseFileHandleInner { }; self.offset += len; - buf[..len].copy_from_slice(&rsp.payload.unwrap()[..len]); + buf[..len].write_copy_of_slice(&rsp.payload.unwrap()[..len]); Ok(len) } else { @@ -767,7 +768,7 @@ impl ObjectInterface for FuseFileHandle { self.0.lock().await.poll(event).await } - async fn read(&self, buf: &mut [u8]) -> io::Result { + async fn read(&self, buf: &mut [MaybeUninit]) -> io::Result { self.0.lock().await.read(buf) } diff --git a/src/fs/mem.rs b/src/fs/mem.rs index 74bc213a87..86e2314050 100644 --- a/src/fs/mem.rs +++ b/src/fs/mem.rs @@ -14,6 +14,7 @@ use alloc::collections::BTreeMap; use alloc::string::{String, ToString}; use alloc::sync::Arc; use alloc::vec::Vec; +use core::mem::MaybeUninit; use async_lock::{Mutex, RwLock}; use async_trait::async_trait; @@ -59,7 +60,7 @@ impl ObjectInterface for RomFileInterface { Ok(ret) } - async fn read(&self, buf: &mut [u8]) -> io::Result { + async fn read(&self, buf: &mut [MaybeUninit]) -> io::Result { { let microseconds = arch::kernel::systemtime::now_micros(); let t = timespec::from_usec(microseconds as i64); @@ -81,7 +82,7 @@ impl ObjectInterface for RomFileInterface { buf.len() }; - buf[0..len].clone_from_slice(&vec[pos..pos + len]); + buf[..len].write_copy_of_slice(&vec[pos..pos + len]); *pos_guard = pos + len; Ok(len) @@ -170,7 +171,7 @@ impl ObjectInterface for RamFileInterface { Ok(event & available) } - async fn read(&self, buf: &mut [u8]) -> io::Result { + async fn read(&self, buf: &mut [MaybeUninit]) -> io::Result { { let microseconds = arch::kernel::systemtime::now_micros(); let t = timespec::from_usec(microseconds as i64); @@ -192,7 +193,7 @@ impl ObjectInterface for RamFileInterface { buf.len() }; - buf[0..len].clone_from_slice(&guard.data[pos..pos + len]); + buf[..len].write_copy_of_slice(&guard.data[pos..pos + len]); *pos_guard = pos + len; Ok(len) @@ -214,7 +215,7 @@ impl ObjectInterface for RamFileInterface { guard.attr.st_mtim = t; guard.attr.st_ctim = t; - guard.data[pos..pos + buf.len()].clone_from_slice(buf); + guard.data[pos..pos + buf.len()].copy_from_slice(buf); *pos_guard = pos + buf.len(); Ok(buf.len()) diff --git a/src/fs/mod.rs b/src/fs/mod.rs index d5ca949f4d..153d72719a 100644 --- a/src/fs/mod.rs +++ b/src/fs/mod.rs @@ -496,6 +496,7 @@ impl File { impl crate::io::Read for File { fn read(&mut self, buf: &mut [u8]) -> io::Result { + let buf = unsafe { core::slice::from_raw_parts_mut(buf.as_mut_ptr().cast(), buf.len()) }; fd::read(self.fd, buf) } } diff --git a/src/fs/uhyve.rs b/src/fs/uhyve.rs index 1cb380cbd8..5d0b6fb88c 100644 --- a/src/fs/uhyve.rs +++ b/src/fs/uhyve.rs @@ -4,6 +4,7 @@ use alloc::ffi::CString; use alloc::string::{String, ToString}; use alloc::sync::Arc; use alloc::vec::Vec; +use core::mem::MaybeUninit; use async_lock::Mutex; use async_trait::async_trait; @@ -29,7 +30,7 @@ impl UhyveFileHandleInner { Self(fd) } - fn read(&mut self, buf: &mut [u8]) -> io::Result { + fn read(&mut self, buf: &mut [MaybeUninit]) -> io::Result { let mut read_params = ReadParams { fd: self.0, buf: GuestVirtAddr::new(buf.as_mut_ptr() as u64), @@ -94,7 +95,7 @@ impl UhyveFileHandle { #[async_trait] impl ObjectInterface for UhyveFileHandle { - async fn read(&self, buf: &mut [u8]) -> io::Result { + async fn read(&self, buf: &mut [MaybeUninit]) -> io::Result { self.0.lock().await.read(buf) } diff --git a/src/lib.rs b/src/lib.rs index abe0d171d9..6380670126 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -14,6 +14,7 @@ #![feature(map_try_insert)] #![feature(maybe_uninit_as_bytes)] #![feature(maybe_uninit_slice)] +#![feature(maybe_uninit_write_slice)] #![feature(naked_functions)] #![feature(never_type)] #![feature(slice_from_ptr_range)] diff --git a/src/syscalls/mod.rs b/src/syscalls/mod.rs index 93c8426797..ffdb0beb9a 100644 --- a/src/syscalls/mod.rs +++ b/src/syscalls/mod.rs @@ -387,7 +387,7 @@ pub extern "C" fn sys_close(fd: FileDescriptor) -> i32 { #[hermit_macro::system] #[unsafe(no_mangle)] pub unsafe extern "C" fn sys_read(fd: FileDescriptor, buf: *mut u8, len: usize) -> isize { - let slice = unsafe { core::slice::from_raw_parts_mut(buf, len) }; + let slice = unsafe { core::slice::from_raw_parts_mut(buf.cast(), len) }; crate::fd::read(fd, slice).map_or_else( |e| -num::ToPrimitive::to_isize(&e).unwrap(), |v| v.try_into().unwrap(), @@ -420,7 +420,9 @@ pub unsafe extern "C" fn sys_readv(fd: i32, iov: *const iovec, iovcnt: usize) -> let iovec_buffers = unsafe { core::slice::from_raw_parts(iov, iovcnt) }; for iovec_buf in iovec_buffers { - let buf = unsafe { core::slice::from_raw_parts_mut(iovec_buf.iov_base, iovec_buf.iov_len) }; + let buf = unsafe { + core::slice::from_raw_parts_mut(iovec_buf.iov_base.cast(), iovec_buf.iov_len) + }; let len = crate::fd::read(fd, buf).map_or_else( |e| -num::ToPrimitive::to_isize(&e).unwrap(), diff --git a/src/syscalls/socket.rs b/src/syscalls/socket.rs index 0f50e31157..7ab8eca183 100644 --- a/src/syscalls/socket.rs +++ b/src/syscalls/socket.rs @@ -884,7 +884,7 @@ pub extern "C" fn sys_shutdown_socket(fd: i32, how: i32) -> i32 { #[unsafe(no_mangle)] pub unsafe extern "C" fn sys_recv(fd: i32, buf: *mut u8, len: usize, flags: i32) -> isize { if flags == 0 { - let slice = unsafe { core::slice::from_raw_parts_mut(buf, len) }; + let slice = unsafe { core::slice::from_raw_parts_mut(buf.cast(), len) }; crate::fd::read(fd, slice).map_or_else( |e| -num::ToPrimitive::to_isize(&e).unwrap(), |v| v.try_into().unwrap(), @@ -962,7 +962,7 @@ pub unsafe extern "C" fn sys_recvfrom( addr: *mut sockaddr, addrlen: *mut socklen_t, ) -> isize { - let slice = unsafe { core::slice::from_raw_parts_mut(buf, len) }; + let slice = unsafe { core::slice::from_raw_parts_mut(buf.cast(), len) }; let obj = get_object(fd); obj.map_or_else( |e| -num::ToPrimitive::to_isize(&e).unwrap(),