trillium_router/router.rs
1use crate::{CapturesNewType, RouteSpecNewType, RouterRef};
2use routefinder::{Match, RouteSpec, Router as Routefinder};
3use std::{
4 borrow::Cow,
5 collections::BTreeSet,
6 fmt::{self, Debug, Display, Formatter},
7 mem,
8};
9use trillium::{BoxedHandler, Conn, Handler, Info, KnownHeaderName, Method, Upgrade};
10
11const ALL_METHODS: [Method; 5] = [
12 Method::Delete,
13 Method::Get,
14 Method::Patch,
15 Method::Post,
16 Method::Put,
17];
18
19#[derive(Debug)]
20enum MethodSelection {
21 Just(Method),
22 All,
23 Any(Vec<Method>),
24}
25
26impl Display for MethodSelection {
27 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
28 match self {
29 MethodSelection::Just(m) => Display::fmt(m, f),
30 MethodSelection::All => f.write_str("*"),
31 MethodSelection::Any(v) => {
32 f.write_str(&v.iter().map(|m| m.as_ref()).collect::<Vec<_>>().join(", "))
33 }
34 }
35 }
36}
37
38impl PartialEq<Method> for MethodSelection {
39 fn eq(&self, other: &Method) -> bool {
40 match self {
41 MethodSelection::Just(m) => m == other,
42 MethodSelection::All => true,
43 MethodSelection::Any(v) => v.contains(other),
44 }
45 }
46}
47
48impl From<()> for MethodSelection {
49 fn from(_: ()) -> MethodSelection {
50 Self::All
51 }
52}
53
54impl From<Method> for MethodSelection {
55 fn from(method: Method) -> Self {
56 Self::Just(method)
57 }
58}
59
60impl From<&[Method]> for MethodSelection {
61 fn from(methods: &[Method]) -> Self {
62 Self::Any(methods.to_vec())
63 }
64}
65impl From<Vec<Method>> for MethodSelection {
66 fn from(methods: Vec<Method>) -> Self {
67 Self::Any(methods)
68 }
69}
70
71#[derive(Debug, Default)]
72struct MethodRoutefinder(Routefinder<(MethodSelection, BoxedHandler)>);
73impl MethodRoutefinder {
74 fn add<R>(
75 &mut self,
76 method_selection: impl Into<MethodSelection>,
77 path: R,
78 handler: impl Handler,
79 ) where
80 R: TryInto<RouteSpec>,
81 R::Error: Debug,
82 {
83 self.0
84 .add(path, (method_selection.into(), BoxedHandler::new(handler)))
85 .expect("could not add route")
86 }
87
88 fn methods_matching(&self, path: &str) -> BTreeSet<Method> {
89 let mut set = BTreeSet::new();
90
91 fn extend(ms: &MethodSelection, set: &mut BTreeSet<Method>) {
92 match ms {
93 MethodSelection::All => {
94 set.extend(ALL_METHODS);
95 }
96 MethodSelection::Just(method) => {
97 set.insert(*method);
98 }
99 MethodSelection::Any(methods) => {
100 set.extend(methods);
101 }
102 }
103 }
104
105 if path == "*" {
106 for ms in self.0.iter().map(|(_, (m, _))| m) {
107 extend(ms, &mut set);
108 }
109 } else {
110 for m in self.0.match_iter(path) {
111 extend(&m.0, &mut set);
112 }
113 };
114
115 set.remove(&Method::Options);
116 set
117 }
118
119 fn best_match<'a, 'b>(
120 &'a self,
121 method: Method,
122 path: &'b str,
123 ) -> Option<Match<'a, 'b, (MethodSelection, BoxedHandler)>> {
124 self.0.match_iter(path).find(|m| m.0 == method)
125 }
126}
127
128/// # The Router handler
129///
130/// See crate level docs for more, as this is the primary type in this crate.
131pub struct Router {
132 routefinder: MethodRoutefinder,
133 handle_options: bool,
134}
135
136impl Default for Router {
137 fn default() -> Self {
138 Self {
139 routefinder: MethodRoutefinder::default(),
140 handle_options: true,
141 }
142 }
143}
144
145macro_rules! method {
146 ($fn_name:ident, $method:ident) => {
147 method!(
148 $fn_name,
149 $method,
150 concat!(
151 // yep, macro-generated doctests
152 "Registers a handler for the ",
153 stringify!($fn_name),
154 " http method.
155
156```
157# use trillium::Conn;
158# use trillium_router::Router;
159# use trillium_testing::TestServer;
160# trillium_testing::block_on(async {
161let router = Router::new().",
162 stringify!($fn_name),
163 "(\"/some/route\", |conn: Conn| async move {
164 conn.ok(\"success\")
165});
166
167let app = TestServer::new(router).await;
168app.",
169 stringify!($fn_name),
170 "(\"/some/route\").await
171 .assert_ok()
172 .assert_body(\"success\");
173app.",
174 stringify!($fn_name),
175 "(\"/other/route\").await
176 .assert_status(404);
177# });
178```
179"
180 )
181 );
182 };
183
184 ($fn_name:ident, $method:ident, $doc_comment:expr_2021) => {
185 #[doc = $doc_comment]
186 pub fn $fn_name<R>(mut self, path: R, handler: impl Handler) -> Self
187 where
188 R: TryInto<RouteSpec>,
189 R::Error: Debug,
190 {
191 self.add(path, Method::$method, handler);
192 self
193 }
194 };
195}
196
197impl Router {
198 method!(get, Get);
199
200 method!(post, Post);
201
202 method!(put, Put);
203
204 method!(delete, Delete);
205
206 method!(patch, Patch);
207
208 /// Constructs a new Router. This is often used with [`Router::get`],
209 /// [`Router::post`], [`Router::put`], [`Router::delete`], and
210 /// [`Router::patch`] chainable methods to build up an application.
211 ///
212 /// For an alternative way of constructing a Router, see [`Router::build`]
213 ///
214 /// ```
215 /// # use trillium::Conn;
216 /// # use trillium_router::Router;
217 /// # use trillium_testing::TestServer;
218 ///
219 /// # trillium_testing::block_on(async {
220 /// let router = Router::new()
221 /// .get("/", |conn: Conn| async move {
222 /// conn.ok("you have reached the index")
223 /// })
224 /// .get("/some/:param", |conn: Conn| async move {
225 /// conn.ok("you have reached /some/:param")
226 /// })
227 /// .post("/", |conn: Conn| async move { conn.ok("post!") });
228 ///
229 /// let app = TestServer::new(router).await;
230 /// app.get("/")
231 /// .await
232 /// .assert_ok()
233 /// .assert_body("you have reached the index");
234 /// app.get("/some/route")
235 /// .await
236 /// .assert_ok()
237 /// .assert_body("you have reached /some/:param");
238 /// app.post("/").await.assert_ok().assert_body("post!");
239 /// # });
240 /// ```
241 pub fn new() -> Self {
242 Self::default()
243 }
244
245 /// Disable the default behavior of responding to OPTIONS requests
246 /// with the supported methods at a given path
247 pub fn without_options_handling(mut self) -> Self {
248 self.set_options_handling(false);
249 self
250 }
251
252 /// enable or disable the router's behavior of responding to OPTIONS requests with the supported
253 /// methods at given path.
254 ///
255 /// default: enabled
256 pub(crate) fn set_options_handling(&mut self, options_enabled: bool) -> &mut Self {
257 self.handle_options = options_enabled;
258 self
259 }
260
261 /// Another way to build a router, if you don't like the chainable
262 /// interface described in [`Router::new`]. Note that the argument to
263 /// the closure is a [`RouterRef`].
264 ///
265 /// ```
266 /// # use trillium::Conn;
267 /// # use trillium_router::Router;
268 /// # use trillium_testing::TestServer;
269 /// # trillium_testing::block_on(async {
270 /// let router = Router::build(|mut router| {
271 /// router.get("/", |conn: Conn| async move {
272 /// conn.ok("you have reached the index")
273 /// });
274 ///
275 /// router.get("/some/:paramroute", |conn: Conn| async move {
276 /// conn.ok("you have reached /some/:param")
277 /// });
278 ///
279 /// router.post("/", |conn: Conn| async move { conn.ok("post!") });
280 /// });
281 ///
282 /// let app = TestServer::new(router).await;
283 /// app.get("/")
284 /// .await
285 /// .assert_ok()
286 /// .assert_body("you have reached the index");
287 /// app.get("/some/route")
288 /// .await
289 /// .assert_ok()
290 /// .assert_body("you have reached /some/:param");
291 /// app.post("/").await.assert_ok().assert_body("post!");
292 /// # });
293 /// ```
294 pub fn build(builder: impl Fn(RouterRef)) -> Router {
295 let mut router = Router::new();
296 builder(RouterRef::new(&mut router));
297 router
298 }
299
300 fn best_match<'a, 'b>(
301 &'a self,
302 method: Method,
303 path: &'b str,
304 ) -> Option<Match<'a, 'b, (MethodSelection, BoxedHandler)>> {
305 self.routefinder.best_match(method, path)
306 }
307
308 /// Registers a handler for a method other than get, put, post, patch, or delete.
309 ///
310 /// ```
311 /// # use trillium::{Conn, Method};
312 /// # use trillium_router::Router;
313 /// # use trillium_testing::TestServer;
314 /// # trillium_testing::block_on(async {
315 /// let router = Router::new()
316 /// .with_route("OPTIONS", "/some/route", |conn: Conn| async move {
317 /// conn.ok("directly handling options")
318 /// })
319 /// .with_route(Method::Checkin, "/some/route", |conn: Conn| async move {
320 /// conn.ok("checkin??")
321 /// });
322 ///
323 /// let app = TestServer::new(router).await;
324 /// app.build(Method::Options, "/some/route")
325 /// .await
326 /// .assert_ok()
327 /// .assert_body("directly handling options");
328 /// app.build(Method::Checkin, "/some/route")
329 /// .await
330 /// .assert_ok()
331 /// .assert_body("checkin??");
332 /// # });
333 /// ```
334 pub fn with_route<M, R>(mut self, method: M, path: R, handler: impl Handler) -> Self
335 where
336 M: TryInto<Method>,
337 <M as TryInto<Method>>::Error: Debug,
338 R: TryInto<RouteSpec>,
339 R::Error: Debug,
340 {
341 self.add(path, method.try_into().unwrap(), handler);
342 self
343 }
344
345 pub(crate) fn add<R>(&mut self, path: R, method: Method, handler: impl Handler)
346 where
347 R: TryInto<RouteSpec>,
348 R::Error: Debug,
349 {
350 self.routefinder.add(method, path, handler);
351 }
352
353 pub(crate) fn add_any<R>(&mut self, methods: &[Method], path: R, handler: impl Handler)
354 where
355 R: TryInto<RouteSpec>,
356 R::Error: Debug,
357 {
358 self.routefinder.add(methods, path, handler)
359 }
360
361 pub(crate) fn add_all<R>(&mut self, path: R, handler: impl Handler)
362 where
363 R: TryInto<RouteSpec>,
364 R::Error: Debug,
365 {
366 self.routefinder.add((), path, handler);
367 }
368
369 /// Appends the handler to all (get, post, put, delete, and patch) methods.
370 /// ```
371 /// # use trillium::Conn;
372 /// # use trillium_router::Router;
373 /// # use trillium_testing::TestServer;
374 /// let router = Router::new().all("/any", |conn: Conn| async move {
375 /// let response = format!("you made a {} request to /any", conn.method());
376 /// conn.ok(response)
377 /// });
378 ///
379 /// # trillium_testing::block_on(async {
380 /// let app = TestServer::new(router).await;
381 ///
382 /// app.get("/any")
383 /// .await
384 /// .assert_ok()
385 /// .assert_body("you made a GET request to /any");
386 ///
387 /// app.post("/any")
388 /// .await
389 /// .assert_ok()
390 /// .assert_body("you made a POST request to /any");
391 ///
392 /// app.delete("/any")
393 /// .await
394 /// .assert_ok()
395 /// .assert_body("you made a DELETE request to /any");
396 ///
397 /// app.patch("/any")
398 /// .await
399 /// .assert_ok()
400 /// .assert_body("you made a PATCH request to /any");
401 ///
402 /// app.put("/any")
403 /// .await
404 /// .assert_ok()
405 /// .assert_body("you made a PUT request to /any");
406 ///
407 /// app.get("/").await.assert_status(404);
408 /// # });
409 /// ```
410 pub fn all<R>(mut self, path: R, handler: impl Handler) -> Self
411 where
412 R: TryInto<RouteSpec>,
413 R::Error: Debug,
414 {
415 self.add_all(path, handler);
416 self
417 }
418
419 /// Appends the handler to each of the provided http methods.
420 /// ```
421 /// # use trillium::Conn;
422 /// # use trillium_router::Router;
423 /// # use trillium_testing::TestServer;
424 /// let router = Router::new().any(&["get", "post"], "/get_or_post", |conn: Conn| async move {
425 /// let response = format!("you made a {} request to /get_or_post", conn.method());
426 /// conn.ok(response)
427 /// });
428 ///
429 /// # trillium_testing::block_on(async {
430 /// let app = TestServer::new(router).await;
431 ///
432 /// app.get("/get_or_post")
433 /// .await
434 /// .assert_ok()
435 /// .assert_body("you made a GET request to /get_or_post");
436 ///
437 /// app.post("/get_or_post")
438 /// .await
439 /// .assert_ok()
440 /// .assert_body("you made a POST request to /get_or_post");
441 ///
442 /// app.delete("/any").await.assert_status(404);
443 /// app.patch("/any").await.assert_status(404);
444 /// app.put("/any").await.assert_status(404);
445 /// app.get("/").await.assert_status(404);
446 /// # });
447 /// ```
448 pub fn any<IntoMethod, R>(
449 mut self,
450 methods: &[IntoMethod],
451 path: R,
452 handler: impl Handler,
453 ) -> Self
454 where
455 IntoMethod: TryInto<Method> + Clone,
456 <IntoMethod as TryInto<Method>>::Error: Debug,
457 R: TryInto<RouteSpec>,
458 R::Error: Debug,
459 {
460 let methods = methods
461 .iter()
462 .cloned()
463 .map(|m| m.try_into().unwrap())
464 .collect::<Vec<_>>();
465 self.add_any(&methods, path, handler);
466 self
467 }
468}
469
470impl Handler for Router {
471 async fn run(&self, mut conn: Conn) -> Conn {
472 let method = conn.method();
473 let original_captures = conn.take_state();
474 let path = conn.path();
475 let mut has_path = false;
476
477 if let Some(m) = self.best_match(conn.method(), path) {
478 let mut captures = m.captures().into_owned();
479
480 let route = m.route().clone();
481
482 if let Some(CapturesNewType(mut original_captures)) = original_captures {
483 original_captures.append(captures);
484 captures = original_captures;
485 }
486
487 log::debug!("running {}: {}", m.route(), m.1.name());
488 let mut new_conn = m
489 .handler()
490 .1
491 .run({
492 if let Some(wildcard) = captures.wildcard() {
493 conn.push_path(String::from(wildcard));
494 has_path = true;
495 }
496
497 conn.with_state(CapturesNewType(captures))
498 .with_state(RouteSpecNewType(route))
499 })
500 .await;
501
502 if has_path {
503 new_conn.pop_path();
504 }
505 new_conn
506 } else if method == Method::Options && self.handle_options {
507 let allow = self
508 .routefinder
509 .methods_matching(path)
510 .iter()
511 .map(|m| m.as_ref())
512 .collect::<Vec<_>>()
513 .join(", ");
514
515 conn.with_response_header(KnownHeaderName::Allow, allow)
516 .with_status(200)
517 .halt()
518 } else {
519 log::debug!("{} did not match any route", conn.path());
520 conn
521 }
522 }
523
524 async fn before_send(&self, conn: Conn) -> Conn {
525 let path = conn.path();
526 if let Some(m) = self.best_match(conn.method(), path) {
527 m.handler().1.before_send(conn).await
528 } else {
529 conn
530 }
531 }
532
533 fn has_upgrade(&self, upgrade: &Upgrade) -> bool {
534 if let Some(m) = self.best_match(upgrade.method(), upgrade.path()) {
535 m.1.has_upgrade(upgrade)
536 } else {
537 false
538 }
539 }
540
541 async fn upgrade(&self, upgrade: Upgrade) {
542 self.best_match(upgrade.method(), upgrade.path())
543 .unwrap()
544 .handler()
545 .1
546 .upgrade(upgrade)
547 .await
548 }
549
550 fn name(&self) -> Cow<'static, str> {
551 "Router".into()
552 }
553
554 async fn init(&mut self, info: &mut Info) {
555 // This code is not what a reader would expect, so here's a
556 // brief explanation:
557 //
558 // Mutable map iterators are not Send, and because we need to hold that data across await
559 // boundaries in a Send future, we cannot mutate in place.
560 //
561 // However, because this is only called once at app boot, and because we have &mut self, it
562 // is safe to move the router contents into this future and then replace it, and the
563 // performance impacts of doing so are unimportant as it is part of app boot.
564 let routefinder = mem::take(&mut self.routefinder);
565 for (route, (methods, mut handler)) in routefinder.0 {
566 handler.init(info).await;
567 self.routefinder.add(methods, route, handler);
568 }
569 }
570}
571
572impl Debug for Router {
573 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
574 f.write_str("Router ")?;
575 let mut set = f.debug_set();
576
577 for (route, (methods, handler)) in &self.routefinder.0 {
578 set.entry(&format_args!("{} {} -> {}", methods, route, handler.name()));
579 }
580 set.finish()
581 }
582}