Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ rustix = { version = "0.35.6", features = ["mm", "param"] }
# depend again on wasmtime to activate its default features for tests
wasmtime = { path = "crates/wasmtime", version = "0.41.0", features = ['component-model'] }
env_logger = "0.9.0"
log = "0.4.8"
filecheck = "0.5.0"
tempfile = "3.1.0"
test-programs = { path = "crates/test-programs" }
Expand Down
272 changes: 193 additions & 79 deletions crates/runtime/src/traphandlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::{VMContext, VMRuntimeLimits};
use anyhow::Error;
use std::any::Any;
use std::cell::{Cell, UnsafeCell};
use std::mem::{self, MaybeUninit};
use std::mem::MaybeUninit;
use std::ptr;
use std::sync::Once;
use wasmtime_environ::TrapCode;
Expand Down Expand Up @@ -182,19 +182,7 @@ where
{
let limits = (*caller).instance().runtime_limits();

let old_last_wasm_exit_fp = mem::replace(&mut *(**limits).last_wasm_exit_fp.get(), 0);
let old_last_wasm_exit_pc = mem::replace(&mut *(**limits).last_wasm_exit_pc.get(), 0);
let old_last_wasm_entry_sp = mem::replace(&mut *(**limits).last_wasm_entry_sp.get(), 0);

let result = CallThreadState::new(
signal_handler,
capture_backtrace,
old_last_wasm_exit_fp,
old_last_wasm_exit_pc,
old_last_wasm_entry_sp,
*limits,
)
.with(|cx| {
let result = CallThreadState::new(signal_handler, capture_backtrace, *limits).with(|cx| {
wasmtime_setjmp(
cx.jmp_buf.as_ptr(),
call_closure::<F>,
Expand All @@ -203,10 +191,6 @@ where
)
});

*(**limits).last_wasm_exit_fp.get() = old_last_wasm_exit_fp;
*(**limits).last_wasm_exit_pc.get() = old_last_wasm_exit_pc;
*(**limits).last_wasm_entry_sp.get() = old_last_wasm_entry_sp;

return match result {
Ok(x) => Ok(x),
Err((UnwindReason::Trap(reason), backtrace)) => Err(Box::new(Trap { reason, backtrace })),
Expand All @@ -221,55 +205,171 @@ where
}
}

/// Temporary state stored on the stack which is registered in the `tls` module
/// below for calls into wasm.
pub struct CallThreadState {
unwind: UnsafeCell<MaybeUninit<(UnwindReason, Option<Backtrace>)>>,
jmp_buf: Cell<*const u8>,
handling_trap: Cell<bool>,
signal_handler: Option<*const SignalHandler<'static>>,
prev: Cell<tls::Ptr>,
capture_backtrace: bool,
pub(crate) old_last_wasm_exit_fp: usize,
pub(crate) old_last_wasm_exit_pc: usize,
pub(crate) old_last_wasm_entry_sp: usize,
pub(crate) limits: *const VMRuntimeLimits,
// Module to hide visibility of the `CallThreadState::prev` field and force
// usage of its accessor methods.
mod call_thread_state {
use super::*;
use std::mem;

/// Temporary state stored on the stack which is registered in the `tls` module
/// below for calls into wasm.
pub struct CallThreadState {
pub(super) unwind: UnsafeCell<MaybeUninit<(UnwindReason, Option<Backtrace>)>>,
pub(super) jmp_buf: Cell<*const u8>,
pub(super) handling_trap: Cell<bool>,
pub(super) signal_handler: Option<*const SignalHandler<'static>>,
pub(super) capture_backtrace: bool,

pub(crate) limits: *const VMRuntimeLimits,

prev: Cell<tls::Ptr>,

// The values of `VMRuntimeLimits::last_wasm_{exit_{pc,fp},entry_sp}` for
// the *previous* `CallThreadState`. Our *current* last wasm PC/FP/SP are
// saved in `self.limits`. We save a copy of the old registers here because
// the `VMRuntimeLimits` typically doesn't change across nested calls into
// Wasm (i.e. they are typically calls back into the same store and
// `self.limits == self.prev.limits`) and we must to maintain the list of
// contiguous-Wasm-frames stack regions for backtracing purposes.
old_last_wasm_exit_fp: Cell<usize>,
old_last_wasm_exit_pc: Cell<usize>,
old_last_wasm_entry_sp: Cell<usize>,
}

impl CallThreadState {
#[inline]
pub(super) fn new(
signal_handler: Option<*const SignalHandler<'static>>,
capture_backtrace: bool,
limits: *const VMRuntimeLimits,
) -> CallThreadState {
CallThreadState {
unwind: UnsafeCell::new(MaybeUninit::uninit()),
jmp_buf: Cell::new(ptr::null()),
handling_trap: Cell::new(false),
signal_handler,
capture_backtrace,
limits,
prev: Cell::new(ptr::null()),
old_last_wasm_exit_fp: Cell::new(0),
old_last_wasm_exit_pc: Cell::new(0),
old_last_wasm_entry_sp: Cell::new(0),
}
}

/// Get the saved FP upon exit from Wasm for the previous `CallThreadState`.
pub fn old_last_wasm_exit_fp(&self) -> usize {
self.old_last_wasm_exit_fp.get()
}

/// Get the saved PC upon exit from Wasm for the previous `CallThreadState`.
pub fn old_last_wasm_exit_pc(&self) -> usize {
self.old_last_wasm_exit_pc.get()
}

/// Get the saved SP upon entry into Wasm for the previous `CallThreadState`.
pub fn old_last_wasm_entry_sp(&self) -> usize {
self.old_last_wasm_entry_sp.get()
}

/// Get the previous `CallThreadState`.
pub fn prev(&self) -> tls::Ptr {
self.prev.get()
}

/// Connect the link to the previous `CallThreadState`.
///
/// Synchronizes the last wasm FP, PC, and SP on `self` and the old
/// `self.prev` for the given new `prev`, and returns the old
/// `self.prev`.
pub unsafe fn set_prev(&self, prev: tls::Ptr) -> tls::Ptr {
let old_prev = self.prev.get();

// Restore the old `prev`'s saved registers in its
// `VMRuntimeLimits`. This is necessary for when we are async
// suspending the top `CallThreadState` and doing `set_prev(null)`
// on it, and so any stack walking we do subsequently will start at
// the old `prev` and look at its `VMRuntimeLimits` to get the
// initial saved registers.
if let Some(old_prev) = old_prev.as_ref() {
*(*old_prev.limits).last_wasm_exit_fp.get() = self.old_last_wasm_exit_fp();
*(*old_prev.limits).last_wasm_exit_pc.get() = self.old_last_wasm_exit_pc();
*(*old_prev.limits).last_wasm_entry_sp.get() = self.old_last_wasm_entry_sp();
}

self.prev.set(prev);

let mut old_last_wasm_exit_fp = 0;
let mut old_last_wasm_exit_pc = 0;
let mut old_last_wasm_entry_sp = 0;
if let Some(prev) = prev.as_ref() {
// We are entering a new `CallThreadState` or resuming a
// previously suspended one. This means we will push new Wasm
// frames that save the new Wasm FP/SP/PC registers into
// `VMRuntimeLimits`, we need to first save the old Wasm
// FP/SP/PC registers into this new `CallThreadState` to
// maintain our list of contiguous Wasm frame regions that we
// use when capturing stack traces.
//
// NB: the Wasm<--->host trampolines saved the Wasm FP/SP/PC
// registers in the active-at-that-time store's
// `VMRuntimeLimits`. For the most recent FP/PC/SP that is the
// `state.prev.limits` (since we haven't entered this
// `CallThreadState` yet). And that can be a different
// `VMRuntimeLimits` instance from the currently active
// `state.limits`, which will be used by the upcoming call into
// Wasm! Consider the case where we have multiple, nested calls
// across stores (with host code in between, by necessity, since
// only things in the same store can be linked directly
// together):
//
// | ... |
// | Host | |
// +-----------------+ | stack
// | Wasm in store A | | grows
// +-----------------+ | down
// | Host | |
// +-----------------+ |
// | Wasm in store B | V
// +-----------------+
//
// In this scenario `state.limits != state.prev.limits`,
// i.e. `B.limits != A.limits`! Therefore we must take care to
// read the old FP/SP/PC from `state.prev.limits`, rather than
// `state.limits`, and store those saved registers into the
// current `state`.
//
// See also the comment above the
// `CallThreadState::old_last_wasm_*` fields.
old_last_wasm_exit_fp =
mem::replace(&mut *(*prev.limits).last_wasm_exit_fp.get(), 0);
old_last_wasm_exit_pc =
mem::replace(&mut *(*prev.limits).last_wasm_exit_pc.get(), 0);
old_last_wasm_entry_sp =
mem::replace(&mut *(*prev.limits).last_wasm_entry_sp.get(), 0);
}

self.old_last_wasm_exit_fp.set(old_last_wasm_exit_fp);
self.old_last_wasm_exit_pc.set(old_last_wasm_exit_pc);
self.old_last_wasm_entry_sp.set(old_last_wasm_entry_sp);

old_prev
}
}
}
pub use call_thread_state::*;

enum UnwindReason {
Panic(Box<dyn Any + Send>),
Trap(TrapReason),
}

impl CallThreadState {
#[inline]
fn new(
signal_handler: Option<*const SignalHandler<'static>>,
capture_backtrace: bool,
old_last_wasm_exit_fp: usize,
old_last_wasm_exit_pc: usize,
old_last_wasm_entry_sp: usize,
limits: *const VMRuntimeLimits,
) -> CallThreadState {
CallThreadState {
unwind: UnsafeCell::new(MaybeUninit::uninit()),
jmp_buf: Cell::new(ptr::null()),
handling_trap: Cell::new(false),
signal_handler,
prev: Cell::new(ptr::null()),
capture_backtrace,
old_last_wasm_exit_fp,
old_last_wasm_exit_pc,
old_last_wasm_entry_sp,
limits,
}
}

fn with(
self,
mut self,
closure: impl FnOnce(&CallThreadState) -> i32,
) -> Result<(), (UnwindReason, Option<Backtrace>)> {
let ret = tls::set(&self, || closure(&self));
let ret = tls::set(&mut self, |me| closure(me));
if ret != 0 {
Ok(())
} else {
Expand Down Expand Up @@ -366,7 +466,7 @@ impl CallThreadState {
let mut state = Some(self);
std::iter::from_fn(move || {
let this = state?;
state = unsafe { this.prev.get().as_ref() };
state = unsafe { this.prev().as_ref() };
Some(this)
})
}
Expand Down Expand Up @@ -462,7 +562,9 @@ mod tls {

/// Opaque state used to help control TLS state across stack switches for
/// async support.
pub struct TlsRestore(raw::Ptr);
pub struct TlsRestore {
state: raw::Ptr,
}

impl TlsRestore {
/// Takes the TLS state that is currently configured and returns a
Expand All @@ -476,14 +578,16 @@ mod tls {
// removing ourselves from the call-stack, and in the process we
// null out our own previous field for safety in case it's
// accidentally used later.
let raw = raw::get();
if !raw.is_null() {
let prev = (*raw).prev.replace(ptr::null());
raw::replace(prev);
let state = raw::get();
if let Some(state) = state.as_ref() {
let prev_state = state.set_prev(ptr::null());
raw::replace(prev_state);
} else {
// Null case: we aren't in a wasm context, so theres no tls to
// save for restoration.
}
// Null case: we aren't in a wasm context, so theres no tls
// to save for restoration.
TlsRestore(raw)

TlsRestore { state }
}

/// Restores a previous tls state back into this thread's TLS.
Expand All @@ -493,40 +597,50 @@ mod tls {
pub unsafe fn replace(self) {
// Null case: we aren't in a wasm context, so theres no tls
// to restore.
if self.0.is_null() {
if self.state.is_null() {
return;
}

// We need to configure our previous TLS pointer to whatever is in
// TLS at this time, and then we set the current state to ourselves.
let prev = raw::get();
assert!((*self.0).prev.get().is_null());
(*self.0).prev.set(prev);
raw::replace(self.0);
assert!((*self.state).prev().is_null());
(*self.state).set_prev(prev);
raw::replace(self.state);
}
}

/// Configures thread local state such that for the duration of the
/// execution of `closure` any call to `with` will yield `ptr`, unless this
/// is recursively called again.
/// execution of `closure` any call to `with` will yield `state`, unless
/// this is recursively called again.
#[inline]
pub fn set<R>(state: &CallThreadState, closure: impl FnOnce() -> R) -> R {
struct Reset<'a>(&'a CallThreadState);
pub fn set<R>(state: &mut CallThreadState, closure: impl FnOnce(&CallThreadState) -> R) -> R {
struct Reset<'a> {
state: &'a CallThreadState,
}

impl Drop for Reset<'_> {
#[inline]
fn drop(&mut self) {
raw::replace(self.0.prev.replace(ptr::null()));
unsafe {
let prev = self.state.set_prev(ptr::null());
let old_state = raw::replace(prev);
debug_assert!(std::ptr::eq(old_state, self.state));
}
}
}

let prev = raw::replace(state);
state.prev.set(prev);
let _reset = Reset(state);
closure()

unsafe {
state.set_prev(prev);

let reset = Reset { state };
closure(reset.state)
}
}

/// Returns the last pointer configured with `set` above. Panics if `set`
/// has not been previously called.
/// Returns the last pointer configured with `set` above, if any.
pub fn with<R>(closure: impl FnOnce(Option<&CallThreadState>) -> R) -> R {
let p = raw::get();
unsafe { closure(if p.is_null() { None } else { Some(&*p) }) }
Expand Down
Loading