Skip to main content

trillium_router/
router.rs

1use crate::{CapturesNewType, RouteSpecNewType, RouterRef};
2use routefinder::{Match, RouteSpec, Router as Routefinder};
3use std::{
4    borrow::Cow,
5    collections::BTreeSet,
6    fmt::{self, Debug, Display, Formatter},
7    mem,
8};
9use trillium::{BoxedHandler, Conn, Handler, Info, KnownHeaderName, Method, Upgrade};
10
11const ALL_METHODS: [Method; 5] = [
12    Method::Delete,
13    Method::Get,
14    Method::Patch,
15    Method::Post,
16    Method::Put,
17];
18
19#[derive(Debug)]
20enum MethodSelection {
21    Just(Method),
22    All,
23    Any(Vec<Method>),
24}
25
26impl Display for MethodSelection {
27    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
28        match self {
29            MethodSelection::Just(m) => Display::fmt(m, f),
30            MethodSelection::All => f.write_str("*"),
31            MethodSelection::Any(v) => {
32                f.write_str(&v.iter().map(|m| m.as_ref()).collect::<Vec<_>>().join(", "))
33            }
34        }
35    }
36}
37
38impl PartialEq<Method> for MethodSelection {
39    fn eq(&self, other: &Method) -> bool {
40        match self {
41            MethodSelection::Just(m) => m == other,
42            MethodSelection::All => true,
43            MethodSelection::Any(v) => v.contains(other),
44        }
45    }
46}
47
48impl From<()> for MethodSelection {
49    fn from(_: ()) -> MethodSelection {
50        Self::All
51    }
52}
53
54impl From<Method> for MethodSelection {
55    fn from(method: Method) -> Self {
56        Self::Just(method)
57    }
58}
59
60impl From<&[Method]> for MethodSelection {
61    fn from(methods: &[Method]) -> Self {
62        Self::Any(methods.to_vec())
63    }
64}
65impl From<Vec<Method>> for MethodSelection {
66    fn from(methods: Vec<Method>) -> Self {
67        Self::Any(methods)
68    }
69}
70
71#[derive(Debug, Default)]
72struct MethodRoutefinder(Routefinder<(MethodSelection, BoxedHandler)>);
73impl MethodRoutefinder {
74    fn add<R>(
75        &mut self,
76        method_selection: impl Into<MethodSelection>,
77        path: R,
78        handler: impl Handler,
79    ) where
80        R: TryInto<RouteSpec>,
81        R::Error: Debug,
82    {
83        self.0
84            .add(path, (method_selection.into(), BoxedHandler::new(handler)))
85            .expect("could not add route")
86    }
87
88    fn methods_matching(&self, path: &str) -> BTreeSet<Method> {
89        let mut set = BTreeSet::new();
90
91        fn extend(ms: &MethodSelection, set: &mut BTreeSet<Method>) {
92            match ms {
93                MethodSelection::All => {
94                    set.extend(ALL_METHODS);
95                }
96                MethodSelection::Just(method) => {
97                    set.insert(*method);
98                }
99                MethodSelection::Any(methods) => {
100                    set.extend(methods);
101                }
102            }
103        }
104
105        if path == "*" {
106            for ms in self.0.iter().map(|(_, (m, _))| m) {
107                extend(ms, &mut set);
108            }
109        } else {
110            for m in self.0.match_iter(path) {
111                extend(&m.0, &mut set);
112            }
113        };
114
115        set.remove(&Method::Options);
116        set
117    }
118
119    fn best_match<'a, 'b>(
120        &'a self,
121        method: Method,
122        path: &'b str,
123    ) -> Option<Match<'a, 'b, (MethodSelection, BoxedHandler)>> {
124        self.0.match_iter(path).find(|m| m.0 == method)
125    }
126}
127
128/// # The Router handler
129///
130/// See crate level docs for more, as this is the primary type in this crate.
131pub struct Router {
132    routefinder: MethodRoutefinder,
133    handle_options: bool,
134}
135
136impl Default for Router {
137    fn default() -> Self {
138        Self {
139            routefinder: MethodRoutefinder::default(),
140            handle_options: true,
141        }
142    }
143}
144
145macro_rules! method {
146    ($fn_name:ident, $method:ident) => {
147        method!(
148            $fn_name,
149            $method,
150            concat!(
151                // yep, macro-generated doctests
152                "Registers a handler for the ",
153                stringify!($fn_name),
154                " http method.
155
156```
157# use trillium::Conn;
158# use trillium_router::Router;
159# use trillium_testing::TestServer;
160# trillium_testing::block_on(async {
161let router = Router::new().",
162                stringify!($fn_name),
163                "(\"/some/route\", |conn: Conn| async move {
164  conn.ok(\"success\")
165});
166
167let app = TestServer::new(router).await;
168app.",
169                stringify!($fn_name),
170                "(\"/some/route\").await
171    .assert_ok()
172    .assert_body(\"success\");
173app.",
174                stringify!($fn_name),
175                "(\"/other/route\").await
176    .assert_status(404);
177# });
178```
179"
180            )
181        );
182    };
183
184    ($fn_name:ident, $method:ident, $doc_comment:expr_2021) => {
185        #[doc = $doc_comment]
186        pub fn $fn_name<R>(mut self, path: R, handler: impl Handler) -> Self
187        where
188            R: TryInto<RouteSpec>,
189            R::Error: Debug,
190        {
191            self.add(path, Method::$method, handler);
192            self
193        }
194    };
195}
196
197impl Router {
198    method!(get, Get);
199
200    method!(post, Post);
201
202    method!(put, Put);
203
204    method!(delete, Delete);
205
206    method!(patch, Patch);
207
208    /// Constructs a new Router. This is often used with [`Router::get`],
209    /// [`Router::post`], [`Router::put`], [`Router::delete`], and
210    /// [`Router::patch`] chainable methods to build up an application.
211    ///
212    /// For an alternative way of constructing a Router, see [`Router::build`]
213    ///
214    /// ```
215    /// # use trillium::Conn;
216    /// # use trillium_router::Router;
217    /// # use trillium_testing::TestServer;
218    ///
219    /// # trillium_testing::block_on(async {
220    /// let router = Router::new()
221    ///     .get("/", |conn: Conn| async move {
222    ///         conn.ok("you have reached the index")
223    ///     })
224    ///     .get("/some/:param", |conn: Conn| async move {
225    ///         conn.ok("you have reached /some/:param")
226    ///     })
227    ///     .post("/", |conn: Conn| async move { conn.ok("post!") });
228    ///
229    /// let app = TestServer::new(router).await;
230    /// app.get("/")
231    ///     .await
232    ///     .assert_ok()
233    ///     .assert_body("you have reached the index");
234    /// app.get("/some/route")
235    ///     .await
236    ///     .assert_ok()
237    ///     .assert_body("you have reached /some/:param");
238    /// app.post("/").await.assert_ok().assert_body("post!");
239    /// # });
240    /// ```
241    pub fn new() -> Self {
242        Self::default()
243    }
244
245    /// Disable the default behavior of responding to OPTIONS requests
246    /// with the supported methods at a given path
247    pub fn without_options_handling(mut self) -> Self {
248        self.set_options_handling(false);
249        self
250    }
251
252    /// enable or disable the router's behavior of responding to OPTIONS requests with the supported
253    /// methods at given path.
254    ///
255    /// default: enabled
256    pub(crate) fn set_options_handling(&mut self, options_enabled: bool) -> &mut Self {
257        self.handle_options = options_enabled;
258        self
259    }
260
261    /// Another way to build a router, if you don't like the chainable
262    /// interface described in [`Router::new`]. Note that the argument to
263    /// the closure is a [`RouterRef`].
264    ///
265    /// ```
266    /// # use trillium::Conn;
267    /// # use trillium_router::Router;
268    /// # use trillium_testing::TestServer;
269    /// # trillium_testing::block_on(async {
270    /// let router = Router::build(|mut router| {
271    ///     router.get("/", |conn: Conn| async move {
272    ///         conn.ok("you have reached the index")
273    ///     });
274    ///
275    ///     router.get("/some/:paramroute", |conn: Conn| async move {
276    ///         conn.ok("you have reached /some/:param")
277    ///     });
278    ///
279    ///     router.post("/", |conn: Conn| async move { conn.ok("post!") });
280    /// });
281    ///
282    /// let app = TestServer::new(router).await;
283    /// app.get("/")
284    ///     .await
285    ///     .assert_ok()
286    ///     .assert_body("you have reached the index");
287    /// app.get("/some/route")
288    ///     .await
289    ///     .assert_ok()
290    ///     .assert_body("you have reached /some/:param");
291    /// app.post("/").await.assert_ok().assert_body("post!");
292    /// # });
293    /// ```
294    pub fn build(builder: impl Fn(RouterRef)) -> Router {
295        let mut router = Router::new();
296        builder(RouterRef::new(&mut router));
297        router
298    }
299
300    fn best_match<'a, 'b>(
301        &'a self,
302        method: Method,
303        path: &'b str,
304    ) -> Option<Match<'a, 'b, (MethodSelection, BoxedHandler)>> {
305        self.routefinder.best_match(method, path)
306    }
307
308    /// Registers a handler for a method other than get, put, post, patch, or delete.
309    ///
310    /// ```
311    /// # use trillium::{Conn, Method};
312    /// # use trillium_router::Router;
313    /// # use trillium_testing::TestServer;
314    /// # trillium_testing::block_on(async {
315    /// let router = Router::new()
316    ///     .with_route("OPTIONS", "/some/route", |conn: Conn| async move {
317    ///         conn.ok("directly handling options")
318    ///     })
319    ///     .with_route(Method::Checkin, "/some/route", |conn: Conn| async move {
320    ///         conn.ok("checkin??")
321    ///     });
322    ///
323    /// let app = TestServer::new(router).await;
324    /// app.build(Method::Options, "/some/route")
325    ///     .await
326    ///     .assert_ok()
327    ///     .assert_body("directly handling options");
328    /// app.build(Method::Checkin, "/some/route")
329    ///     .await
330    ///     .assert_ok()
331    ///     .assert_body("checkin??");
332    /// # });
333    /// ```
334    pub fn with_route<M, R>(mut self, method: M, path: R, handler: impl Handler) -> Self
335    where
336        M: TryInto<Method>,
337        <M as TryInto<Method>>::Error: Debug,
338        R: TryInto<RouteSpec>,
339        R::Error: Debug,
340    {
341        self.add(path, method.try_into().unwrap(), handler);
342        self
343    }
344
345    pub(crate) fn add<R>(&mut self, path: R, method: Method, handler: impl Handler)
346    where
347        R: TryInto<RouteSpec>,
348        R::Error: Debug,
349    {
350        self.routefinder.add(method, path, handler);
351    }
352
353    pub(crate) fn add_any<R>(&mut self, methods: &[Method], path: R, handler: impl Handler)
354    where
355        R: TryInto<RouteSpec>,
356        R::Error: Debug,
357    {
358        self.routefinder.add(methods, path, handler)
359    }
360
361    pub(crate) fn add_all<R>(&mut self, path: R, handler: impl Handler)
362    where
363        R: TryInto<RouteSpec>,
364        R::Error: Debug,
365    {
366        self.routefinder.add((), path, handler);
367    }
368
369    /// Appends the handler to all (get, post, put, delete, and patch) methods.
370    /// ```
371    /// # use trillium::Conn;
372    /// # use trillium_router::Router;
373    /// # use trillium_testing::TestServer;
374    /// let router = Router::new().all("/any", |conn: Conn| async move {
375    ///     let response = format!("you made a {} request to /any", conn.method());
376    ///     conn.ok(response)
377    /// });
378    ///
379    /// # trillium_testing::block_on(async {
380    /// let app = TestServer::new(router).await;
381    ///
382    /// app.get("/any")
383    ///     .await
384    ///     .assert_ok()
385    ///     .assert_body("you made a GET request to /any");
386    ///
387    /// app.post("/any")
388    ///     .await
389    ///     .assert_ok()
390    ///     .assert_body("you made a POST request to /any");
391    ///
392    /// app.delete("/any")
393    ///     .await
394    ///     .assert_ok()
395    ///     .assert_body("you made a DELETE request to /any");
396    ///
397    /// app.patch("/any")
398    ///     .await
399    ///     .assert_ok()
400    ///     .assert_body("you made a PATCH request to /any");
401    ///
402    /// app.put("/any")
403    ///     .await
404    ///     .assert_ok()
405    ///     .assert_body("you made a PUT request to /any");
406    ///
407    /// app.get("/").await.assert_status(404);
408    /// # });
409    /// ```
410    pub fn all<R>(mut self, path: R, handler: impl Handler) -> Self
411    where
412        R: TryInto<RouteSpec>,
413        R::Error: Debug,
414    {
415        self.add_all(path, handler);
416        self
417    }
418
419    /// Appends the handler to each of the provided http methods.
420    /// ```
421    /// # use trillium::Conn;
422    /// # use trillium_router::Router;
423    /// # use trillium_testing::TestServer;
424    /// let router = Router::new().any(&["get", "post"], "/get_or_post", |conn: Conn| async move {
425    ///     let response = format!("you made a {} request to /get_or_post", conn.method());
426    ///     conn.ok(response)
427    /// });
428    ///
429    /// # trillium_testing::block_on(async {
430    /// let app = TestServer::new(router).await;
431    ///
432    /// app.get("/get_or_post")
433    ///     .await
434    ///     .assert_ok()
435    ///     .assert_body("you made a GET request to /get_or_post");
436    ///
437    /// app.post("/get_or_post")
438    ///     .await
439    ///     .assert_ok()
440    ///     .assert_body("you made a POST request to /get_or_post");
441    ///
442    /// app.delete("/any").await.assert_status(404);
443    /// app.patch("/any").await.assert_status(404);
444    /// app.put("/any").await.assert_status(404);
445    /// app.get("/").await.assert_status(404);
446    /// # });
447    /// ```
448    pub fn any<IntoMethod, R>(
449        mut self,
450        methods: &[IntoMethod],
451        path: R,
452        handler: impl Handler,
453    ) -> Self
454    where
455        IntoMethod: TryInto<Method> + Clone,
456        <IntoMethod as TryInto<Method>>::Error: Debug,
457        R: TryInto<RouteSpec>,
458        R::Error: Debug,
459    {
460        let methods = methods
461            .iter()
462            .cloned()
463            .map(|m| m.try_into().unwrap())
464            .collect::<Vec<_>>();
465        self.add_any(&methods, path, handler);
466        self
467    }
468}
469
470impl Handler for Router {
471    async fn run(&self, mut conn: Conn) -> Conn {
472        let method = conn.method();
473        let original_captures = conn.take_state();
474        let path = conn.path();
475        let mut has_path = false;
476
477        if let Some(m) = self.best_match(conn.method(), path) {
478            let mut captures = m.captures().into_owned();
479
480            let route = m.route().clone();
481
482            if let Some(CapturesNewType(mut original_captures)) = original_captures {
483                original_captures.append(captures);
484                captures = original_captures;
485            }
486
487            log::debug!("running {}: {}", m.route(), m.1.name());
488            let mut new_conn = m
489                .handler()
490                .1
491                .run({
492                    if let Some(wildcard) = captures.wildcard() {
493                        conn.push_path(String::from(wildcard));
494                        has_path = true;
495                    }
496
497                    conn.with_state(CapturesNewType(captures))
498                        .with_state(RouteSpecNewType(route))
499                })
500                .await;
501
502            if has_path {
503                new_conn.pop_path();
504            }
505            new_conn
506        } else if method == Method::Options && self.handle_options {
507            let allow = self
508                .routefinder
509                .methods_matching(path)
510                .iter()
511                .map(|m| m.as_ref())
512                .collect::<Vec<_>>()
513                .join(", ");
514
515            conn.with_response_header(KnownHeaderName::Allow, allow)
516                .with_status(200)
517                .halt()
518        } else {
519            log::debug!("{} did not match any route", conn.path());
520            conn
521        }
522    }
523
524    async fn before_send(&self, conn: Conn) -> Conn {
525        let path = conn.path();
526        if let Some(m) = self.best_match(conn.method(), path) {
527            m.handler().1.before_send(conn).await
528        } else {
529            conn
530        }
531    }
532
533    fn has_upgrade(&self, upgrade: &Upgrade) -> bool {
534        if let Some(m) = self.best_match(upgrade.method(), upgrade.path()) {
535            m.1.has_upgrade(upgrade)
536        } else {
537            false
538        }
539    }
540
541    async fn upgrade(&self, upgrade: Upgrade) {
542        self.best_match(upgrade.method(), upgrade.path())
543            .unwrap()
544            .handler()
545            .1
546            .upgrade(upgrade)
547            .await
548    }
549
550    fn name(&self) -> Cow<'static, str> {
551        "Router".into()
552    }
553
554    async fn init(&mut self, info: &mut Info) {
555        // This code is not what a reader would expect, so here's a
556        // brief explanation:
557        //
558        // Mutable map iterators are not Send, and because we need to hold that data across await
559        // boundaries in a Send future, we cannot mutate in place.
560        //
561        // However, because this is only called once at app boot, and because we have &mut self, it
562        // is safe to move the router contents into this future and then replace it, and the
563        // performance impacts of doing so are unimportant as it is part of app boot.
564        let routefinder = mem::take(&mut self.routefinder);
565        for (route, (methods, mut handler)) in routefinder.0 {
566            handler.init(info).await;
567            self.routefinder.add(methods, route, handler);
568        }
569    }
570}
571
572impl Debug for Router {
573    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
574        f.write_str("Router ")?;
575        let mut set = f.debug_set();
576
577        for (route, (methods, handler)) in &self.routefinder.0 {
578            set.entry(&format_args!("{} {} -> {}", methods, route, handler.name()));
579        }
580        set.finish()
581    }
582}