Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/test-programs/src/bin/api_proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ struct T;

impl bindings::exports::wasi::http::incoming_handler::Guest for T {
fn handle(_request: IncomingRequest, outparam: ResponseOutparam) {
let hdrs = bindings::wasi::http::types::Headers::new(&[]);
let hdrs = bindings::wasi::http::types::Headers::new();
let resp = bindings::wasi::http::types::OutgoingResponse::new(200, hdrs);
let body = resp.body().expect("outgoing response");

Expand Down
11 changes: 6 additions & 5 deletions crates/test-programs/src/bin/api_proxy_streaming.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ async fn handle_request(request: IncomingRequest, response_out: ResponseOutparam

let response = OutgoingResponse::new(
200,
Fields::new(&[("content-type".to_string(), b"text/plain".to_vec())]),
Fields::from_list(&[("content-type".to_string(), b"text/plain".to_vec())]).unwrap(),
);

let mut body =
Expand All @@ -75,12 +75,13 @@ async fn handle_request(request: IncomingRequest, response_out: ResponseOutparam
(Method::Post, Some("/echo")) => {
let response = OutgoingResponse::new(
200,
Fields::new(
Fields::from_list(
&headers
.into_iter()
.filter_map(|(k, v)| (k == "content-type").then_some((k, v)))
.collect::<Vec<_>>(),
),
)
.unwrap(),
);

let mut body =
Expand Down Expand Up @@ -108,7 +109,7 @@ async fn handle_request(request: IncomingRequest, response_out: ResponseOutparam
}

_ => {
let response = OutgoingResponse::new(405, Fields::new(&[]));
let response = OutgoingResponse::new(405, Fields::new());

let body = response.body().expect("response should be writable");

Expand Down Expand Up @@ -137,7 +138,7 @@ async fn hash(url: &Url) -> Result<String> {
String::new()
}
)),
Fields::new(&[]),
Fields::new(),
);

let response = executor::outgoing_request_send(request).await?;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
use test_programs::wasi::http::types::{HeaderError, Headers};

fn main() {
let hdrs = Headers::new();
assert!(matches!(
hdrs.append(&"malformed header name".to_owned(), &b"ok value".to_vec()),
Err(HeaderError::InvalidSyntax)
));

assert!(matches!(
hdrs.append(&"ok-header-name".to_owned(), &b"ok value".to_vec()),
Ok(())
));

assert!(matches!(
hdrs.append(&"ok-header-name".to_owned(), &b"bad\nvalue".to_vec()),
Err(HeaderError::InvalidSyntax)
));

assert!(matches!(
hdrs.append(&"Connection".to_owned(), &b"keep-alive".to_vec()),
Err(HeaderError::Forbidden)
));

assert!(matches!(
hdrs.append(&"Keep-Alive".to_owned(), &b"stuff".to_vec()),
Err(HeaderError::Forbidden)
));

assert!(matches!(
hdrs.append(
&"custom-forbidden-header".to_owned(),
&b"keep-alive".to_vec()
),
Err(HeaderError::Forbidden)
));

assert!(matches!(
hdrs.append(
&"Custom-Forbidden-Header".to_owned(),
&b"keep-alive".to_vec()
),
Err(HeaderError::Forbidden)
));

assert!(matches!(
Headers::from_list(&[("bad header".to_owned(), b"value".to_vec())]),
Err(HeaderError::InvalidSyntax)
));

assert!(matches!(
Headers::from_list(&[("custom-forbidden-header".to_owned(), b"value".to_vec())]),
Err(HeaderError::Forbidden)
));

assert!(matches!(
Headers::from_list(&[("ok-header-name".to_owned(), b"bad\nvalue".to_vec())]),
Err(HeaderError::InvalidSyntax)
));
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@ use test_programs::wasi::http::types as http_types;
fn main() {
println!("Called _start");
{
let headers = http_types::Headers::new(&[(
let headers = http_types::Headers::from_list(&[(
"Content-Type".to_string(),
"application/json".to_string().into_bytes(),
)]);
)])
.unwrap();
let request = http_types::OutgoingRequest::new(
&http_types::Method::Get,
None,
Expand All @@ -21,10 +22,11 @@ fn main() {
.unwrap();
}
{
let headers = http_types::Headers::new(&[(
let headers = http_types::Headers::from_list(&[(
"Content-Type".to_string(),
"application/text".to_string().into_bytes(),
)]);
)])
.unwrap();
let response = http_types::OutgoingResponse::new(200, headers);
let outgoing_body = response.body().unwrap();
let response_body = outgoing_body.write().unwrap();
Expand Down
4 changes: 2 additions & 2 deletions crates/test-programs/src/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ pub fn request(
fn header_val(v: &str) -> Vec<u8> {
v.to_string().into_bytes()
}
let headers = http_types::Headers::new(
let headers = http_types::Headers::from_list(
&[
&[
("User-agent".to_string(), header_val("WASI-HTTP/0.0.1")),
Expand All @@ -51,7 +51,7 @@ pub fn request(
additional_headers.unwrap_or(&[]),
]
.concat(),
);
)?;

let request = http_types::OutgoingRequest::new(
&method,
Expand Down
5 changes: 5 additions & 0 deletions crates/wasi-http/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use crate::{
};
use anyhow::Context;
use http_body_util::BodyExt;
use hyper::header::HeaderName;
use std::any::Any;
use std::time::Duration;
use tokio::net::TcpStream;
Expand Down Expand Up @@ -65,6 +66,10 @@ pub trait WasiHttpView: Send {
{
default_send_request(self, request)
}

fn is_forbidden_header(&mut self, _name: &HeaderName) -> bool {
false
}
}

pub fn default_send_request(
Expand Down
115 changes: 98 additions & 17 deletions crates/wasi-http/src/types_impl.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use crate::bindings::http::types::{self, Error, Headers, Method, Scheme, StatusCode, Trailers};
use crate::bindings::http::types::{
self, Error, HeaderError, Headers, Method, Scheme, StatusCode, Trailers,
};
use crate::body::{FinishMessage, HostFutureTrailers, HostFutureTrailersState};
use crate::types::{HostIncomingRequest, HostOutgoingResponse};
use crate::WasiHttpView;
Expand All @@ -10,6 +12,7 @@ use crate::{
},
};
use anyhow::Context;
use hyper::header::HeaderName;
use std::any::Any;
use wasmtime::component::Resource;
use wasmtime_wasi::preview2::{
Expand Down Expand Up @@ -51,13 +54,55 @@ fn get_fields_mut<'a>(
}
}

fn is_forbidden_header<T: WasiHttpView>(view: &mut T, name: &HeaderName) -> bool {
static FORBIDDEN_HEADERS: [HeaderName; 9] = [
hyper::header::CONNECTION,
HeaderName::from_static("keep-alive"),
hyper::header::PROXY_AUTHENTICATE,
hyper::header::PROXY_AUTHORIZATION,
HeaderName::from_static("proxy-connection"),
hyper::header::TE,
hyper::header::TRANSFER_ENCODING,
hyper::header::UPGRADE,
HeaderName::from_static("http2-settings"),
];

FORBIDDEN_HEADERS.contains(name) || view.is_forbidden_header(name)
}

impl<T: WasiHttpView> crate::bindings::http::types::HostFields for T {
fn new(&mut self, entries: Vec<(String, Vec<u8>)>) -> wasmtime::Result<Resource<HostFields>> {
fn new(&mut self) -> wasmtime::Result<Resource<HostFields>> {
let id = self
.table()
.push(HostFields::Owned {
fields: hyper::HeaderMap::new(),
})
.context("[new_fields] pushing fields")?;

Ok(id)
}

fn from_list(
&mut self,
entries: Vec<(String, Vec<u8>)>,
) -> wasmtime::Result<Result<Resource<HostFields>, HeaderError>> {
let mut map = hyper::HeaderMap::new();

for (header, value) in entries {
let header = hyper::header::HeaderName::from_bytes(header.as_bytes())?;
let value = hyper::header::HeaderValue::from_bytes(&value)?;
let header = match hyper::header::HeaderName::from_bytes(header.as_bytes()) {
Ok(header) => header,
Err(_) => return Ok(Err(HeaderError::InvalidSyntax)),
};

if is_forbidden_header(self, &header) {
return Ok(Err(HeaderError::Forbidden));
}

let value = match hyper::header::HeaderValue::from_bytes(&value) {
Ok(value) => value,
Err(_) => return Ok(Err(HeaderError::InvalidSyntax)),
};

map.append(header, value);
}

Expand All @@ -66,7 +111,7 @@ impl<T: WasiHttpView> crate::bindings::http::types::HostFields for T {
.push(HostFields::Owned { fields: map })
.context("[new_fields] pushing fields")?;

Ok(id)
Ok(Ok(id))
}

fn drop(&mut self, fields: Resource<HostFields>) -> wasmtime::Result<()> {
Expand All @@ -81,9 +126,14 @@ impl<T: WasiHttpView> crate::bindings::http::types::HostFields for T {
fields: Resource<HostFields>,
name: String,
) -> wasmtime::Result<Vec<Vec<u8>>> {
let header = match hyper::header::HeaderName::from_bytes(name.as_bytes()) {
Ok(header) => header,
Err(_) => return Ok(vec![]),
};

let res = get_fields_mut(self.table(), &fields)
.context("[fields_get] getting fields")?
.get_all(hyper::header::HeaderName::from_bytes(name.as_bytes())?)
.get_all(header)
.into_iter()
.map(|val| val.as_bytes().to_owned())
.collect();
Expand All @@ -94,24 +144,42 @@ impl<T: WasiHttpView> crate::bindings::http::types::HostFields for T {
&mut self,
fields: Resource<HostFields>,
name: String,
values: Vec<Vec<u8>>,
) -> wasmtime::Result<()> {
let m = get_fields_mut(self.table(), &fields)?;
byte_values: Vec<Vec<u8>>,
) -> wasmtime::Result<Result<(), HeaderError>> {
let header = match hyper::header::HeaderName::from_bytes(name.as_bytes()) {
Ok(header) => header,
Err(_) => return Ok(Err(HeaderError::InvalidSyntax)),
};

if is_forbidden_header(self, &header) {
return Ok(Err(HeaderError::Forbidden));
}

let header = hyper::header::HeaderName::from_bytes(name.as_bytes())?;
let mut values = Vec::with_capacity(byte_values.len());
for value in byte_values {
match hyper::header::HeaderValue::from_bytes(&value) {
Ok(value) => values.push(value),
Err(_) => return Ok(Err(HeaderError::InvalidSyntax)),
}
}

let m =
get_fields_mut(self.table(), &fields).context("[fields_set] getting mutable fields")?;
m.remove(&header);
for value in values {
let value = hyper::header::HeaderValue::from_bytes(&value)?;
m.append(&header, value);
}

Ok(())
Ok(Ok(()))
}

fn delete(&mut self, fields: Resource<HostFields>, name: String) -> wasmtime::Result<()> {
let header = match hyper::header::HeaderName::from_bytes(name.as_bytes()) {
Ok(header) => header,
Err(_) => return Ok(()),
};

let m = get_fields_mut(self.table(), &fields)?;
let header = hyper::header::HeaderName::from_bytes(name.as_bytes())?;
m.remove(header);
Ok(())
}
Expand All @@ -121,13 +189,26 @@ impl<T: WasiHttpView> crate::bindings::http::types::HostFields for T {
fields: Resource<HostFields>,
name: String,
value: Vec<u8>,
) -> wasmtime::Result<()> {
) -> wasmtime::Result<Result<(), HeaderError>> {
let header = match hyper::header::HeaderName::from_bytes(name.as_bytes()) {
Ok(header) => header,
Err(_) => return Ok(Err(HeaderError::InvalidSyntax)),
};

if is_forbidden_header(self, &header) {
return Ok(Err(HeaderError::Forbidden));
}

let value = match hyper::header::HeaderValue::from_bytes(&value) {
Ok(value) => value,
Err(_) => return Ok(Err(HeaderError::InvalidSyntax)),
};

let m = get_fields_mut(self.table(), &fields)
.context("[fields_append] getting mutable fields")?;
let header = hyper::header::HeaderName::from_bytes(name.as_bytes())?;
let value = hyper::header::HeaderValue::from_bytes(&value)?;

m.append(header, value);
Ok(())
Ok(Ok(()))
}

fn entries(
Expand Down
6 changes: 6 additions & 0 deletions crates/wasi-http/tests/all/async_.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ async fn http_outbound_request_invalid_version() -> Result<()> {
run(HTTP_OUTBOUND_REQUEST_INVALID_VERSION_COMPONENT, &server).await
}

#[test_log::test(tokio::test(flavor = "multi_thread"))]
async fn http_outbound_request_invalid_header() -> Result<()> {
let server = Server::http2()?;
run(HTTP_OUTBOUND_REQUEST_INVALID_HEADER_COMPONENT, &server).await
}

#[test_log::test(tokio::test(flavor = "multi_thread"))]
async fn http_outbound_request_unknown_method() -> Result<()> {
let server = Server::http1()?;
Expand Down
4 changes: 4 additions & 0 deletions crates/wasi-http/tests/all/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ impl WasiHttpView for Ctx {
types::default_send_request(self, request)
}
}

fn is_forbidden_header(&mut self, name: &hyper::header::HeaderName) -> bool {
name.as_str() == "custom-forbidden-header"
}
}

fn store(engine: &Engine, server: &Server) -> Store<Ctx> {
Expand Down
6 changes: 6 additions & 0 deletions crates/wasi-http/tests/all/sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@ fn http_outbound_request_invalid_version() -> Result<()> {
run(HTTP_OUTBOUND_REQUEST_INVALID_VERSION_COMPONENT, &server)
}

#[test_log::test]
fn http_outbound_request_invalid_header() -> Result<()> {
let server = Server::http2()?;
run(HTTP_OUTBOUND_REQUEST_INVALID_HEADER_COMPONENT, &server)
}

#[test_log::test]
fn http_outbound_request_unknown_method() -> Result<()> {
let server = Server::http1()?;
Expand Down
Loading