1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
/*!
# Trillium method override handler

This allows http clients that are unable to generate http methods
other than `GET` and `POST` to use `POST` requests that are
interpreted as other methods such as `PUT`, `PATCH`, or `DELETE`.

This is currently supported with a querystring parameter of
`_method`. To change the querystring parameter's name, use
[`MethodOverride::with_param_name`]

By default, the only methods allowed are `PUT`, `PATCH`, and
`DELETE`. To override this, use
[`MethodOverride::with_allowed_methods`]

Subsequent handlers see the requested method on the conn instead of
POST.
*/
#![forbid(unsafe_code)]
#![deny(
    missing_copy_implementations,
    rustdoc::missing_crate_level_docs,
    missing_debug_implementations,
    missing_docs,
    nonstandard_style,
    unused_qualifications
)]

use querystrong::QueryStrong;
use std::{collections::HashSet, convert::TryInto, fmt::Debug, iter::FromIterator};
use trillium::{async_trait, conn_unwrap, Conn, Handler, Method};

/**
Trillium method override handler

See crate-level docs for an explanation
*/
#[derive(Clone, Debug)]
pub struct MethodOverride {
    param: &'static str,
    allowed_methods: HashSet<Method>,
}

impl Default for MethodOverride {
    fn default() -> Self {
        Self {
            param: "_method",
            allowed_methods: HashSet::from_iter([Method::Put, Method::Patch, Method::Delete]),
        }
    }
}

impl MethodOverride {
    /// constructs a new MethodOverride handler with default allowed methods and param name
    pub fn new() -> Self {
        Self::default()
    }

    /**
    replace the default allowed methods with the provided list of methods

    default: `put`, `patch`, `delete`

    ```
    # use trillium_method_override::MethodOverride;
    let handler = MethodOverride::new().with_allowed_methods(["put", "patch"]);
    ```
    */
    pub fn with_allowed_methods<M>(mut self, methods: impl IntoIterator<Item = M>) -> Self
    where
        M: TryInto<Method>,
        <M as TryInto<Method>>::Error: Debug,
    {
        self.allowed_methods = methods.into_iter().map(|m| m.try_into().unwrap()).collect();
        self
    }

    /**
    replace the default param name with the provided param name

    default: `_method`
    ```
    # use trillium_method_override::MethodOverride;
    let handler = MethodOverride::new().with_param_name("_http_method");
    ```
    */

    pub fn with_param_name(mut self, param_name: &'static str) -> Self {
        self.param = param_name;
        self
    }
}

#[async_trait]
impl Handler for MethodOverride {
    async fn run(&self, mut conn: Conn) -> Conn {
        if conn.method() != Method::Post {
            return conn;
        }
        let qs = conn_unwrap!(QueryStrong::parse(conn.querystring()).ok(), conn);
        let method_str = conn_unwrap!(qs.get_str(self.param), conn);
        let method: Method = conn_unwrap!(method_str.try_into().ok(), conn);
        if self.allowed_methods.contains(&method) {
            conn.inner_mut().set_method(method);
        }
        conn
    }
}

/// Alias for MethodOverride::new()
pub fn method_override() -> MethodOverride {
    MethodOverride::new()
}