Skip to main content

trillium_router/
router.rs

1use crate::{CapturesNewType, RouteSpecNewType, RouterRef};
2use routefinder::{Match, RouteSpec, Router as Routefinder};
3use std::{
4    borrow::Cow,
5    collections::BTreeSet,
6    fmt::{self, Debug, Display, Formatter},
7    mem,
8};
9use trillium::{BoxedHandler, Conn, Handler, Info, KnownHeaderName, Method, Upgrade};
10
11const ALL_METHODS: [Method; 5] = [
12    Method::Delete,
13    Method::Get,
14    Method::Patch,
15    Method::Post,
16    Method::Put,
17];
18
19#[derive(Debug)]
20enum MethodSelection {
21    Just(Method),
22    All,
23    Any(Vec<Method>),
24}
25
26impl Display for MethodSelection {
27    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
28        match self {
29            MethodSelection::Just(m) => Display::fmt(m, f),
30            MethodSelection::All => f.write_str("*"),
31            MethodSelection::Any(v) => {
32                f.write_str(&v.iter().map(|m| m.as_ref()).collect::<Vec<_>>().join(", "))
33            }
34        }
35    }
36}
37
38impl PartialEq<Method> for MethodSelection {
39    fn eq(&self, other: &Method) -> bool {
40        match self {
41            MethodSelection::Just(m) => m == other,
42            MethodSelection::All => true,
43            MethodSelection::Any(v) => v.contains(other),
44        }
45    }
46}
47
48impl From<()> for MethodSelection {
49    fn from(_: ()) -> MethodSelection {
50        Self::All
51    }
52}
53
54impl From<Method> for MethodSelection {
55    fn from(method: Method) -> Self {
56        Self::Just(method)
57    }
58}
59
60impl From<&[Method]> for MethodSelection {
61    fn from(methods: &[Method]) -> Self {
62        Self::Any(methods.to_vec())
63    }
64}
65impl From<Vec<Method>> for MethodSelection {
66    fn from(methods: Vec<Method>) -> Self {
67        Self::Any(methods)
68    }
69}
70
71#[derive(Debug, Default)]
72struct MethodRoutefinder(Routefinder<(MethodSelection, BoxedHandler)>);
73impl MethodRoutefinder {
74    fn add<R>(
75        &mut self,
76        method_selection: impl Into<MethodSelection>,
77        path: R,
78        handler: impl Handler,
79    ) where
80        R: TryInto<RouteSpec>,
81        R::Error: Debug,
82    {
83        self.0
84            .add(path, (method_selection.into(), BoxedHandler::new(handler)))
85            .expect("could not add route")
86    }
87
88    fn methods_matching(&self, path: &str) -> BTreeSet<Method> {
89        let mut set = BTreeSet::new();
90
91        fn extend(ms: &MethodSelection, set: &mut BTreeSet<Method>) {
92            match ms {
93                MethodSelection::All => {
94                    set.extend(ALL_METHODS);
95                }
96                MethodSelection::Just(method) => {
97                    set.insert(*method);
98                }
99                MethodSelection::Any(methods) => {
100                    set.extend(methods);
101                }
102            }
103        }
104
105        if path == "*" {
106            for ms in self.0.iter().map(|(_, (m, _))| m) {
107                extend(ms, &mut set);
108            }
109        } else {
110            for m in self.0.match_iter(path) {
111                extend(&m.0, &mut set);
112            }
113        };
114
115        set.remove(&Method::Options);
116        set
117    }
118
119    fn best_match<'a, 'b>(
120        &'a self,
121        method: Method,
122        path: &'b str,
123    ) -> Option<Match<'a, 'b, (MethodSelection, BoxedHandler)>> {
124        self.0.match_iter(path).find(|m| m.0 == method)
125    }
126}
127
128/// # The Router handler
129///
130/// See crate level docs for more, as this is the primary type in this crate.
131pub struct Router {
132    routefinder: MethodRoutefinder,
133    handle_options: bool,
134    handle_method_not_allowed: bool,
135}
136
137impl Default for Router {
138    fn default() -> Self {
139        Self {
140            routefinder: MethodRoutefinder::default(),
141            handle_options: true,
142            handle_method_not_allowed: false,
143        }
144    }
145}
146
147macro_rules! method {
148    ($fn_name:ident, $method:ident) => {
149        method!(
150            $fn_name,
151            $method,
152            concat!(
153                // yep, macro-generated doctests
154                "Registers a handler for the ",
155                stringify!($fn_name),
156                " http method.
157
158```
159# use trillium::Conn;
160# use trillium_router::Router;
161# use trillium_testing::TestServer;
162# trillium_testing::block_on(async {
163let router = Router::new().",
164                stringify!($fn_name),
165                "(\"/some/route\", |conn: Conn| async move {
166  conn.ok(\"success\")
167});
168
169let app = TestServer::new(router).await;
170app.",
171                stringify!($fn_name),
172                "(\"/some/route\").await
173    .assert_ok()
174    .assert_body(\"success\");
175app.",
176                stringify!($fn_name),
177                "(\"/other/route\").await
178    .assert_status(404);
179# });
180```
181"
182            )
183        );
184    };
185
186    ($fn_name:ident, $method:ident, $doc_comment:expr_2021) => {
187        #[doc = $doc_comment]
188        pub fn $fn_name<R>(mut self, path: R, handler: impl Handler) -> Self
189        where
190            R: TryInto<RouteSpec>,
191            R::Error: Debug,
192        {
193            self.add(path, Method::$method, handler);
194            self
195        }
196    };
197}
198
199impl Router {
200    method!(get, Get);
201
202    method!(post, Post);
203
204    method!(put, Put);
205
206    method!(delete, Delete);
207
208    method!(patch, Patch);
209
210    /// Constructs a new Router. This is often used with [`Router::get`],
211    /// [`Router::post`], [`Router::put`], [`Router::delete`], and
212    /// [`Router::patch`] chainable methods to build up an application.
213    ///
214    /// For an alternative way of constructing a Router, see [`Router::build`]
215    ///
216    /// ```
217    /// # use trillium::Conn;
218    /// # use trillium_router::Router;
219    /// # use trillium_testing::TestServer;
220    ///
221    /// # trillium_testing::block_on(async {
222    /// let router = Router::new()
223    ///     .get("/", |conn: Conn| async move {
224    ///         conn.ok("you have reached the index")
225    ///     })
226    ///     .get("/some/:param", |conn: Conn| async move {
227    ///         conn.ok("you have reached /some/:param")
228    ///     })
229    ///     .post("/", |conn: Conn| async move { conn.ok("post!") });
230    ///
231    /// let app = TestServer::new(router).await;
232    /// app.get("/")
233    ///     .await
234    ///     .assert_ok()
235    ///     .assert_body("you have reached the index");
236    /// app.get("/some/route")
237    ///     .await
238    ///     .assert_ok()
239    ///     .assert_body("you have reached /some/:param");
240    /// app.post("/").await.assert_ok().assert_body("post!");
241    /// # });
242    /// ```
243    pub fn new() -> Self {
244        Self::default()
245    }
246
247    /// Disable the default behavior of responding to OPTIONS requests
248    /// with the supported methods at a given path
249    pub fn without_options_handling(mut self) -> Self {
250        self.set_options_handling(false);
251        self
252    }
253
254    /// enable or disable the router's behavior of responding to OPTIONS requests with the supported
255    /// methods at given path.
256    ///
257    /// default: enabled
258    pub(crate) fn set_options_handling(&mut self, options_enabled: bool) -> &mut Self {
259        self.handle_options = options_enabled;
260        self
261    }
262
263    /// Enable responding to a request whose path matches a route but whose method does not with
264    /// `405 Method Not Allowed` and an `Allow` header listing the methods that path does support.
265    ///
266    /// The status is set *without halting the conn*, so a subsequent handler is free to replace it;
267    /// the `405` only stands if nothing else handles the request. Without this enabled (the
268    /// default), a method mismatch falls through unchanged, which typically results in a `404`.
269    ///
270    /// This is opt-in because enabling it changes the response to requests that previously fell
271    /// through the router, and because advertising the supported methods reveals that the path
272    /// exists.
273    pub fn with_method_not_allowed(mut self) -> Self {
274        self.set_method_not_allowed(true);
275        self
276    }
277
278    /// enable or disable the router's behavior of responding to a method mismatch on a known path
279    /// with `405 Method Not Allowed` plus an `Allow` header.
280    ///
281    /// default: disabled
282    pub(crate) fn set_method_not_allowed(&mut self, enabled: bool) -> &mut Self {
283        self.handle_method_not_allowed = enabled;
284        self
285    }
286
287    /// Another way to build a router, if you don't like the chainable
288    /// interface described in [`Router::new`]. Note that the argument to
289    /// the closure is a [`RouterRef`].
290    ///
291    /// ```
292    /// # use trillium::Conn;
293    /// # use trillium_router::Router;
294    /// # use trillium_testing::TestServer;
295    /// # trillium_testing::block_on(async {
296    /// let router = Router::build(|mut router| {
297    ///     router.get("/", |conn: Conn| async move {
298    ///         conn.ok("you have reached the index")
299    ///     });
300    ///
301    ///     router.get("/some/:paramroute", |conn: Conn| async move {
302    ///         conn.ok("you have reached /some/:param")
303    ///     });
304    ///
305    ///     router.post("/", |conn: Conn| async move { conn.ok("post!") });
306    /// });
307    ///
308    /// let app = TestServer::new(router).await;
309    /// app.get("/")
310    ///     .await
311    ///     .assert_ok()
312    ///     .assert_body("you have reached the index");
313    /// app.get("/some/route")
314    ///     .await
315    ///     .assert_ok()
316    ///     .assert_body("you have reached /some/:param");
317    /// app.post("/").await.assert_ok().assert_body("post!");
318    /// # });
319    /// ```
320    pub fn build(builder: impl Fn(RouterRef)) -> Router {
321        let mut router = Router::new();
322        builder(RouterRef::new(&mut router));
323        router
324    }
325
326    fn best_match<'a, 'b>(
327        &'a self,
328        method: Method,
329        path: &'b str,
330    ) -> Option<Match<'a, 'b, (MethodSelection, BoxedHandler)>> {
331        self.routefinder.best_match(method, path)
332    }
333
334    /// Registers a handler for a method other than get, put, post, patch, or delete.
335    ///
336    /// ```
337    /// # use trillium::{Conn, Method};
338    /// # use trillium_router::Router;
339    /// # use trillium_testing::TestServer;
340    /// # trillium_testing::block_on(async {
341    /// let router = Router::new()
342    ///     .with_route("OPTIONS", "/some/route", |conn: Conn| async move {
343    ///         conn.ok("directly handling options")
344    ///     })
345    ///     .with_route(Method::Checkin, "/some/route", |conn: Conn| async move {
346    ///         conn.ok("checkin??")
347    ///     });
348    ///
349    /// let app = TestServer::new(router).await;
350    /// app.build(Method::Options, "/some/route")
351    ///     .await
352    ///     .assert_ok()
353    ///     .assert_body("directly handling options");
354    /// app.build(Method::Checkin, "/some/route")
355    ///     .await
356    ///     .assert_ok()
357    ///     .assert_body("checkin??");
358    /// # });
359    /// ```
360    pub fn with_route<M, R>(mut self, method: M, path: R, handler: impl Handler) -> Self
361    where
362        M: TryInto<Method>,
363        <M as TryInto<Method>>::Error: Debug,
364        R: TryInto<RouteSpec>,
365        R::Error: Debug,
366    {
367        self.add(path, method.try_into().unwrap(), handler);
368        self
369    }
370
371    pub(crate) fn add<R>(&mut self, path: R, method: Method, handler: impl Handler)
372    where
373        R: TryInto<RouteSpec>,
374        R::Error: Debug,
375    {
376        self.routefinder.add(method, path, handler);
377    }
378
379    pub(crate) fn add_any<R>(&mut self, methods: &[Method], path: R, handler: impl Handler)
380    where
381        R: TryInto<RouteSpec>,
382        R::Error: Debug,
383    {
384        self.routefinder.add(methods, path, handler)
385    }
386
387    pub(crate) fn add_all<R>(&mut self, path: R, handler: impl Handler)
388    where
389        R: TryInto<RouteSpec>,
390        R::Error: Debug,
391    {
392        self.routefinder.add((), path, handler);
393    }
394
395    /// Appends the handler to all (get, post, put, delete, and patch) methods.
396    /// ```
397    /// # use trillium::Conn;
398    /// # use trillium_router::Router;
399    /// # use trillium_testing::TestServer;
400    /// let router = Router::new().all("/any", |conn: Conn| async move {
401    ///     let response = format!("you made a {} request to /any", conn.method());
402    ///     conn.ok(response)
403    /// });
404    ///
405    /// # trillium_testing::block_on(async {
406    /// let app = TestServer::new(router).await;
407    ///
408    /// app.get("/any")
409    ///     .await
410    ///     .assert_ok()
411    ///     .assert_body("you made a GET request to /any");
412    ///
413    /// app.post("/any")
414    ///     .await
415    ///     .assert_ok()
416    ///     .assert_body("you made a POST request to /any");
417    ///
418    /// app.delete("/any")
419    ///     .await
420    ///     .assert_ok()
421    ///     .assert_body("you made a DELETE request to /any");
422    ///
423    /// app.patch("/any")
424    ///     .await
425    ///     .assert_ok()
426    ///     .assert_body("you made a PATCH request to /any");
427    ///
428    /// app.put("/any")
429    ///     .await
430    ///     .assert_ok()
431    ///     .assert_body("you made a PUT request to /any");
432    ///
433    /// app.get("/").await.assert_status(404);
434    /// # });
435    /// ```
436    pub fn all<R>(mut self, path: R, handler: impl Handler) -> Self
437    where
438        R: TryInto<RouteSpec>,
439        R::Error: Debug,
440    {
441        self.add_all(path, handler);
442        self
443    }
444
445    /// Appends the handler to each of the provided http methods.
446    /// ```
447    /// # use trillium::Conn;
448    /// # use trillium_router::Router;
449    /// # use trillium_testing::TestServer;
450    /// let router = Router::new().any(&["get", "post"], "/get_or_post", |conn: Conn| async move {
451    ///     let response = format!("you made a {} request to /get_or_post", conn.method());
452    ///     conn.ok(response)
453    /// });
454    ///
455    /// # trillium_testing::block_on(async {
456    /// let app = TestServer::new(router).await;
457    ///
458    /// app.get("/get_or_post")
459    ///     .await
460    ///     .assert_ok()
461    ///     .assert_body("you made a GET request to /get_or_post");
462    ///
463    /// app.post("/get_or_post")
464    ///     .await
465    ///     .assert_ok()
466    ///     .assert_body("you made a POST request to /get_or_post");
467    ///
468    /// app.delete("/any").await.assert_status(404);
469    /// app.patch("/any").await.assert_status(404);
470    /// app.put("/any").await.assert_status(404);
471    /// app.get("/").await.assert_status(404);
472    /// # });
473    /// ```
474    pub fn any<IntoMethod, R>(
475        mut self,
476        methods: &[IntoMethod],
477        path: R,
478        handler: impl Handler,
479    ) -> Self
480    where
481        IntoMethod: TryInto<Method> + Clone,
482        <IntoMethod as TryInto<Method>>::Error: Debug,
483        R: TryInto<RouteSpec>,
484        R::Error: Debug,
485    {
486        let methods = methods
487            .iter()
488            .cloned()
489            .map(|m| m.try_into().unwrap())
490            .collect::<Vec<_>>();
491        self.add_any(&methods, path, handler);
492        self
493    }
494}
495
496impl Handler for Router {
497    async fn run(&self, mut conn: Conn) -> Conn {
498        let method = conn.method();
499        let original_captures = conn.take_state();
500        let path = conn.path();
501        let mut has_path = false;
502
503        if let Some(m) = self.best_match(conn.method(), path) {
504            let mut captures = m.captures().into_owned();
505
506            let route = m.route().clone();
507
508            if let Some(CapturesNewType(mut original_captures)) = original_captures {
509                original_captures.append(captures);
510                captures = original_captures;
511            }
512
513            log::debug!("running {}: {}", m.route(), m.1.name());
514            let mut new_conn = m
515                .handler()
516                .1
517                .run({
518                    if let Some(wildcard) = captures.wildcard() {
519                        conn.push_path(String::from(wildcard));
520                        has_path = true;
521                    }
522
523                    conn.with_state(CapturesNewType(captures))
524                        .with_state(RouteSpecNewType(route))
525                })
526                .await;
527
528            if has_path {
529                new_conn.pop_path();
530            }
531            new_conn
532        } else if method == Method::Options && self.handle_options {
533            let allow = self
534                .routefinder
535                .methods_matching(path)
536                .iter()
537                .map(|m| m.as_ref())
538                .collect::<Vec<_>>()
539                .join(", ");
540
541            conn.with_response_header(KnownHeaderName::Allow, allow)
542                .with_status(200)
543                .halt()
544        } else if let Some(allow) = self
545            .handle_method_not_allowed
546            .then(|| self.routefinder.methods_matching(path))
547            .filter(|methods| !methods.is_empty())
548            .map(|methods| {
549                methods
550                    .iter()
551                    .map(|m| m.as_ref())
552                    .collect::<Vec<_>>()
553                    .join(", ")
554            })
555        {
556            // Soft default: set the status without halting, so a later handler can replace the
557            // 405. If nothing does, it stands as the fall-through response.
558            conn.with_response_header(KnownHeaderName::Allow, allow)
559                .with_status(405)
560        } else {
561            log::debug!("{} did not match any route", conn.path());
562            conn
563        }
564    }
565
566    async fn before_send(&self, conn: Conn) -> Conn {
567        let path = conn.path();
568        if let Some(m) = self.best_match(conn.method(), path) {
569            m.handler().1.before_send(conn).await
570        } else {
571            conn
572        }
573    }
574
575    fn has_upgrade(&self, upgrade: &Upgrade) -> bool {
576        if let Some(m) = self.best_match(upgrade.method(), upgrade.path()) {
577            m.1.has_upgrade(upgrade)
578        } else {
579            false
580        }
581    }
582
583    async fn upgrade(&self, upgrade: Upgrade) {
584        self.best_match(upgrade.method(), upgrade.path())
585            .unwrap()
586            .handler()
587            .1
588            .upgrade(upgrade)
589            .await
590    }
591
592    fn name(&self) -> Cow<'static, str> {
593        "Router".into()
594    }
595
596    async fn init(&mut self, info: &mut Info) {
597        // This code is not what a reader would expect, so here's a
598        // brief explanation:
599        //
600        // Mutable map iterators are not Send, and because we need to hold that data across await
601        // boundaries in a Send future, we cannot mutate in place.
602        //
603        // However, because this is only called once at app boot, and because we have &mut self, it
604        // is safe to move the router contents into this future and then replace it, and the
605        // performance impacts of doing so are unimportant as it is part of app boot.
606        let routefinder = mem::take(&mut self.routefinder);
607        for (route, (methods, mut handler)) in routefinder.0 {
608            handler.init(info).await;
609            self.routefinder.add(methods, route, handler);
610        }
611    }
612}
613
614impl Debug for Router {
615    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
616        f.write_str("Router ")?;
617        let mut set = f.debug_set();
618
619        for (route, (methods, handler)) in &self.routefinder.0 {
620            set.entry(&format_args!("{} {} -> {}", methods, route, handler.name()));
621        }
622        set.finish()
623    }
624}