diff --git a/crates/test-programs/src/bin/api_proxy.rs b/crates/test-programs/src/bin/api_proxy.rs index ae75c1d3bf49..1bcdc783dd41 100644 --- a/crates/test-programs/src/bin/api_proxy.rs +++ b/crates/test-programs/src/bin/api_proxy.rs @@ -16,7 +16,7 @@ struct T; impl bindings::exports::wasi::http::incoming_handler::Guest for T { fn handle(_request: IncomingRequest, outparam: ResponseOutparam) { - let hdrs = bindings::wasi::http::types::Headers::new(&[]); + let hdrs = bindings::wasi::http::types::Headers::new(); let resp = bindings::wasi::http::types::OutgoingResponse::new(200, hdrs); let body = resp.body().expect("outgoing response"); diff --git a/crates/test-programs/src/bin/api_proxy_streaming.rs b/crates/test-programs/src/bin/api_proxy_streaming.rs index e63223058c55..c2cde5c155e0 100644 --- a/crates/test-programs/src/bin/api_proxy_streaming.rs +++ b/crates/test-programs/src/bin/api_proxy_streaming.rs @@ -51,7 +51,7 @@ async fn handle_request(request: IncomingRequest, response_out: ResponseOutparam let response = OutgoingResponse::new( 200, - Fields::new(&[("content-type".to_string(), b"text/plain".to_vec())]), + Fields::from_list(&[("content-type".to_string(), b"text/plain".to_vec())]).unwrap(), ); let mut body = @@ -75,12 +75,13 @@ async fn handle_request(request: IncomingRequest, response_out: ResponseOutparam (Method::Post, Some("/echo")) => { let response = OutgoingResponse::new( 200, - Fields::new( + Fields::from_list( &headers .into_iter() .filter_map(|(k, v)| (k == "content-type").then_some((k, v))) .collect::>(), - ), + ) + .unwrap(), ); let mut body = @@ -108,7 +109,7 @@ async fn handle_request(request: IncomingRequest, response_out: ResponseOutparam } _ => { - let response = OutgoingResponse::new(405, Fields::new(&[])); + let response = OutgoingResponse::new(405, Fields::new()); let body = response.body().expect("response should be writable"); @@ -137,7 +138,7 @@ async fn hash(url: &Url) -> Result { String::new() } )), - Fields::new(&[]), + Fields::new(), ); let response = executor::outgoing_request_send(request).await?; diff --git a/crates/test-programs/src/bin/http_outbound_request_invalid_header.rs b/crates/test-programs/src/bin/http_outbound_request_invalid_header.rs new file mode 100644 index 000000000000..09f1732983ba --- /dev/null +++ b/crates/test-programs/src/bin/http_outbound_request_invalid_header.rs @@ -0,0 +1,60 @@ +use test_programs::wasi::http::types::{HeaderError, Headers}; + +fn main() { + let hdrs = Headers::new(); + assert!(matches!( + hdrs.append(&"malformed header name".to_owned(), &b"ok value".to_vec()), + Err(HeaderError::InvalidSyntax) + )); + + assert!(matches!( + hdrs.append(&"ok-header-name".to_owned(), &b"ok value".to_vec()), + Ok(()) + )); + + assert!(matches!( + hdrs.append(&"ok-header-name".to_owned(), &b"bad\nvalue".to_vec()), + Err(HeaderError::InvalidSyntax) + )); + + assert!(matches!( + hdrs.append(&"Connection".to_owned(), &b"keep-alive".to_vec()), + Err(HeaderError::Forbidden) + )); + + assert!(matches!( + hdrs.append(&"Keep-Alive".to_owned(), &b"stuff".to_vec()), + Err(HeaderError::Forbidden) + )); + + assert!(matches!( + hdrs.append( + &"custom-forbidden-header".to_owned(), + &b"keep-alive".to_vec() + ), + Err(HeaderError::Forbidden) + )); + + assert!(matches!( + hdrs.append( + &"Custom-Forbidden-Header".to_owned(), + &b"keep-alive".to_vec() + ), + Err(HeaderError::Forbidden) + )); + + assert!(matches!( + Headers::from_list(&[("bad header".to_owned(), b"value".to_vec())]), + Err(HeaderError::InvalidSyntax) + )); + + assert!(matches!( + Headers::from_list(&[("custom-forbidden-header".to_owned(), b"value".to_vec())]), + Err(HeaderError::Forbidden) + )); + + assert!(matches!( + Headers::from_list(&[("ok-header-name".to_owned(), b"bad\nvalue".to_vec())]), + Err(HeaderError::InvalidSyntax) + )); +} diff --git a/crates/test-programs/src/bin/http_outbound_request_response_build.rs b/crates/test-programs/src/bin/http_outbound_request_response_build.rs index d9a29822d33c..8a50ff6fa693 100644 --- a/crates/test-programs/src/bin/http_outbound_request_response_build.rs +++ b/crates/test-programs/src/bin/http_outbound_request_response_build.rs @@ -3,10 +3,11 @@ use test_programs::wasi::http::types as http_types; fn main() { println!("Called _start"); { - let headers = http_types::Headers::new(&[( + let headers = http_types::Headers::from_list(&[( "Content-Type".to_string(), "application/json".to_string().into_bytes(), - )]); + )]) + .unwrap(); let request = http_types::OutgoingRequest::new( &http_types::Method::Get, None, @@ -21,10 +22,11 @@ fn main() { .unwrap(); } { - let headers = http_types::Headers::new(&[( + let headers = http_types::Headers::from_list(&[( "Content-Type".to_string(), "application/text".to_string().into_bytes(), - )]); + )]) + .unwrap(); let response = http_types::OutgoingResponse::new(200, headers); let outgoing_body = response.body().unwrap(); let response_body = outgoing_body.write().unwrap(); diff --git a/crates/test-programs/src/http.rs b/crates/test-programs/src/http.rs index f065b45ee701..4637ec713570 100644 --- a/crates/test-programs/src/http.rs +++ b/crates/test-programs/src/http.rs @@ -42,7 +42,7 @@ pub fn request( fn header_val(v: &str) -> Vec { v.to_string().into_bytes() } - let headers = http_types::Headers::new( + let headers = http_types::Headers::from_list( &[ &[ ("User-agent".to_string(), header_val("WASI-HTTP/0.0.1")), @@ -51,7 +51,7 @@ pub fn request( additional_headers.unwrap_or(&[]), ] .concat(), - ); + )?; let request = http_types::OutgoingRequest::new( &method, diff --git a/crates/wasi-http/src/types.rs b/crates/wasi-http/src/types.rs index 50119b98977b..0facf6245dd2 100644 --- a/crates/wasi-http/src/types.rs +++ b/crates/wasi-http/src/types.rs @@ -7,6 +7,7 @@ use crate::{ }; use anyhow::Context; use http_body_util::BodyExt; +use hyper::header::HeaderName; use std::any::Any; use std::time::Duration; use tokio::net::TcpStream; @@ -65,6 +66,10 @@ pub trait WasiHttpView: Send { { default_send_request(self, request) } + + fn is_forbidden_header(&mut self, _name: &HeaderName) -> bool { + false + } } pub fn default_send_request( diff --git a/crates/wasi-http/src/types_impl.rs b/crates/wasi-http/src/types_impl.rs index d1c8eabb6df3..2b21d7329b9e 100644 --- a/crates/wasi-http/src/types_impl.rs +++ b/crates/wasi-http/src/types_impl.rs @@ -1,4 +1,6 @@ -use crate::bindings::http::types::{self, Error, Headers, Method, Scheme, StatusCode, Trailers}; +use crate::bindings::http::types::{ + self, Error, HeaderError, Headers, Method, Scheme, StatusCode, Trailers, +}; use crate::body::{FinishMessage, HostFutureTrailers, HostFutureTrailersState}; use crate::types::{HostIncomingRequest, HostOutgoingResponse}; use crate::WasiHttpView; @@ -10,6 +12,7 @@ use crate::{ }, }; use anyhow::Context; +use hyper::header::HeaderName; use std::any::Any; use wasmtime::component::Resource; use wasmtime_wasi::preview2::{ @@ -51,13 +54,55 @@ fn get_fields_mut<'a>( } } +fn is_forbidden_header(view: &mut T, name: &HeaderName) -> bool { + static FORBIDDEN_HEADERS: [HeaderName; 9] = [ + hyper::header::CONNECTION, + HeaderName::from_static("keep-alive"), + hyper::header::PROXY_AUTHENTICATE, + hyper::header::PROXY_AUTHORIZATION, + HeaderName::from_static("proxy-connection"), + hyper::header::TE, + hyper::header::TRANSFER_ENCODING, + hyper::header::UPGRADE, + HeaderName::from_static("http2-settings"), + ]; + + FORBIDDEN_HEADERS.contains(name) || view.is_forbidden_header(name) +} + impl crate::bindings::http::types::HostFields for T { - fn new(&mut self, entries: Vec<(String, Vec)>) -> wasmtime::Result> { + fn new(&mut self) -> wasmtime::Result> { + let id = self + .table() + .push(HostFields::Owned { + fields: hyper::HeaderMap::new(), + }) + .context("[new_fields] pushing fields")?; + + Ok(id) + } + + fn from_list( + &mut self, + entries: Vec<(String, Vec)>, + ) -> wasmtime::Result, HeaderError>> { let mut map = hyper::HeaderMap::new(); for (header, value) in entries { - let header = hyper::header::HeaderName::from_bytes(header.as_bytes())?; - let value = hyper::header::HeaderValue::from_bytes(&value)?; + let header = match hyper::header::HeaderName::from_bytes(header.as_bytes()) { + Ok(header) => header, + Err(_) => return Ok(Err(HeaderError::InvalidSyntax)), + }; + + if is_forbidden_header(self, &header) { + return Ok(Err(HeaderError::Forbidden)); + } + + let value = match hyper::header::HeaderValue::from_bytes(&value) { + Ok(value) => value, + Err(_) => return Ok(Err(HeaderError::InvalidSyntax)), + }; + map.append(header, value); } @@ -66,7 +111,7 @@ impl crate::bindings::http::types::HostFields for T { .push(HostFields::Owned { fields: map }) .context("[new_fields] pushing fields")?; - Ok(id) + Ok(Ok(id)) } fn drop(&mut self, fields: Resource) -> wasmtime::Result<()> { @@ -81,9 +126,14 @@ impl crate::bindings::http::types::HostFields for T { fields: Resource, name: String, ) -> wasmtime::Result>> { + let header = match hyper::header::HeaderName::from_bytes(name.as_bytes()) { + Ok(header) => header, + Err(_) => return Ok(vec![]), + }; + let res = get_fields_mut(self.table(), &fields) .context("[fields_get] getting fields")? - .get_all(hyper::header::HeaderName::from_bytes(name.as_bytes())?) + .get_all(header) .into_iter() .map(|val| val.as_bytes().to_owned()) .collect(); @@ -94,24 +144,42 @@ impl crate::bindings::http::types::HostFields for T { &mut self, fields: Resource, name: String, - values: Vec>, - ) -> wasmtime::Result<()> { - let m = get_fields_mut(self.table(), &fields)?; + byte_values: Vec>, + ) -> wasmtime::Result> { + let header = match hyper::header::HeaderName::from_bytes(name.as_bytes()) { + Ok(header) => header, + Err(_) => return Ok(Err(HeaderError::InvalidSyntax)), + }; + + if is_forbidden_header(self, &header) { + return Ok(Err(HeaderError::Forbidden)); + } - let header = hyper::header::HeaderName::from_bytes(name.as_bytes())?; + let mut values = Vec::with_capacity(byte_values.len()); + for value in byte_values { + match hyper::header::HeaderValue::from_bytes(&value) { + Ok(value) => values.push(value), + Err(_) => return Ok(Err(HeaderError::InvalidSyntax)), + } + } + let m = + get_fields_mut(self.table(), &fields).context("[fields_set] getting mutable fields")?; m.remove(&header); for value in values { - let value = hyper::header::HeaderValue::from_bytes(&value)?; m.append(&header, value); } - Ok(()) + Ok(Ok(())) } fn delete(&mut self, fields: Resource, name: String) -> wasmtime::Result<()> { + let header = match hyper::header::HeaderName::from_bytes(name.as_bytes()) { + Ok(header) => header, + Err(_) => return Ok(()), + }; + let m = get_fields_mut(self.table(), &fields)?; - let header = hyper::header::HeaderName::from_bytes(name.as_bytes())?; m.remove(header); Ok(()) } @@ -121,13 +189,26 @@ impl crate::bindings::http::types::HostFields for T { fields: Resource, name: String, value: Vec, - ) -> wasmtime::Result<()> { + ) -> wasmtime::Result> { + let header = match hyper::header::HeaderName::from_bytes(name.as_bytes()) { + Ok(header) => header, + Err(_) => return Ok(Err(HeaderError::InvalidSyntax)), + }; + + if is_forbidden_header(self, &header) { + return Ok(Err(HeaderError::Forbidden)); + } + + let value = match hyper::header::HeaderValue::from_bytes(&value) { + Ok(value) => value, + Err(_) => return Ok(Err(HeaderError::InvalidSyntax)), + }; + let m = get_fields_mut(self.table(), &fields) .context("[fields_append] getting mutable fields")?; - let header = hyper::header::HeaderName::from_bytes(name.as_bytes())?; - let value = hyper::header::HeaderValue::from_bytes(&value)?; + m.append(header, value); - Ok(()) + Ok(Ok(())) } fn entries( diff --git a/crates/wasi-http/tests/all/async_.rs b/crates/wasi-http/tests/all/async_.rs index 128c1f25fce9..3b74482e56d3 100644 --- a/crates/wasi-http/tests/all/async_.rs +++ b/crates/wasi-http/tests/all/async_.rs @@ -49,6 +49,12 @@ async fn http_outbound_request_invalid_version() -> Result<()> { run(HTTP_OUTBOUND_REQUEST_INVALID_VERSION_COMPONENT, &server).await } +#[test_log::test(tokio::test(flavor = "multi_thread"))] +async fn http_outbound_request_invalid_header() -> Result<()> { + let server = Server::http2()?; + run(HTTP_OUTBOUND_REQUEST_INVALID_HEADER_COMPONENT, &server).await +} + #[test_log::test(tokio::test(flavor = "multi_thread"))] async fn http_outbound_request_unknown_method() -> Result<()> { let server = Server::http1()?; diff --git a/crates/wasi-http/tests/all/main.rs b/crates/wasi-http/tests/all/main.rs index f9c68da0a9b7..84e3ef12df91 100644 --- a/crates/wasi-http/tests/all/main.rs +++ b/crates/wasi-http/tests/all/main.rs @@ -72,6 +72,10 @@ impl WasiHttpView for Ctx { types::default_send_request(self, request) } } + + fn is_forbidden_header(&mut self, name: &hyper::header::HeaderName) -> bool { + name.as_str() == "custom-forbidden-header" + } } fn store(engine: &Engine, server: &Server) -> Store { diff --git a/crates/wasi-http/tests/all/sync.rs b/crates/wasi-http/tests/all/sync.rs index 729b1e427a38..ea9328aff73e 100644 --- a/crates/wasi-http/tests/all/sync.rs +++ b/crates/wasi-http/tests/all/sync.rs @@ -48,6 +48,12 @@ fn http_outbound_request_invalid_version() -> Result<()> { run(HTTP_OUTBOUND_REQUEST_INVALID_VERSION_COMPONENT, &server) } +#[test_log::test] +fn http_outbound_request_invalid_header() -> Result<()> { + let server = Server::http2()?; + run(HTTP_OUTBOUND_REQUEST_INVALID_HEADER_COMPONENT, &server) +} + #[test_log::test] fn http_outbound_request_unknown_method() -> Result<()> { let server = Server::http1()?; diff --git a/crates/wasi-http/wit/deps/http/types.wit b/crates/wasi-http/wit/deps/http/types.wit index 9cace67ab294..1ab5bd11741d 100644 --- a/crates/wasi-http/wit/deps/http/types.wit +++ b/crates/wasi-http/wit/deps/http/types.wit @@ -36,6 +36,13 @@ interface types { unexpected-error(string) } + /// This tyep enumerates the different kinds of errors that may occur when + /// setting or appending to a `fields` resource. + variant header-error { + invalid-syntax, + forbidden, + } + /// Field keys are always strings. type field-key = string; @@ -49,6 +56,9 @@ interface types { /// Headers and Trailers. resource fields { + /// Construct an empty HTTP Fields. + constructor(); + /// Construct an HTTP Fields. /// /// The list represents each key-value pair in the Fields. Keys @@ -59,14 +69,22 @@ interface types { /// Value, represented as a list of bytes. In a valid Fields, all keys /// and values are valid UTF-8 strings. However, values are not always /// well-formed, so they are represented as a raw list of bytes. - constructor(entries: list>); + /// + /// An error result will be returned if any header or value was + /// syntactically invalid, or if a header was forbidden. + from-list: static func( + entries: list> + ) -> result; /// Get all of the values corresponding to a key. get: func(name: field-key) -> list; /// Set all of the values for a key. Clears any existing values for that /// key, if they have been set. - set: func(name: field-key, value: list); + /// + /// The operation can fail if the name or value arguments are invalid, or if + /// the name is forbidden. + set: func(name: field-key, value: list) -> result<_, header-error>; /// Delete all values for a key. Does nothing if no values for the key /// exist. @@ -74,7 +92,10 @@ interface types { /// Append a value for a key. Does not change or delete any existing /// values for that key. - append: func(name: field-key, value: field-value); + /// + /// The operation can fail if the name or value arguments are invalid, or if + /// the name is forbidden. + append: func(name: field-key, value: field-value) -> result<_, header-error>; /// Retrieve the full set of keys and values in the Fields. Like the diff --git a/crates/wasi/wit/deps/http/types.wit b/crates/wasi/wit/deps/http/types.wit index 9cace67ab294..1ab5bd11741d 100644 --- a/crates/wasi/wit/deps/http/types.wit +++ b/crates/wasi/wit/deps/http/types.wit @@ -36,6 +36,13 @@ interface types { unexpected-error(string) } + /// This tyep enumerates the different kinds of errors that may occur when + /// setting or appending to a `fields` resource. + variant header-error { + invalid-syntax, + forbidden, + } + /// Field keys are always strings. type field-key = string; @@ -49,6 +56,9 @@ interface types { /// Headers and Trailers. resource fields { + /// Construct an empty HTTP Fields. + constructor(); + /// Construct an HTTP Fields. /// /// The list represents each key-value pair in the Fields. Keys @@ -59,14 +69,22 @@ interface types { /// Value, represented as a list of bytes. In a valid Fields, all keys /// and values are valid UTF-8 strings. However, values are not always /// well-formed, so they are represented as a raw list of bytes. - constructor(entries: list>); + /// + /// An error result will be returned if any header or value was + /// syntactically invalid, or if a header was forbidden. + from-list: static func( + entries: list> + ) -> result; /// Get all of the values corresponding to a key. get: func(name: field-key) -> list; /// Set all of the values for a key. Clears any existing values for that /// key, if they have been set. - set: func(name: field-key, value: list); + /// + /// The operation can fail if the name or value arguments are invalid, or if + /// the name is forbidden. + set: func(name: field-key, value: list) -> result<_, header-error>; /// Delete all values for a key. Does nothing if no values for the key /// exist. @@ -74,7 +92,10 @@ interface types { /// Append a value for a key. Does not change or delete any existing /// values for that key. - append: func(name: field-key, value: field-value); + /// + /// The operation can fail if the name or value arguments are invalid, or if + /// the name is forbidden. + append: func(name: field-key, value: field-value) -> result<_, header-error>; /// Retrieve the full set of keys and values in the Fields. Like the