Implement a middleware in Actix-web

byGinkWed, 29 Nov 2023

1. A very first simple middleware for Actix-web

You may think this looks quite intimidating, right? Believe me, it's the most simple standard middleware for Actix. Commonly, we just have to care about the call() function, the other configurations just to satisfy middleware definition.

use std::future::{ready, Ready};

use actix_web::{
    dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform},
    Error,
};
use futures_util::future::LocalBoxFuture;

pub struct Authentication;

impl<S, B> Transform<S, ServiceRequest> for Authentication
where
    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
    S::Future: 'static,
    B: 'static,
{
    type Response = ServiceResponse<B>;
    type Error = Error;
    type InitError = ();
    type Transform = AuthenticationMiddleware<S>;
    type Future = Ready<Result<Self::Transform, Self::InitError>>;

    fn new_transform(&self, service: S) -> Self::Future {
        ready(Ok(AuthenticationMiddleware { service }))
    }
}

pub struct AuthenticationMiddleware<S> {
    service: S,
}

impl<S, B> Service<ServiceRequest> for AuthenticationMiddleware<S>
where
    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
    S::Future: 'static,
    B: 'static,
{
    type Response = ServiceResponse<B>;
    type Error = Error;
    type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;

    forward_ready!(service);

    fn call(&self, req: ServiceRequest) -> Self::Future {
        // Do something with the request here

        // Continue with the next middleware / handler
        let fut = self.service.call(req);

        Box::pin(async move {
            let res = fut.await?;
            Ok(res)
        })
    }
}

2. Response to the request right in middlewares

Now you think you get the hang of it already, and try to add more logics. For example, checking the request headers and then return 403 if something went wrong. Let's add some code for that in call() function:

        // Do something with the request here
        let auth = req.headers().get(AUTHORIZATION);
        if auth.is_none() {
            let http_res = HttpResponse::Unauthorized().finish();
            let (http_req, _) = req.into_parts();
            let res = ServiceResponse::new(http_req, http_res);
            return (async move { Ok(res) }).boxed_local();
        }

        // Continue with the next middleware / handler
        let fut = self.service.call(req);

And the compiler will punch your face by the return type error:

^^^^^^^ expected `Pin<Box<dyn Future<Output = Result<ServiceResponse<B>, Error>>>>`, found `Pin<Box<dyn Future<Output = Result<ServiceResponse, _>>>>`

Why is that? Because return type from the handler and the Unauthorized you're trying to return right now might be different types.

What you can do now is satisfying the compiler by converting the Response associated type to ServiceResponse<EitherBody<B>> and map the return in conditions to correct type. Here it is:

use std::future::{ready, Ready};

use actix_web::{
    body::EitherBody,
    dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform},
    http::header::AUTHORIZATION,
    Error, HttpResponse,
};
use futures_util::{future::LocalBoxFuture, FutureExt};

pub struct Authentication;

impl<S, B> Transform<S, ServiceRequest> for Authentication
where
    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
    S::Future: 'static,
    B: 'static,
{
    type Response = ServiceResponse<EitherBody<B>>; // update here
    type Error = Error;
    type InitError = ();
    type Transform = AuthenticationMiddleware<S>;
    type Future = Ready<Result<Self::Transform, Self::InitError>>;

    fn new_transform(&self, service: S) -> Self::Future {
        ready(Ok(AuthenticationMiddleware { service }))
    }
}

pub struct AuthenticationMiddleware<S> {
    service: S,
}

impl<S, B> Service<ServiceRequest> for AuthenticationMiddleware<S>
where
    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
    S::Future: 'static,
    B: 'static,
{
    type Response = ServiceResponse<EitherBody<B>>; // update here
    type Error = Error;
    type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;

    forward_ready!(service);

    fn call(&self, req: ServiceRequest) -> Self::Future {
        // Do something with the request here
        let auth = req.headers().get(AUTHORIZATION);
        if auth.is_none() {
            let http_res = HttpResponse::Unauthorized().finish();
            let (http_req, _) = req.into_parts();
            let res = ServiceResponse::new(http_req, http_res);
            // Map to R type
            return (async move { Ok(res.map_into_right_body()) }).boxed_local();
        }

        // Continue with the next middleware / handler
        let fut = self.service.call(req);

        Box::pin(async move {
            let res = fut.await?;
            // Map to L type
            Ok(res.map_into_left_body())
        })
    }
}

3. Access to async functions before continuing with next middleware / handler

Good enough with the response types?

Now you need something more, like accessing the database or redis to get some data and update to the request before continue with the real handler.

Soon you will realize that the self.service.call(req) must be invoked outside of the async move {} block. So impossible to run async functions before invoking service's call().

Immediately, you think about moving all the function calls into the async block, something like:

    Box::pin(async move {
        // Getting some data here (just demo code)
        let user = get_some_data().await;
        req.extensions_mut().insert(user);

        // Continue with the next middleware / handler
        let res = self.service.call(req).await?;
        // Map to L type
        Ok(res.map_into_left_body())
    })

But no, the compiler now hit you harder by another error:

error: lifetime may not live long enough
xx |       fn call(&self, req: ServiceRequest) -> Self::Future {
   |               - let's call the lifetime of this reference `'1`
   | |__________^ returning this value requires that `'1` must outlive `'static`

What the hell? What is lifetime?

It's a kind of borrow checker. The compiler now infer the lifetime of async block is '1, it can't outlive the return type Future which is 'static in definition. While we still have a reference from &self.service.

Then you think about cloning the self.service first, before moving into the async block. Unfortunately, we don't have Clone implementation on it. What to do now?

Ok, let's use smart pointer for this case. Rc might be a good candidate. We'll convert the Service to Rc<S> and let S satisfying 'static lifetime. Here it is:

use std::{
    future::{ready, Ready},
    rc::Rc,
};

use actix_web::{
    body::EitherBody,
    dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform},
    http::header::AUTHORIZATION,
    Error, HttpMessage, HttpResponse,
};
use futures_util::{future::LocalBoxFuture, FutureExt};

pub struct Authentication;

impl<S, B> Transform<S, ServiceRequest> for Authentication
where
    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static, // update here
    S::Future: 'static,
    B: 'static,
{
    type Response = ServiceResponse<EitherBody<B>>;
    type Error = Error;
    type InitError = ();
    type Transform = AuthenticationMiddleware<S>;
    type Future = Ready<Result<Self::Transform, Self::InitError>>;

    fn new_transform(&self, service: S) -> Self::Future {
        ready(Ok(AuthenticationMiddleware {
            service: Rc::new(service), // convert S to Rc<S>
        }))
    }
}

pub struct AuthenticationMiddleware<S> {
    // service: S,
    service: Rc<S>,
}

impl<S, B> Service<ServiceRequest> for AuthenticationMiddleware<S>
where
    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static, // update here
    S::Future: 'static,
    B: 'static,
{
    type Response = ServiceResponse<EitherBody<B>>;
    type Error = Error;
    type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;

    forward_ready!(service);

    fn call(&self, req: ServiceRequest) -> Self::Future {
        // Do something with the request here
        let auth = req.headers().get(AUTHORIZATION);
        if auth.is_none() {
            let http_res = HttpResponse::Unauthorized().finish();
            let (http_req, _) = req.into_parts();
            let res = ServiceResponse::new(http_req, http_res);
            // Map to R type
            return (async move { Ok(res.map_into_right_body()) }).boxed_local();
        }

        // Clone the service to keep reference after moving into async block
        let service = Rc::clone(&self.service);

        Box::pin(async move {
            // Getting some data here (just demo code for async function)
            let user = get_some_data().await;
            req.extensions_mut().insert(user);

            // Continue with the next middleware / handler
            let res = service.call(req).await?;
            // Map to L type
            Ok(res.map_into_left_body())
        })
    }
}

async fn get_some_data() -> String {
    "Data".into()
}

Mission completed!

4. Is there any way more simple? This is too verbose.

Yes, there is.

Fortunately, we have another way to create middleware, which is from_fn

It's provided by actix_web_lab crate.

Now you can create an async function and turn it into a middleware by from_fn() instead of making a standard middleware type.

use actix_web::{
    body::MessageBody,
    dev::{ServiceRequest, ServiceResponse},
    Error,
};
use actix_web_lab::middleware::Next;

async fn check_auth_mw(
    req: ServiceRequest,
    next: Next<impl MessageBody>,
) -> Result<ServiceResponse<impl MessageBody>, Error> {
    // Do something with the request here
    next.call(req).await
}

Use it as a middleware:

use actix_web::{
    App, Error,
    dev::{ServiceRequest, ServiceResponse, Service as _},
};
use actix_web_lab::middleware::from_fn;

App::new()
    .wrap(from_fn(check_auth_mw))

It's much more simple and elegant now. We just need to focus on logic instead of wasting time for definition.

Notes: If you need to return different types by conditions. EitherBody is still a savior anyway.


© 2016-2024  GinkCode.com