Skip to main content

trillium_api/
cancel_on_disconnect.rs

1use crate::TryFromConn;
2use std::{future::Future, marker::PhantomData, sync::Arc};
3use trillium::{Conn, Handler, Info, Status, Upgrade};
4
5/// A struct that cancels a handler if the client disconnects.
6///
7/// Note that the conn is not available to this handler, and any properties of the request needed
8/// for execution must be extracted through [`FromConn`](crate::FromConn) or
9/// [`TryFromConn`](crate::TryFromConn) arguments
10#[derive(Debug)]
11pub struct CancelOnDisconnect<F, OutputHandler, TryFromConn>(
12    F,
13    PhantomData<OutputHandler>,
14    PhantomData<TryFromConn>,
15);
16
17impl<F, OH, TFC, Fut> CancelOnDisconnect<F, OH, TFC>
18where
19    F: Fn(TFC) -> Fut + Send + Sync + 'static,
20    Fut: Future<Output = OH> + Send + 'static,
21    OH: Handler,
22    TFC: TryFromConn,
23    TFC::Error: Handler,
24{
25    /// Construct a new CancelOnDisconnect handler
26    pub fn new(handler: F) -> Self {
27        CancelOnDisconnect(handler, PhantomData, PhantomData)
28    }
29}
30
31/// Construct a new [`CancelOnDisconnect`] handler.
32///
33/// Alias for [`CancelOnDisconnect::new`]
34pub fn cancel_on_disconnect<F, OH, TFC, Fut>(handler: F) -> CancelOnDisconnect<F, OH, TFC>
35where
36    F: Fn(TFC) -> Fut + Send + Sync + 'static,
37    Fut: Future<Output = OH> + Send + 'static,
38    OH: Handler,
39    TFC: TryFromConn,
40    TFC::Error: Handler,
41{
42    CancelOnDisconnect(handler, PhantomData, PhantomData)
43}
44
45impl<F, OutputHandler, TFC, Fut> Handler for CancelOnDisconnect<F, OutputHandler, TFC>
46where
47    F: Fn(TFC) -> Fut + Send + Sync + 'static,
48    Fut: Future<Output = OutputHandler> + Send + 'static,
49    OutputHandler: Handler,
50    TFC: TryFromConn,
51    TFC::Error: Handler,
52{
53    async fn run(&self, mut conn: Conn) -> Conn {
54        let mut output_handler: Result<OutputHandler, <TFC as TryFromConn>::Error> =
55            match TFC::try_from_conn(&mut conn).await {
56                Ok(extracted) => {
57                    let Some(ret) = conn.cancel_on_disconnect(self.0(extracted)).await else {
58                        log::info!("client disconnected");
59                        return conn;
60                    };
61                    Ok(ret)
62                }
63                Err(error_handler) => Err(error_handler),
64            };
65
66        if let Some(info) = conn.state_mut::<Info>() {
67            output_handler.init(info).await;
68        } else {
69            output_handler.init(&mut Info::default()).await;
70        }
71        let mut conn = output_handler.run(conn).await;
72        if conn.status().is_none() && conn.response_body().is_some() {
73            conn.set_status(Status::Ok);
74        }
75        conn.with_state(OutputHandlerWrapper(
76            Arc::new(output_handler),
77            PhantomData::<Self>,
78        ))
79    }
80
81    async fn before_send(&self, conn: Conn) -> Conn {
82        match conn
83            .state::<OutputHandlerWrapper<Self, OutputHandler, <TFC as TryFromConn>::Error>>()
84            .cloned()
85        {
86            Some(OutputHandlerWrapper(handler, _)) => handler.before_send(conn).await,
87            _ => conn,
88        }
89    }
90
91    fn has_upgrade(&self, upgrade: &Upgrade) -> bool {
92        upgrade
93            .state()
94            .get::<OutputHandlerWrapper<Self, OutputHandler, <TFC as TryFromConn>::Error>>()
95            .cloned()
96            .is_some_and(|OutputHandlerWrapper(handler, _)| handler.has_upgrade(upgrade))
97    }
98
99    async fn upgrade(&self, upgrade: Upgrade) {
100        if let Some(OutputHandlerWrapper(handler, _)) = upgrade
101            .state()
102            .get::<OutputHandlerWrapper<Self, OutputHandler, <TFC as TryFromConn>::Error>>()
103            .cloned()
104        {
105            handler.upgrade(upgrade).await
106        }
107    }
108}
109
110struct OutputHandlerWrapper<TFC, OH, EH>(Arc<Result<OH, EH>>, PhantomData<TFC>);
111
112impl<TFC, OH, EH> Clone for OutputHandlerWrapper<TFC, OH, EH> {
113    fn clone(&self) -> Self {
114        Self(Arc::clone(&self.0), self.1)
115    }
116}