Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 36 additions & 2 deletions examples/middleware.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@ use std::future::Future;
use std::pin::Pin;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use tide::{Middleware, Next, Request, Response, Result, StatusCode};
use tide::{After, Before, Middleware, Next, Request, Response, Result, StatusCode};

#[derive(Debug)]
struct User {
name: String,
}

#[derive(Default)]
#[derive(Default, Debug)]
struct UserDatabase;
impl UserDatabase {
async fn find_user(&self) -> Option<User> {
Expand Down Expand Up @@ -78,13 +78,47 @@ impl<State: Send + Sync + 'static> Middleware<State> for RequestCounterMiddlewar
}
}

const NOT_FOUND_HTML_PAGE: &str = "<html><body>
<h1>uh oh, we couldn't find that document</h1>
<p>
probably, this would be served from the file system or
included with `include_bytes!`
</p>
</body></html>";

const INTERNAL_SERVER_ERROR_HTML_PAGE: &str = "<html><body>
<h1>whoops! it's not you, it's us</h1>
<p>
we're very sorry, but something seems to have gone wrong on our end
</p>
</body></html>";

#[async_std::main]
async fn main() -> Result<()> {
tide::log::start();
let mut app = tide::with_state(UserDatabase::default());

app.middleware(After(|result: Result| async move {
let response = result.unwrap_or_else(|e| Response::new(e.status()));
match response.status() {
StatusCode::NotFound => Ok(response
.set_content_type(tide::http::mime::HTML)
.body_string(NOT_FOUND_HTML_PAGE.into())),

StatusCode::InternalServerError => Ok(response
.set_content_type(tide::http::mime::HTML)
.body_string(INTERNAL_SERVER_ERROR_HTML_PAGE.into())),

_ => Ok(response),
}
}));

app.middleware(user_loader);
app.middleware(RequestCounterMiddleware::new(0));
app.middleware(Before(|mut request: Request<UserDatabase>| async move {
request.set_ext(std::time::Instant::now());
request
}));

app.at("/").get(|req: Request<_>| async move {
let count: &RequestCount = req.ext().unwrap();
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ pub mod security;
pub mod sse;

pub use endpoint::Endpoint;
pub use middleware::{Middleware, Next};
pub use middleware::{After, Before, Middleware, Next};
pub use redirect::Redirect;
pub use request::Request;
pub use response::Response;
Expand Down
77 changes: 77 additions & 0 deletions src/middleware.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,83 @@ pub trait Middleware<State>: 'static + Send + Sync {
}
}

/// Define a middleware that operates on incoming requests.
///
/// This middleware is useful because it is not possible in Rust yet to use
/// closures to define inline middleware.
///
/// # Examples
///
/// ```rust
/// use tide::{Before, Request};
/// use std::time::Instant;
///
/// let mut app = tide::new();
/// app.middleware(Before(|mut request: Request<()>| async move {
/// request.set_ext(Instant::now());
/// request
/// }));
/// ```
#[derive(Debug)]
pub struct Before<F>(pub F);
impl<State, F, Fut> Middleware<State> for Before<F>
where
State: Send + Sync + 'static,
F: Fn(Request<State>) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = Request<State>> + Send + Sync,
{
fn handle<'a>(
&'a self,
request: Request<State>,
next: Next<'a, State>,
) -> BoxFuture<'a, crate::Result> {
Box::pin(async move {
let request = (self.0)(request).await;
next.run(request).await
})
}
}

/// Define a middleware that operates on outgoing responses.
///
/// This middleware is useful because it is not possible in Rust yet to use
/// closures to define inline middleware.
///
/// # Examples
///
/// ```rust
/// use tide::{After, Response, http};
///
/// let mut app = tide::new();
/// app.middleware(After(|res: tide::Result| async move {
/// let res = res.unwrap_or_else(|e| Response::new(e.status()));
/// match res.status() {
/// http::StatusCode::NotFound => Ok("Page not found".into()),
/// http::StatusCode::InternalServerError => Ok("Something went wrong".into()),
/// _ => Ok(res),
/// }
/// }));
/// ```
#[derive(Debug)]
pub struct After<F>(pub F);
impl<State, F, Fut> Middleware<State> for After<F>
where
State: Send + Sync + 'static,
F: Fn(crate::Result) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = crate::Result> + Send + Sync,
{
fn handle<'a>(
&'a self,
request: Request<State>,
next: Next<'a, State>,
) -> BoxFuture<'a, crate::Result> {
Box::pin(async move {
let result = next.run(request).await;
(self.0)(result).await
})
}
}

impl<State, F> Middleware<State> for F
where
F: Send
Expand Down