trillium_router/router.rs
1use crate::{CapturesNewType, RouteSpecNewType, RouterRef};
2use routefinder::{Match, RouteSpec, Router as Routefinder};
3use std::{
4 borrow::Cow,
5 collections::BTreeSet,
6 fmt::{self, Debug, Display, Formatter},
7 mem,
8};
9use trillium::{BoxedHandler, Conn, Handler, Info, KnownHeaderName, Method, Upgrade};
10
11const ALL_METHODS: [Method; 5] = [
12 Method::Delete,
13 Method::Get,
14 Method::Patch,
15 Method::Post,
16 Method::Put,
17];
18
19#[derive(Debug)]
20enum MethodSelection {
21 Just(Method),
22 All,
23 Any(Vec<Method>),
24}
25
26impl Display for MethodSelection {
27 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
28 match self {
29 MethodSelection::Just(m) => Display::fmt(m, f),
30 MethodSelection::All => f.write_str("*"),
31 MethodSelection::Any(v) => {
32 f.write_str(&v.iter().map(|m| m.as_ref()).collect::<Vec<_>>().join(", "))
33 }
34 }
35 }
36}
37
38impl PartialEq<Method> for MethodSelection {
39 fn eq(&self, other: &Method) -> bool {
40 match self {
41 MethodSelection::Just(m) => m == other,
42 MethodSelection::All => true,
43 MethodSelection::Any(v) => v.contains(other),
44 }
45 }
46}
47
48impl From<()> for MethodSelection {
49 fn from(_: ()) -> MethodSelection {
50 Self::All
51 }
52}
53
54impl From<Method> for MethodSelection {
55 fn from(method: Method) -> Self {
56 Self::Just(method)
57 }
58}
59
60impl From<&[Method]> for MethodSelection {
61 fn from(methods: &[Method]) -> Self {
62 Self::Any(methods.to_vec())
63 }
64}
65impl From<Vec<Method>> for MethodSelection {
66 fn from(methods: Vec<Method>) -> Self {
67 Self::Any(methods)
68 }
69}
70
71#[derive(Debug, Default)]
72struct MethodRoutefinder(Routefinder<(MethodSelection, BoxedHandler)>);
73impl MethodRoutefinder {
74 fn add<R>(
75 &mut self,
76 method_selection: impl Into<MethodSelection>,
77 path: R,
78 handler: impl Handler,
79 ) where
80 R: TryInto<RouteSpec>,
81 R::Error: Debug,
82 {
83 self.0
84 .add(path, (method_selection.into(), BoxedHandler::new(handler)))
85 .expect("could not add route")
86 }
87
88 fn methods_matching(&self, path: &str) -> BTreeSet<Method> {
89 let mut set = BTreeSet::new();
90
91 fn extend(ms: &MethodSelection, set: &mut BTreeSet<Method>) {
92 match ms {
93 MethodSelection::All => {
94 set.extend(ALL_METHODS);
95 }
96 MethodSelection::Just(method) => {
97 set.insert(*method);
98 }
99 MethodSelection::Any(methods) => {
100 set.extend(methods);
101 }
102 }
103 }
104
105 if path == "*" {
106 for ms in self.0.iter().map(|(_, (m, _))| m) {
107 extend(ms, &mut set);
108 }
109 } else {
110 for m in self.0.match_iter(path) {
111 extend(&m.0, &mut set);
112 }
113 };
114
115 set.remove(&Method::Options);
116 set
117 }
118
119 fn best_match<'a, 'b>(
120 &'a self,
121 method: Method,
122 path: &'b str,
123 ) -> Option<Match<'a, 'b, (MethodSelection, BoxedHandler)>> {
124 self.0.match_iter(path).find(|m| m.0 == method)
125 }
126}
127
128/// # The Router handler
129///
130/// See crate level docs for more, as this is the primary type in this crate.
131pub struct Router {
132 routefinder: MethodRoutefinder,
133 handle_options: bool,
134 handle_method_not_allowed: bool,
135}
136
137impl Default for Router {
138 fn default() -> Self {
139 Self {
140 routefinder: MethodRoutefinder::default(),
141 handle_options: true,
142 handle_method_not_allowed: false,
143 }
144 }
145}
146
147macro_rules! method {
148 ($fn_name:ident, $method:ident) => {
149 method!(
150 $fn_name,
151 $method,
152 concat!(
153 // yep, macro-generated doctests
154 "Registers a handler for the ",
155 stringify!($fn_name),
156 " http method.
157
158```
159# use trillium::Conn;
160# use trillium_router::Router;
161# use trillium_testing::TestServer;
162# trillium_testing::block_on(async {
163let router = Router::new().",
164 stringify!($fn_name),
165 "(\"/some/route\", |conn: Conn| async move {
166 conn.ok(\"success\")
167});
168
169let app = TestServer::new(router).await;
170app.",
171 stringify!($fn_name),
172 "(\"/some/route\").await
173 .assert_ok()
174 .assert_body(\"success\");
175app.",
176 stringify!($fn_name),
177 "(\"/other/route\").await
178 .assert_status(404);
179# });
180```
181"
182 )
183 );
184 };
185
186 ($fn_name:ident, $method:ident, $doc_comment:expr_2021) => {
187 #[doc = $doc_comment]
188 pub fn $fn_name<R>(mut self, path: R, handler: impl Handler) -> Self
189 where
190 R: TryInto<RouteSpec>,
191 R::Error: Debug,
192 {
193 self.add(path, Method::$method, handler);
194 self
195 }
196 };
197}
198
199impl Router {
200 method!(get, Get);
201
202 method!(post, Post);
203
204 method!(put, Put);
205
206 method!(delete, Delete);
207
208 method!(patch, Patch);
209
210 /// Constructs a new Router. This is often used with [`Router::get`],
211 /// [`Router::post`], [`Router::put`], [`Router::delete`], and
212 /// [`Router::patch`] chainable methods to build up an application.
213 ///
214 /// For an alternative way of constructing a Router, see [`Router::build`]
215 ///
216 /// ```
217 /// # use trillium::Conn;
218 /// # use trillium_router::Router;
219 /// # use trillium_testing::TestServer;
220 ///
221 /// # trillium_testing::block_on(async {
222 /// let router = Router::new()
223 /// .get("/", |conn: Conn| async move {
224 /// conn.ok("you have reached the index")
225 /// })
226 /// .get("/some/:param", |conn: Conn| async move {
227 /// conn.ok("you have reached /some/:param")
228 /// })
229 /// .post("/", |conn: Conn| async move { conn.ok("post!") });
230 ///
231 /// let app = TestServer::new(router).await;
232 /// app.get("/")
233 /// .await
234 /// .assert_ok()
235 /// .assert_body("you have reached the index");
236 /// app.get("/some/route")
237 /// .await
238 /// .assert_ok()
239 /// .assert_body("you have reached /some/:param");
240 /// app.post("/").await.assert_ok().assert_body("post!");
241 /// # });
242 /// ```
243 pub fn new() -> Self {
244 Self::default()
245 }
246
247 /// Disable the default behavior of responding to OPTIONS requests
248 /// with the supported methods at a given path
249 pub fn without_options_handling(mut self) -> Self {
250 self.set_options_handling(false);
251 self
252 }
253
254 /// enable or disable the router's behavior of responding to OPTIONS requests with the supported
255 /// methods at given path.
256 ///
257 /// default: enabled
258 pub(crate) fn set_options_handling(&mut self, options_enabled: bool) -> &mut Self {
259 self.handle_options = options_enabled;
260 self
261 }
262
263 /// Enable responding to a request whose path matches a route but whose method does not with
264 /// `405 Method Not Allowed` and an `Allow` header listing the methods that path does support.
265 ///
266 /// The status is set *without halting the conn*, so a subsequent handler is free to replace it;
267 /// the `405` only stands if nothing else handles the request. Without this enabled (the
268 /// default), a method mismatch falls through unchanged, which typically results in a `404`.
269 ///
270 /// This is opt-in because enabling it changes the response to requests that previously fell
271 /// through the router, and because advertising the supported methods reveals that the path
272 /// exists.
273 pub fn with_method_not_allowed(mut self) -> Self {
274 self.set_method_not_allowed(true);
275 self
276 }
277
278 /// enable or disable the router's behavior of responding to a method mismatch on a known path
279 /// with `405 Method Not Allowed` plus an `Allow` header.
280 ///
281 /// default: disabled
282 pub(crate) fn set_method_not_allowed(&mut self, enabled: bool) -> &mut Self {
283 self.handle_method_not_allowed = enabled;
284 self
285 }
286
287 /// Another way to build a router, if you don't like the chainable
288 /// interface described in [`Router::new`]. Note that the argument to
289 /// the closure is a [`RouterRef`].
290 ///
291 /// ```
292 /// # use trillium::Conn;
293 /// # use trillium_router::Router;
294 /// # use trillium_testing::TestServer;
295 /// # trillium_testing::block_on(async {
296 /// let router = Router::build(|mut router| {
297 /// router.get("/", |conn: Conn| async move {
298 /// conn.ok("you have reached the index")
299 /// });
300 ///
301 /// router.get("/some/:paramroute", |conn: Conn| async move {
302 /// conn.ok("you have reached /some/:param")
303 /// });
304 ///
305 /// router.post("/", |conn: Conn| async move { conn.ok("post!") });
306 /// });
307 ///
308 /// let app = TestServer::new(router).await;
309 /// app.get("/")
310 /// .await
311 /// .assert_ok()
312 /// .assert_body("you have reached the index");
313 /// app.get("/some/route")
314 /// .await
315 /// .assert_ok()
316 /// .assert_body("you have reached /some/:param");
317 /// app.post("/").await.assert_ok().assert_body("post!");
318 /// # });
319 /// ```
320 pub fn build(builder: impl Fn(RouterRef)) -> Router {
321 let mut router = Router::new();
322 builder(RouterRef::new(&mut router));
323 router
324 }
325
326 fn best_match<'a, 'b>(
327 &'a self,
328 method: Method,
329 path: &'b str,
330 ) -> Option<Match<'a, 'b, (MethodSelection, BoxedHandler)>> {
331 self.routefinder.best_match(method, path)
332 }
333
334 /// Registers a handler for a method other than get, put, post, patch, or delete.
335 ///
336 /// ```
337 /// # use trillium::{Conn, Method};
338 /// # use trillium_router::Router;
339 /// # use trillium_testing::TestServer;
340 /// # trillium_testing::block_on(async {
341 /// let router = Router::new()
342 /// .with_route("OPTIONS", "/some/route", |conn: Conn| async move {
343 /// conn.ok("directly handling options")
344 /// })
345 /// .with_route(Method::Checkin, "/some/route", |conn: Conn| async move {
346 /// conn.ok("checkin??")
347 /// });
348 ///
349 /// let app = TestServer::new(router).await;
350 /// app.build(Method::Options, "/some/route")
351 /// .await
352 /// .assert_ok()
353 /// .assert_body("directly handling options");
354 /// app.build(Method::Checkin, "/some/route")
355 /// .await
356 /// .assert_ok()
357 /// .assert_body("checkin??");
358 /// # });
359 /// ```
360 pub fn with_route<M, R>(mut self, method: M, path: R, handler: impl Handler) -> Self
361 where
362 M: TryInto<Method>,
363 <M as TryInto<Method>>::Error: Debug,
364 R: TryInto<RouteSpec>,
365 R::Error: Debug,
366 {
367 self.add(path, method.try_into().unwrap(), handler);
368 self
369 }
370
371 pub(crate) fn add<R>(&mut self, path: R, method: Method, handler: impl Handler)
372 where
373 R: TryInto<RouteSpec>,
374 R::Error: Debug,
375 {
376 self.routefinder.add(method, path, handler);
377 }
378
379 pub(crate) fn add_any<R>(&mut self, methods: &[Method], path: R, handler: impl Handler)
380 where
381 R: TryInto<RouteSpec>,
382 R::Error: Debug,
383 {
384 self.routefinder.add(methods, path, handler)
385 }
386
387 pub(crate) fn add_all<R>(&mut self, path: R, handler: impl Handler)
388 where
389 R: TryInto<RouteSpec>,
390 R::Error: Debug,
391 {
392 self.routefinder.add((), path, handler);
393 }
394
395 /// Appends the handler to all (get, post, put, delete, and patch) methods.
396 /// ```
397 /// # use trillium::Conn;
398 /// # use trillium_router::Router;
399 /// # use trillium_testing::TestServer;
400 /// let router = Router::new().all("/any", |conn: Conn| async move {
401 /// let response = format!("you made a {} request to /any", conn.method());
402 /// conn.ok(response)
403 /// });
404 ///
405 /// # trillium_testing::block_on(async {
406 /// let app = TestServer::new(router).await;
407 ///
408 /// app.get("/any")
409 /// .await
410 /// .assert_ok()
411 /// .assert_body("you made a GET request to /any");
412 ///
413 /// app.post("/any")
414 /// .await
415 /// .assert_ok()
416 /// .assert_body("you made a POST request to /any");
417 ///
418 /// app.delete("/any")
419 /// .await
420 /// .assert_ok()
421 /// .assert_body("you made a DELETE request to /any");
422 ///
423 /// app.patch("/any")
424 /// .await
425 /// .assert_ok()
426 /// .assert_body("you made a PATCH request to /any");
427 ///
428 /// app.put("/any")
429 /// .await
430 /// .assert_ok()
431 /// .assert_body("you made a PUT request to /any");
432 ///
433 /// app.get("/").await.assert_status(404);
434 /// # });
435 /// ```
436 pub fn all<R>(mut self, path: R, handler: impl Handler) -> Self
437 where
438 R: TryInto<RouteSpec>,
439 R::Error: Debug,
440 {
441 self.add_all(path, handler);
442 self
443 }
444
445 /// Appends the handler to each of the provided http methods.
446 /// ```
447 /// # use trillium::Conn;
448 /// # use trillium_router::Router;
449 /// # use trillium_testing::TestServer;
450 /// let router = Router::new().any(&["get", "post"], "/get_or_post", |conn: Conn| async move {
451 /// let response = format!("you made a {} request to /get_or_post", conn.method());
452 /// conn.ok(response)
453 /// });
454 ///
455 /// # trillium_testing::block_on(async {
456 /// let app = TestServer::new(router).await;
457 ///
458 /// app.get("/get_or_post")
459 /// .await
460 /// .assert_ok()
461 /// .assert_body("you made a GET request to /get_or_post");
462 ///
463 /// app.post("/get_or_post")
464 /// .await
465 /// .assert_ok()
466 /// .assert_body("you made a POST request to /get_or_post");
467 ///
468 /// app.delete("/any").await.assert_status(404);
469 /// app.patch("/any").await.assert_status(404);
470 /// app.put("/any").await.assert_status(404);
471 /// app.get("/").await.assert_status(404);
472 /// # });
473 /// ```
474 pub fn any<IntoMethod, R>(
475 mut self,
476 methods: &[IntoMethod],
477 path: R,
478 handler: impl Handler,
479 ) -> Self
480 where
481 IntoMethod: TryInto<Method> + Clone,
482 <IntoMethod as TryInto<Method>>::Error: Debug,
483 R: TryInto<RouteSpec>,
484 R::Error: Debug,
485 {
486 let methods = methods
487 .iter()
488 .cloned()
489 .map(|m| m.try_into().unwrap())
490 .collect::<Vec<_>>();
491 self.add_any(&methods, path, handler);
492 self
493 }
494}
495
496impl Handler for Router {
497 async fn run(&self, mut conn: Conn) -> Conn {
498 let method = conn.method();
499 let original_captures = conn.take_state();
500 let path = conn.path();
501 let mut has_path = false;
502
503 if let Some(m) = self.best_match(conn.method(), path) {
504 let mut captures = m.captures().into_owned();
505
506 let route = m.route().clone();
507
508 if let Some(CapturesNewType(mut original_captures)) = original_captures {
509 original_captures.append(captures);
510 captures = original_captures;
511 }
512
513 log::debug!("running {}: {}", m.route(), m.1.name());
514 let mut new_conn = m
515 .handler()
516 .1
517 .run({
518 if let Some(wildcard) = captures.wildcard() {
519 conn.push_path(String::from(wildcard));
520 has_path = true;
521 }
522
523 conn.with_state(CapturesNewType(captures))
524 .with_state(RouteSpecNewType(route))
525 })
526 .await;
527
528 if has_path {
529 new_conn.pop_path();
530 }
531 new_conn
532 } else if method == Method::Options && self.handle_options {
533 let allow = self
534 .routefinder
535 .methods_matching(path)
536 .iter()
537 .map(|m| m.as_ref())
538 .collect::<Vec<_>>()
539 .join(", ");
540
541 conn.with_response_header(KnownHeaderName::Allow, allow)
542 .with_status(200)
543 .halt()
544 } else if let Some(allow) = self
545 .handle_method_not_allowed
546 .then(|| self.routefinder.methods_matching(path))
547 .filter(|methods| !methods.is_empty())
548 .map(|methods| {
549 methods
550 .iter()
551 .map(|m| m.as_ref())
552 .collect::<Vec<_>>()
553 .join(", ")
554 })
555 {
556 // Soft default: set the status without halting, so a later handler can replace the
557 // 405. If nothing does, it stands as the fall-through response.
558 conn.with_response_header(KnownHeaderName::Allow, allow)
559 .with_status(405)
560 } else {
561 log::debug!("{} did not match any route", conn.path());
562 conn
563 }
564 }
565
566 async fn before_send(&self, conn: Conn) -> Conn {
567 let path = conn.path();
568 if let Some(m) = self.best_match(conn.method(), path) {
569 m.handler().1.before_send(conn).await
570 } else {
571 conn
572 }
573 }
574
575 fn has_upgrade(&self, upgrade: &Upgrade) -> bool {
576 if let Some(m) = self.best_match(upgrade.method(), upgrade.path()) {
577 m.1.has_upgrade(upgrade)
578 } else {
579 false
580 }
581 }
582
583 async fn upgrade(&self, upgrade: Upgrade) {
584 self.best_match(upgrade.method(), upgrade.path())
585 .unwrap()
586 .handler()
587 .1
588 .upgrade(upgrade)
589 .await
590 }
591
592 fn name(&self) -> Cow<'static, str> {
593 "Router".into()
594 }
595
596 async fn init(&mut self, info: &mut Info) {
597 // This code is not what a reader would expect, so here's a
598 // brief explanation:
599 //
600 // Mutable map iterators are not Send, and because we need to hold that data across await
601 // boundaries in a Send future, we cannot mutate in place.
602 //
603 // However, because this is only called once at app boot, and because we have &mut self, it
604 // is safe to move the router contents into this future and then replace it, and the
605 // performance impacts of doing so are unimportant as it is part of app boot.
606 let routefinder = mem::take(&mut self.routefinder);
607 for (route, (methods, mut handler)) in routefinder.0 {
608 handler.init(info).await;
609 self.routefinder.add(methods, route, handler);
610 }
611 }
612}
613
614impl Debug for Router {
615 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
616 f.write_str("Router ")?;
617 let mut set = f.debug_set();
618
619 for (route, (methods, handler)) in &self.routefinder.0 {
620 set.entry(&format_args!("{} {} -> {}", methods, route, handler.name()));
621 }
622 set.finish()
623 }
624}