diff --git a/Cargo.lock b/Cargo.lock index 25457ee41d9c..944fe518213b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4492,6 +4492,7 @@ version = "12.0.0" dependencies = [ "anyhow", "heck", + "indexmap 2.0.0", "wit-parser", ] diff --git a/crates/component-macro/src/bindgen.rs b/crates/component-macro/src/bindgen.rs index 7415241fe191..3dfd7f846ea4 100644 --- a/crates/component-macro/src/bindgen.rs +++ b/crates/component-macro/src/bindgen.rs @@ -251,18 +251,9 @@ impl Parse for Opt { input.parse::()?; let contents; let _lbrace = braced!(contents in input); - let fields: Punctuated<(String, String, String), Token![,]> = + let fields: Punctuated<_, Token![,]> = contents.parse_terminated(trappable_error_field_parse, Token![,])?; - Ok(Opt::TrappableErrorType( - fields - .into_iter() - .map(|(wit_owner, wit_name, rust_name)| TrappableError { - wit_owner: Some(wit_owner), - wit_name, - rust_name, - }) - .collect(), - )) + Ok(Opt::TrappableErrorType(Vec::from_iter(fields.into_iter()))) } else if l.peek(kw::interfaces) { input.parse::()?; input.parse::()?; @@ -281,7 +272,7 @@ impl Parse for Opt { } } -fn trappable_error_field_parse(input: ParseStream<'_>) -> Result<(String, String, String)> { +fn trappable_error_field_parse(input: ParseStream<'_>) -> Result { // Accept a Rust identifier or a string literal. This is required // because not all wit identifiers are Rust identifiers, so we can // smuggle the invalid ones inside quotes. @@ -296,12 +287,16 @@ fn trappable_error_field_parse(input: ParseStream<'_>) -> Result<(String, String } } - let interface = ident_or_str(input)?; + let wit_package_path = input.parse::()?.value(); input.parse::()?; - let type_ = ident_or_str(input)?; + let wit_type_name = ident_or_str(input)?; input.parse::()?; - let rust_type = input.parse::()?.to_string(); - Ok((interface, type_, rust_type)) + let rust_type_name = input.parse::()?.to_string(); + Ok(TrappableError { + wit_package_path, + wit_type_name, + rust_type_name, + }) } fn with_field_parse(input: ParseStream<'_>) -> Result<(String, String)> { diff --git a/crates/wasi/src/preview2/command.rs b/crates/wasi/src/preview2/command.rs index fccfeaf52ed6..4b49491ea063 100644 --- a/crates/wasi/src/preview2/command.rs +++ b/crates/wasi/src/preview2/command.rs @@ -5,8 +5,8 @@ wasmtime::component::bindgen!({ tracing: true, async: true, trappable_error_type: { - "filesystem"::"error-code": Error, - "streams"::"stream-error": Error, + "wasi:filesystem/filesystem"::"error-code": Error, + "wasi:io/streams"::"stream-error": Error, }, with: { "wasi:filesystem/filesystem": crate::preview2::bindings::filesystem::filesystem, @@ -50,8 +50,8 @@ pub mod sync { tracing: true, async: false, trappable_error_type: { - "filesystem"::"error-code": Error, - "streams"::"stream-error": Error, + "wasi:filesystem/filesystem"::"error-code": Error, + "wasi:io/streams"::"stream-error": Error, }, with: { "wasi:filesystem/filesystem": crate::preview2::bindings::sync_io::filesystem::filesystem, diff --git a/crates/wasi/src/preview2/mod.rs b/crates/wasi/src/preview2/mod.rs index 7ab0a1b430b7..16390462e26e 100644 --- a/crates/wasi/src/preview2/mod.rs +++ b/crates/wasi/src/preview2/mod.rs @@ -53,8 +53,8 @@ pub mod bindings { ", tracing: true, trappable_error_type: { - "streams"::"stream-error": Error, - "filesystem"::"error-code": Error, + "wasi:io/streams"::"stream-error": Error, + "wasi:filesystem/filesystem"::"error-code": Error, }, with: { "wasi:clocks/wall-clock": crate::preview2::bindings::clocks::wall_clock, @@ -104,8 +104,8 @@ pub mod bindings { tracing: true, async: true, trappable_error_type: { - "streams"::"stream-error": Error, - "filesystem"::"error-code": Error, + "wasi:io/streams"::"stream-error": Error, + "wasi:filesystem/filesystem"::"error-code": Error, }, with: { "wasi:clocks/wall-clock": crate::preview2::bindings::clocks::wall_clock, @@ -133,8 +133,8 @@ pub mod bindings { ", tracing: true, trappable_error_type: { - "filesystem"::"error-code": Error, - "streams"::"stream-error": Error, + "wasi:filesystem/filesystem"::"error-code": Error, + "wasi:io/streams"::"stream-error": Error, }, with: { "wasi:clocks/wall-clock": crate::preview2::bindings::clocks::wall_clock, diff --git a/crates/wit-bindgen/Cargo.toml b/crates/wit-bindgen/Cargo.toml index fc50af8b42fa..cc4b27260bbe 100644 --- a/crates/wit-bindgen/Cargo.toml +++ b/crates/wit-bindgen/Cargo.toml @@ -12,3 +12,4 @@ edition.workspace = true anyhow = { workspace = true } heck = { workspace = true } wit-parser = { workspace = true } +indexmap = { workspace = true } diff --git a/crates/wit-bindgen/src/lib.rs b/crates/wit-bindgen/src/lib.rs index ff2ca10a437b..2d86124f1a8d 100644 --- a/crates/wit-bindgen/src/lib.rs +++ b/crates/wit-bindgen/src/lib.rs @@ -1,6 +1,8 @@ use crate::rust::{to_rust_ident, to_rust_upper_camel_case, RustGenerator, TypeMode}; use crate::types::{TypeInfo, Types}; +use anyhow::{anyhow, bail, Context}; use heck::*; +use indexmap::IndexMap; use std::collections::{BTreeMap, HashMap}; use std::fmt::Write as _; use std::io::{Read, Write}; @@ -110,19 +112,14 @@ pub struct Opts { #[derive(Debug, Clone)] pub struct TrappableError { - /// The name of the error in WIT that is being mapped. - pub wit_name: String, + /// The package and interface that define the error type being mapped. + pub wit_package_path: String, - /// The owner container of the error in WIT of the error that's being - /// mapped. - /// - /// This is, for example, the name of the WIT interface or the WIT world - /// which owns the type. If this is set to `None` then any error type with - /// `wit_name` is remapped to `rust_name`. - pub wit_owner: Option, + /// The name of the error type in WIT that is being mapped. + pub wit_type_name: String, /// The name, in Rust, of the error type to generate. - pub rust_name: String, + pub rust_type_name: String, } impl Opts { @@ -191,7 +188,7 @@ impl Wasmtime { let mut gen = InterfaceGenerator::new(self, resolve); match item { WorldItem::Function(func) => { - gen.generate_function_trait_sig(TypeOwner::None, func); + gen.generate_function_trait_sig(func); let sig = mem::take(&mut gen.src).into(); gen.generate_add_function_to_linker(TypeOwner::None, func, "linker"); let add_to_linker = gen.src.into(); @@ -204,7 +201,6 @@ impl Wasmtime { } gen.current_interface = Some((*id, name, false)); gen.types(*id); - gen.generate_trappable_error_types(TypeOwner::Interface(*id)); let key_name = resolve.name_world_key(name); gen.generate_add_to_linker(*id, &key_name); @@ -271,7 +267,6 @@ impl Wasmtime { gen.gen.name_interface(resolve, *id, name); gen.current_interface = Some((*id, name, true)); gen.types(*id); - gen.generate_trappable_error_types(TypeOwner::Interface(*id)); let iface = &resolve.interfaces[*id]; let iface_name = match name { WorldKey::Name(name) => name, @@ -643,26 +638,120 @@ impl Wasmtime { } } +fn resolve_type_in_package( + resolve: &Resolve, + package_path: &str, + type_name: &str, +) -> anyhow::Result { + // foo:bar/baz + + let (namespace, rest) = package_path + .split_once(':') + .ok_or_else(|| anyhow!("Invalid package path: missing package identifier"))?; + + let (package_name, iface_name) = rest + .split_once('/') + .ok_or_else(|| anyhow!("Invalid package path: missing namespace separator"))?; + + // TODO: we should handle version annotations + if package_name.contains('@') { + bail!("Invalid package path: version parsing is not currently handled"); + } + + let packages = Vec::from_iter( + resolve + .package_names + .iter() + .filter(|(pname, _)| pname.namespace == namespace && pname.name == package_name), + ); + + if packages.len() != 1 { + if packages.is_empty() { + bail!("No package named `{}`", namespace); + } else { + // Getting here is a bug, parsing version identifiers would disambiguate the intended + // package. + bail!( + "Multiple packages named `{}` found ({:?})", + namespace, + packages + ); + } + } + + let (_, &package_id) = packages[0]; + let package = &resolve.packages[package_id]; + + let (_, &iface_id) = package + .interfaces + .iter() + .find(|(name, _)| name.as_str() == iface_name) + .ok_or_else(|| { + anyhow!( + "Unknown interface `{}` in package `{}`", + iface_name, + package_path + ) + })?; + + let iface = &resolve.interfaces[iface_id]; + + let (_, &type_id) = iface + .types + .iter() + .find(|(n, _)| n.as_str() == type_name) + .ok_or_else(|| { + anyhow!( + "No type named `{}` in package `{}`", + package_name, + package_path + ) + })?; + + Ok(type_id) +} + struct InterfaceGenerator<'a> { src: Source, gen: &'a mut Wasmtime, resolve: &'a Resolve, current_interface: Option<(InterfaceId, &'a WorldKey, bool)>, + + /// A mapping of wit types to their rust type name equivalent. This is the pre-processed + /// version of `gen.opts.trappable_error_types`, where the types have been eagerly resolved. + trappable_errors: IndexMap, } impl<'a> InterfaceGenerator<'a> { fn new(gen: &'a mut Wasmtime, resolve: &'a Resolve) -> InterfaceGenerator<'a> { + let trappable_errors = gen + .opts + .trappable_error_type + .iter() + .map(|te| { + let id = resolve_type_in_package(resolve, &te.wit_package_path, &te.wit_type_name) + .context(format!("resolving {:?}", te))?; + Ok((id, te.rust_type_name.clone())) + }) + .collect::>>() + .unwrap(); + InterfaceGenerator { src: Source::default(), gen, resolve, current_interface: None, + trappable_errors, } } fn types(&mut self, id: InterfaceId) { for (name, id) in self.resolve.interfaces[id].types.iter() { self.define_type(name, *id); + + if let Some(rust_name) = self.trappable_errors.get(id) { + self.define_trappable_error_type(*id, rust_name.clone()) + } } } @@ -1140,9 +1229,8 @@ impl<'a> InterfaceGenerator<'a> { fn special_case_trappable_error( &self, - owner: TypeOwner, results: &Results, - ) -> Option<(&'a Result_, String)> { + ) -> Option<(&'a Result_, TypeId, String)> { // We fillin a special trappable error type in the case when a function has just one // result, which is itself a `result`, and the `e` is *not* a primitive // (i.e. defined in std) type, and matches the typename given by the user. @@ -1163,9 +1251,9 @@ impl<'a> InterfaceGenerator<'a> { _ => return None, }; - self.trappable_error_types(owner) - .find(|(wit_error_typeid, _)| error_typeid == *wit_error_typeid) - .map(|(_, rust_errortype)| (result, rust_errortype)) + let rust_type = self.trappable_errors.get(&error_typeid)?; + + Some((result, error_typeid, rust_type.clone())) } fn generate_add_to_linker(&mut self, id: InterfaceId, name: &str) { @@ -1179,7 +1267,7 @@ impl<'a> InterfaceGenerator<'a> { // this import. uwriteln!(self.src, "pub trait Host {{"); for (_, func) in iface.functions.iter() { - self.generate_function_trait_sig(owner, func); + self.generate_function_trait_sig(func); } uwriteln!(self.src, "}}"); @@ -1302,10 +1390,7 @@ impl<'a> InterfaceGenerator<'a> { ); } - if self - .special_case_trappable_error(owner, &func.results) - .is_some() - { + if self.special_case_trappable_error(&func.results).is_some() { uwrite!( self.src, "match r {{ @@ -1330,7 +1415,7 @@ impl<'a> InterfaceGenerator<'a> { } } - fn generate_function_trait_sig(&mut self, owner: TypeOwner, func: &Function) { + fn generate_function_trait_sig(&mut self, func: &Function) { self.rustdoc(&func.docs); if self.gen.opts.async_ { @@ -1349,7 +1434,9 @@ impl<'a> InterfaceGenerator<'a> { self.push_str(")"); self.push_str(" -> "); - if let Some((r, error_typename)) = self.special_case_trappable_error(owner, &func.results) { + if let Some((r, error_id, error_typename)) = + self.special_case_trappable_error(&func.results) + { // Functions which have a single result `result` get special // cased to use the host_wasmtime_rust::Error, making it possible // for them to trap or use `?` to propogate their errors @@ -1360,6 +1447,12 @@ impl<'a> InterfaceGenerator<'a> { self.push_str("()"); } self.push_str(","); + if let TypeOwner::Interface(id) = self.resolve.types[error_id].owner { + if let Some(path) = self.path_to_interface(id) { + self.push_str(&path); + self.push_str("::"); + } + } self.push_str(&error_typename); self.push_str(">"); } else { @@ -1498,68 +1591,31 @@ impl<'a> InterfaceGenerator<'a> { self.src.push_str("}\n"); } - fn trappable_error_types( - &self, - owner: TypeOwner, - ) -> impl Iterator + '_ { - let resolve = self.resolve; - self.gen - .opts - .trappable_error_type - .iter() - .filter_map(move |trappable| { - if let Some(name) = &trappable.wit_owner { - let owner_name = match owner { - TypeOwner::Interface(id) => resolve.interfaces[id].name.as_deref()?, - TypeOwner::World(id) => &resolve.worlds[id].name, - TypeOwner::None => return None, - }; - if owner_name != name { - return None; - } - } - let id = match owner { - TypeOwner::Interface(id) => { - *resolve.interfaces[id].types.get(&trappable.wit_name)? - } - // TODO: right now worlds can't have types defined within - // them but that's just a temporary limitation of - // `wit-parser`. Once that's filled in this should be - // replaced with a type-lookup in the world. - TypeOwner::World(_id) => unimplemented!(), - TypeOwner::None => return None, - }; - - Some((id, trappable.rust_name.clone())) - }) - } - - fn generate_trappable_error_types(&mut self, owner: TypeOwner) { - for (wit_type, trappable_type) in self.trappable_error_types(owner).collect::>() { - let info = self.info(wit_type); - if self.lifetime_for(&info, TypeMode::Owned).is_some() { - panic!("wit error for {trappable_type} is not 'static") - } - let abi_type = self.param_name(wit_type); + fn define_trappable_error_type(&mut self, id: TypeId, rust_name: String) { + let info = self.info(id); + if self.lifetime_for(&info, TypeMode::Owned).is_some() { + panic!("wit error for {rust_name} is not 'static") + } + let abi_type = self.param_name(id); - uwriteln!( - self.src, - " + uwriteln!( + self.src, + " #[derive(Debug)] - pub struct {trappable_type} {{ + pub struct {rust_name} {{ inner: anyhow::Error, }} - impl std::fmt::Display for {trappable_type} {{ + impl std::fmt::Display for {rust_name} {{ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {{ write!(f, \"{{}}\", self.inner) }} }} - impl std::error::Error for {trappable_type} {{ + impl std::error::Error for {rust_name} {{ fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {{ self.inner.source() }} }} - impl {trappable_type} {{ + impl {rust_name} {{ pub fn trap(inner: anyhow::Error) -> Self {{ Self {{ inner }} }} @@ -1573,14 +1629,13 @@ impl<'a> InterfaceGenerator<'a> { Self {{ inner: self.inner.context(s.into()) }} }} }} - impl From<{abi_type}> for {trappable_type} {{ - fn from(abi: {abi_type}) -> {trappable_type} {{ - {trappable_type} {{ inner: anyhow::Error::from(abi) }} + impl From<{abi_type}> for {rust_name} {{ + fn from(abi: {abi_type}) -> {rust_name} {{ + {rust_name} {{ inner: anyhow::Error::from(abi) }} }} }} " - ); - } + ); } fn rustdoc(&mut self, docs: &Docs) { diff --git a/tests/all/component_model/bindgen/results.rs b/tests/all/component_model/bindgen/results.rs index d0c49e40245e..dc11558661ef 100644 --- a/tests/all/component_model/bindgen/results.rs +++ b/tests/all/component_model/bindgen/results.rs @@ -238,7 +238,7 @@ mod enum_error { enum-error: func(a: float64) -> result } }", - trappable_error_type: { imports::e1: TrappableE1 } + trappable_error_type: { "inline:inline/imports"::e1: TrappableE1 } }); #[test] @@ -402,7 +402,7 @@ mod record_error { }", // Literal strings can be used for the interface and typename fields instead of // identifiers, because wit identifiers arent always Rust identifiers. - trappable_error_type: { "imports"::"e2": TrappableE2 } + trappable_error_type: { "inline:inline/imports"::"e2": TrappableE2 } }); #[test] @@ -556,7 +556,7 @@ mod variant_error { variant-error: func(a: float64) -> result } }", - trappable_error_type: { imports::e3: TrappableE3 } + trappable_error_type: { "inline:inline/imports"::e3: TrappableE3 } }); #[test]