Skip to main content

trillium_conn_id/
lib.rs

1//! Trillium crate to add identifiers to conns.
2//!
3//! This crate provides the following utilities:
4//! [`ConnId`] a handler which must be called for the rest of this crate to function
5//! [`log_formatter::conn_id`] a formatter to use with trillium_logger
6//! (note that this does not depend on the trillium_logger crate and is very lightweight
7//! if you do not use that crate)
8//! [`ConnIdExt`] an extension trait for retrieving the id from a conn
9#![forbid(unsafe_code)]
10#![deny(
11    missing_copy_implementations,
12    rustdoc::missing_crate_level_docs,
13    missing_debug_implementations,
14    missing_docs,
15    nonstandard_style,
16    unused_qualifications
17)]
18
19#[cfg(test)]
20#[doc = include_str!("../README.md")]
21mod readme {}
22
23use fastrand::Rng;
24use std::{
25    fmt::{Debug, Formatter, Result},
26    iter::repeat_with,
27    ops::Deref,
28    sync::{Arc, Mutex},
29};
30use trillium::{Conn, Handler, HeaderName, KnownHeaderName, TypeSet};
31
32#[derive(Default)]
33enum IdGenerator {
34    #[default]
35    Default,
36    SeededFastrand(Arc<Mutex<Rng>>),
37    Fn(Box<dyn Fn() -> String + Send + Sync + 'static>),
38}
39
40impl IdGenerator {
41    fn generate(&self) -> Id {
42        match self {
43            IdGenerator::Default => Id::default(),
44            IdGenerator::SeededFastrand(rng) => Id::with_rng(&mut rng.lock().unwrap()),
45            IdGenerator::Fn(gen_fun) => Id(gen_fun()),
46        }
47    }
48}
49
50/// Trillium handler to set a identifier for every Conn.
51///
52/// By default, it will use an inbound `x-request-id` request header or if
53/// that is missing, populate a ten character random id. This handler will
54/// set an outbound `x-request-id` header as well by default. All of this
55/// behavior can be customized with [`ConnId::with_request_header`],
56/// [`ConnId::with_response_header`] and [`ConnId::with_id_generator`]
57pub struct ConnId {
58    request_header: Option<HeaderName<'static>>,
59    response_header: Option<HeaderName<'static>>,
60    id_generator: IdGenerator,
61}
62
63impl Debug for ConnId {
64    fn fmt(&self, f: &mut Formatter<'_>) -> Result {
65        f.debug_struct("ConnId")
66            .field("request_header", &self.request_header)
67            .field("response_header", &self.response_header)
68            .field("id_generator", &self.id_generator)
69            .finish()
70    }
71}
72
73impl Default for ConnId {
74    fn default() -> Self {
75        Self {
76            request_header: Some(KnownHeaderName::XrequestId.into()),
77            response_header: Some(KnownHeaderName::XrequestId.into()),
78            id_generator: Default::default(),
79        }
80    }
81}
82
83impl ConnId {
84    /// Constructs a new ConnId handler
85    /// ```
86    /// # use trillium_testing::TestServer;
87    /// # use trillium_conn_id::ConnId;
88    /// # trillium_testing::block_on(async {
89    /// let app = TestServer::new((ConnId::new().with_seed(1000), "ok")).await;
90    /// app.get("/")
91    ///     .await
92    ///     .assert_ok()
93    ///     .assert_body("ok")
94    ///     .assert_header("x-request-id", "J4lzoPXcT5");
95    ///
96    /// app.get("/")
97    ///     .with_request_header("x-request-id", "inbound")
98    ///     .await
99    ///     .assert_ok()
100    ///     .assert_header("x-request-id", "inbound");
101    /// # });
102    /// ```
103    pub fn new() -> Self {
104        Self::default()
105    }
106
107    /// Specifies a request header to use. If this header is provided on
108    /// the inbound request, the id will be used unmodified. To disable
109    /// this behavior, see [`ConnId::without_request_header`]
110    ///
111    /// ```
112    /// # use trillium_testing::TestServer;
113    /// # use trillium_conn_id::ConnId;
114    /// # trillium_testing::block_on(async {
115    /// let app = TestServer::new((ConnId::new().with_request_header("x-custom-id"), "ok")).await;
116    ///
117    /// app.get("/")
118    ///     .with_request_header("x-custom-id", "inbound")
119    ///     .await
120    ///     .assert_ok()
121    ///     .assert_header("x-request-id", "inbound");
122    /// # });
123    /// ```
124    pub fn with_request_header(mut self, request_header: impl Into<HeaderName<'static>>) -> Self {
125        self.request_header = Some(request_header.into());
126        self
127    }
128
129    /// disables the default behavior of reusing an inbound header for
130    /// the request id. If a ConnId is configured
131    /// `without_request_header`, a new id will always be generated
132    pub fn without_request_header(mut self) -> Self {
133        self.request_header = None;
134        self
135    }
136
137    /// Specifies a response header to set. To disable this behavior, see
138    /// [`ConnId::without_response_header`]
139    ///
140    /// ```
141    /// # use trillium_testing::TestServer;
142    /// # use trillium_conn_id::ConnId;
143    /// # trillium_testing::block_on(async {
144    /// let app = TestServer::new((
145    ///     ConnId::new()
146    ///         .with_seed(1000)
147    ///         .with_response_header("x-custom-header"),
148    ///     "ok",
149    /// ))
150    /// .await;
151    ///
152    /// app.get("/")
153    ///     .await
154    ///     .assert_ok()
155    ///     .assert_header("x-custom-header", "J4lzoPXcT5");
156    /// # });
157    /// ```
158    pub fn with_response_header(mut self, response_header: impl Into<HeaderName<'static>>) -> Self {
159        self.response_header = Some(response_header.into());
160        self
161    }
162
163    /// Disables the default behavior of sending the conn id as a response
164    /// header. A request id will be available within the application
165    /// through use of [`ConnIdExt`] but will not be sent as part of the
166    /// response.
167    pub fn without_response_header(mut self) -> Self {
168        self.response_header = None;
169        self
170    }
171
172    /// Provide an alternative generator function for ids. The default
173    /// is a ten-character alphanumeric random sequence.
174    ///
175    /// ```
176    /// # use trillium_testing::TestServer;
177    /// # use trillium_conn_id::ConnId;
178    /// # use uuid::Uuid;
179    /// # trillium_testing::block_on(async {
180    /// let app = TestServer::new((
181    ///     ConnId::new().with_id_generator(|| Uuid::new_v4().to_string()),
182    ///     "ok",
183    /// ))
184    /// .await;
185    ///
186    /// app.get("/")
187    ///     .await
188    ///     .assert_header_with("x-request-id", |header| {
189    ///         assert!(Uuid::parse_str(header.as_str().unwrap()).is_ok());
190    ///     });
191    /// # });
192    /// ```
193    pub fn with_id_generator<F>(mut self, id_generator: F) -> Self
194    where
195        F: Fn() -> String + Send + Sync + 'static,
196    {
197        self.id_generator = IdGenerator::Fn(Box::new(id_generator));
198        self
199    }
200
201    /// seed a shared rng
202    ///
203    /// this is primarily useful for tests
204    pub fn with_seed(mut self, seed: u64) -> Self {
205        self.id_generator = IdGenerator::SeededFastrand(Arc::new(Mutex::new(Rng::with_seed(seed))));
206        self
207    }
208
209    fn generate_id(&self) -> Id {
210        self.id_generator.generate()
211    }
212}
213
214/// The ID type generated by the ConnId handler
215#[derive(Clone, Debug, PartialEq)]
216struct Id(String);
217
218impl Deref for Id {
219    type Target = str;
220
221    fn deref(&self) -> &Self::Target {
222        &self.0
223    }
224}
225
226impl std::fmt::Display for Id {
227    fn fmt(&self, f: &mut Formatter<'_>) -> Result {
228        f.write_str(self)
229    }
230}
231
232impl Default for Id {
233    fn default() -> Self {
234        Self(repeat_with(fastrand::alphanumeric).take(10).collect())
235    }
236}
237
238impl Id {
239    fn with_rng(rng: &mut Rng) -> Self {
240        Self(repeat_with(|| rng.alphanumeric()).take(10).collect())
241    }
242}
243
244impl Handler for ConnId {
245    async fn run(&self, mut conn: Conn) -> Conn {
246        let id = self
247            .request_header
248            .as_ref()
249            .and_then(|request_header| conn.request_headers().get_str(request_header.clone()))
250            .map(|request_header| Id(request_header.to_string()))
251            .unwrap_or_else(|| self.generate_id());
252
253        if let Some(ref response_header) = self.response_header {
254            conn.response_headers_mut()
255                .insert(response_header.clone(), id.to_string());
256        }
257
258        conn.with_state(id)
259    }
260}
261
262/// Extension trait to retrieve an id generated by the [`ConnId`] handler
263pub trait ConnIdExt {
264    /// Retrieves the id for this conn. This method will panic if it
265    /// is run before the [`ConnId`] handler.
266    fn id(&self) -> &str;
267}
268
269impl<ConnLike> ConnIdExt for ConnLike
270where
271    ConnLike: AsRef<TypeSet>,
272{
273    fn id(&self) -> &str {
274        self.as_ref()
275            .get::<Id>()
276            .expect("ConnId handler must be run before calling IdConnExt::id")
277    }
278}
279
280/// Formatter for the trillium_log crate
281pub mod log_formatter {
282    use super::*;
283    use std::borrow::Cow;
284    /// Formatter for the trillium_log crate. This will be `-` if
285    /// there is no id on the conn.
286    pub fn conn_id(conn: &Conn, _color: bool) -> Cow<'static, str> {
287        conn.state::<Id>()
288            .map(|id| Cow::Owned(id.0.clone()))
289            .unwrap_or_else(|| Cow::Borrowed("-"))
290    }
291}
292
293/// Alias for ConnId::new()
294pub fn conn_id() -> ConnId {
295    ConnId::new()
296}
297
298impl Debug for IdGenerator {
299    fn fmt(&self, f: &mut Formatter<'_>) -> Result {
300        f.write_str(match self {
301            IdGenerator::Default => "IdGenerator::Default",
302            IdGenerator::SeededFastrand(_) => "IdGenerator::SeededFastrand",
303            IdGenerator::Fn(_) => "IdGenerator::Fn",
304        })
305    }
306}