trillium_method_override/
lib.rs1#![forbid(unsafe_code)]
18#![deny(
19 missing_copy_implementations,
20 rustdoc::missing_crate_level_docs,
21 missing_debug_implementations,
22 missing_docs,
23 nonstandard_style,
24 unused_qualifications
25)]
26
27#[cfg(test)]
28#[doc = include_str!("../README.md")]
29mod readme {}
30
31use querystrong::{IndexPath, QueryStrong};
32use std::{collections::HashSet, fmt::Debug};
33use trillium::{Conn, Handler, Method, Transport};
34
35#[derive(Clone, Debug)]
39pub struct MethodOverride {
40 param: IndexPath<'static>,
41 allowed_methods: HashSet<Method>,
42}
43
44impl Default for MethodOverride {
45 fn default() -> Self {
46 Self {
47 param: IndexPath::parse("_method").unwrap(),
48 allowed_methods: HashSet::from_iter([Method::Put, Method::Patch, Method::Delete]),
49 }
50 }
51}
52
53impl MethodOverride {
54 pub fn new() -> Self {
56 Self::default()
57 }
58
59 pub fn with_allowed_methods<M>(mut self, methods: impl IntoIterator<Item = M>) -> Self
68 where
69 M: TryInto<Method>,
70 <M as TryInto<Method>>::Error: Debug,
71 {
72 self.allowed_methods = methods.into_iter().map(|m| m.try_into().unwrap()).collect();
73 self
74 }
75
76 pub fn with_param_name(mut self, param_name: &'static str) -> Self {
84 self.param = IndexPath::parse(param_name).unwrap();
85 self
86 }
87}
88
89impl Handler for MethodOverride {
90 async fn run(&self, mut conn: Conn) -> Conn {
91 if conn.method() == Method::Post
92 && let Some(method_str) =
93 QueryStrong::parse(conn.querystring()).get_str(self.param.clone())
94 && let Ok(method) = Method::try_from(method_str)
95 && self.allowed_methods.contains(&method)
96 {
97 let mut_conn: &mut trillium_http::Conn<Box<dyn Transport>> = conn.as_mut();
98 mut_conn.set_method(method);
99 }
100
101 conn
102 }
103}
104
105pub fn method_override() -> MethodOverride {
107 MethodOverride::new()
108}