1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
#![forbid(unsafe_code)]
#![deny(
    clippy::dbg_macro,
    missing_copy_implementations,
    rustdoc::missing_crate_level_docs,
    missing_debug_implementations,
    missing_docs,
    nonstandard_style,
    unused_qualifications
)]
/*!
Basic authentication for trillium.rs

```rust,no_run
use trillium_basic_auth::BasicAuth;
trillium_smol::run((
    BasicAuth::new("trillium", "7r1ll1um").with_realm("rust"),
    |conn: trillium::Conn| async move { conn.ok("authenticated") },
));
```
*/
use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
use trillium::{
    async_trait, Conn, Handler,
    KnownHeaderName::{Authorization, WwwAuthenticate},
    Status,
};

/// basic auth handler
#[derive(Clone, Debug)]
pub struct BasicAuth {
    credentials: Credentials,
    realm: Option<String>,

    // precomputed/derived data fields:
    expected_header: String,
    www_authenticate: String,
}

/// basic auth username-password credentials
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct Credentials {
    username: String,
    password: String,
}

impl Credentials {
    fn new(username: &str, password: &str) -> Self {
        Self {
            username: String::from(username),
            password: String::from(password),
        }
    }

    fn expected_header(&self) -> String {
        format!(
            "Basic {}",
            BASE64.encode(format!("{}:{}", self.username, self.password))
        )
    }

    // const BASIC: &str = "Basic ";
    // pub fn for_conn(conn: &Conn) -> Option<Self> {
    //     conn.headers()
    //         .get_str(KnownHeaderName::Authorization)
    //         .and_then(|value| {
    //             if value[..BASIC.len().min(value.len())].eq_ignore_ascii_case(BASIC) {
    //                 Some(&value[BASIC.len()..])
    //             } else {
    //                 None
    //             }
    //         })
    //         .and_then(|base64_credentials| BASE64.decode(base64_credentials).ok())
    //         .and_then(|credential_bytes| String::from_utf8(credential_bytes).ok())
    //         .and_then(|mut credential_string| {
    //             credential_string.find(":").map(|colon| {
    //                 let password = credential_string.split_off(colon + 1).into();
    //                 credential_string.pop();
    //                 Self {
    //                     username: credential_string.into(),
    //                     password,
    //                 }
    //             })
    //         })
    // }
}

impl BasicAuth {
    /// build a new basic auth handler with the provided username and password
    pub fn new(username: &str, password: &str) -> Self {
        let credentials = Credentials::new(username, password);
        let expected_header = credentials.expected_header();
        let realm = None;
        Self {
            expected_header,
            credentials,
            realm,
            www_authenticate: String::from("Basic"),
        }
    }

    /// provide a realm for the www-authenticate response sent by this handler
    pub fn with_realm(mut self, realm: &str) -> Self {
        self.www_authenticate = format!("Basic realm=\"{}\"", realm.replace('\"', "\\\""));
        self.realm = Some(String::from(realm));
        self
    }

    fn is_allowed(&self, conn: &Conn) -> bool {
        conn.headers().get_str(Authorization) == Some(&*self.expected_header)
    }

    fn deny(&self, conn: Conn) -> Conn {
        conn.with_status(Status::Unauthorized)
            .with_header(WwwAuthenticate, self.www_authenticate.clone())
            .halt()
    }
}

#[async_trait]
impl Handler for BasicAuth {
    async fn run(&self, conn: Conn) -> Conn {
        if self.is_allowed(&conn) {
            conn.with_state(self.credentials.clone())
        } else {
            self.deny(conn)
        }
    }
}