Skip to main content

trillium/
init.rs

1use crate::{Conn, Handler, Info};
2use std::{future::Future, mem};
3
4/// Provides support for asynchronous initialization of a handler after
5/// the server is started.
6///
7/// ```
8/// use trillium::{Conn, Init, State};
9/// use trillium_testing::TestServer;
10///
11/// #[derive(Debug, Clone, PartialEq)]
12/// struct MyDatabaseConnection(String);
13/// impl MyDatabaseConnection {
14///     async fn connect(uri: &str) -> std::io::Result<Self> {
15///         Ok(Self(uri.into()))
16///     }
17///
18///     async fn query(&self, query: &str) -> String {
19///         format!("you queried `{}` against {}", query, &self.0)
20///     }
21/// }
22///
23/// # trillium_testing::block_on(async {
24/// let handler = (
25///     Init::new(|mut info| async move {
26///         let db = MyDatabaseConnection::connect("db://db").await.expect("1");
27///         info.with_shared_state(db)
28///     }),
29///     |conn: Conn| async move {
30///         let db = conn.shared_state::<MyDatabaseConnection>().expect("2");
31///         let response = db.query("select * from users limit 1").await;
32///         conn.ok(response)
33///     },
34/// );
35///
36/// let app = TestServer::new(handler).await;
37/// app.assert_shared_state(MyDatabaseConnection("db://db".into()));
38/// app.get("/")
39///     .await
40///     .assert_ok()
41///     .assert_body("you queried `select * from users limit 1` against db://db");
42/// # });
43/// ```
44#[derive(Debug)]
45pub struct Init<F>(Option<F>);
46
47impl<F, Fut> Init<F>
48where
49    F: FnOnce(Info) -> Fut + Send + Sync + 'static,
50    Fut: Future<Output = Info> + Send + 'static,
51{
52    /// Constructs a new Init handler with an async function that receives and returns [`Info`].
53    #[must_use]
54    pub const fn new(init: F) -> Self {
55        Self(Some(init))
56    }
57}
58
59impl<F, Fut> Handler for Init<F>
60where
61    F: FnOnce(Info) -> Fut + Send + Sync + 'static,
62    Fut: Future<Output = Info> + Send + 'static,
63{
64    async fn run(&self, conn: Conn) -> Conn {
65        conn
66    }
67
68    async fn init(&mut self, info: &mut Info) {
69        match self.0.take() {
70            Some(init) => {
71                *info = init(mem::take(info)).await;
72            }
73            _ => {
74                log::warn!("called init more than once");
75            }
76        }
77    }
78}
79
80/// alias for [`Init::new`]
81pub const fn init<F, Fut>(init: F) -> Init<F>
82where
83    F: FnOnce(Info) -> Fut + Send + Sync + 'static,
84    Fut: Future<Output = Info> + Send + 'static,
85{
86    Init::new(init)
87}