trillium_api/
cancel_on_disconnect.rs1use crate::TryFromConn;
2use std::{future::Future, marker::PhantomData, sync::Arc};
3use trillium::{Conn, Handler, Info, Status, Upgrade};
4
5#[derive(Debug)]
11pub struct CancelOnDisconnect<F, OutputHandler, TryFromConn>(
12 F,
13 PhantomData<OutputHandler>,
14 PhantomData<TryFromConn>,
15);
16
17impl<F, OH, TFC, Fut> CancelOnDisconnect<F, OH, TFC>
18where
19 F: Fn(TFC) -> Fut + Send + Sync + 'static,
20 Fut: Future<Output = OH> + Send + 'static,
21 OH: Handler,
22 TFC: TryFromConn,
23 TFC::Error: Handler,
24{
25 pub fn new(handler: F) -> Self {
27 CancelOnDisconnect(handler, PhantomData, PhantomData)
28 }
29}
30
31pub fn cancel_on_disconnect<F, OH, TFC, Fut>(handler: F) -> CancelOnDisconnect<F, OH, TFC>
35where
36 F: Fn(TFC) -> Fut + Send + Sync + 'static,
37 Fut: Future<Output = OH> + Send + 'static,
38 OH: Handler,
39 TFC: TryFromConn,
40 TFC::Error: Handler,
41{
42 CancelOnDisconnect(handler, PhantomData, PhantomData)
43}
44
45impl<F, OutputHandler, TFC, Fut> Handler for CancelOnDisconnect<F, OutputHandler, TFC>
46where
47 F: Fn(TFC) -> Fut + Send + Sync + 'static,
48 Fut: Future<Output = OutputHandler> + Send + 'static,
49 OutputHandler: Handler,
50 TFC: TryFromConn,
51 TFC::Error: Handler,
52{
53 async fn run(&self, mut conn: Conn) -> Conn {
54 let mut output_handler: Result<OutputHandler, <TFC as TryFromConn>::Error> =
55 match TFC::try_from_conn(&mut conn).await {
56 Ok(extracted) => {
57 let Some(ret) = conn.cancel_on_disconnect(self.0(extracted)).await else {
58 log::info!("client disconnected");
59 return conn;
60 };
61 Ok(ret)
62 }
63 Err(error_handler) => Err(error_handler),
64 };
65
66 if let Some(info) = conn.state_mut::<Info>() {
67 output_handler.init(info).await;
68 } else {
69 output_handler.init(&mut Info::default()).await;
70 }
71 let mut conn = output_handler.run(conn).await;
72 if conn.status().is_none() && conn.response_body().is_some() {
73 conn.set_status(Status::Ok);
74 }
75 conn.with_state(OutputHandlerWrapper(
76 Arc::new(output_handler),
77 PhantomData::<Self>,
78 ))
79 }
80
81 async fn before_send(&self, conn: Conn) -> Conn {
82 match conn
83 .state::<OutputHandlerWrapper<Self, OutputHandler, <TFC as TryFromConn>::Error>>()
84 .cloned()
85 {
86 Some(OutputHandlerWrapper(handler, _)) => handler.before_send(conn).await,
87 _ => conn,
88 }
89 }
90
91 fn has_upgrade(&self, upgrade: &Upgrade) -> bool {
92 upgrade
93 .state()
94 .get::<OutputHandlerWrapper<Self, OutputHandler, <TFC as TryFromConn>::Error>>()
95 .cloned()
96 .is_some_and(|OutputHandlerWrapper(handler, _)| handler.has_upgrade(upgrade))
97 }
98
99 async fn upgrade(&self, upgrade: Upgrade) {
100 if let Some(OutputHandlerWrapper(handler, _)) = upgrade
101 .state()
102 .get::<OutputHandlerWrapper<Self, OutputHandler, <TFC as TryFromConn>::Error>>()
103 .cloned()
104 {
105 handler.upgrade(upgrade).await
106 }
107 }
108}
109
110struct OutputHandlerWrapper<TFC, OH, EH>(Arc<Result<OH, EH>>, PhantomData<TFC>);
111
112impl<TFC, OH, EH> Clone for OutputHandlerWrapper<TFC, OH, EH> {
113 fn clone(&self) -> Self {
114 Self(Arc::clone(&self.0), self.1)
115 }
116}