trillium_forwarding/
lib.rs1#![forbid(unsafe_code)]
19#![deny(
20 missing_copy_implementations,
21 rustdoc::missing_crate_level_docs,
22 missing_debug_implementations,
23 missing_docs,
24 nonstandard_style,
25 unused_qualifications
26)]
27
28#[cfg(test)]
29#[doc = include_str!("../README.md")]
30mod readme {}
31
32mod forwarded;
33pub use forwarded::Forwarded;
34
35mod parse_utils;
36
37use std::{fmt::Debug, net::IpAddr, ops::Deref};
38use trillium::{Conn, Handler, Status, Transport};
39
40#[derive(Debug, Default)]
41#[non_exhaustive]
42enum TrustProxy {
43 Always,
44
45 #[default]
46 Never,
47
48 Cidr(Vec<cidr::AnyIpCidr>),
49
50 Function(TrustFn),
51}
52
53struct TrustFn(Box<dyn Fn(&IpAddr) -> bool + Send + Sync + 'static>);
54impl<F> From<F> for TrustFn
55where
56 F: Fn(&IpAddr) -> bool + Send + Sync + 'static,
57{
58 fn from(f: F) -> Self {
59 Self(Box::new(f))
60 }
61}
62impl Debug for TrustFn {
63 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64 f.debug_tuple("TrustPredicate")
65 .field(&format_args!(".."))
66 .finish()
67 }
68}
69
70impl Deref for TrustFn {
71 type Target = dyn Fn(&IpAddr) -> bool + Send + Sync + 'static;
72
73 fn deref(&self) -> &Self::Target {
74 &self.0
75 }
76}
77
78impl TrustProxy {
79 fn is_trusted(&self, ip: Option<IpAddr>) -> bool {
80 match (self, ip) {
81 (TrustProxy::Always, _) => true,
82 (TrustProxy::Cidr(cidrs), Some(ip)) => cidrs.iter().any(|c| c.contains(&ip)),
83 (TrustProxy::Function(trust_predicate), Some(ip)) => trust_predicate(&ip),
84 _ => false,
85 }
86 }
87}
88
89#[derive(Default, Debug)]
93pub struct Forwarding(TrustProxy);
94
95impl From<TrustProxy> for Forwarding {
96 fn from(tp: TrustProxy) -> Self {
97 Self(tp)
98 }
99}
100
101impl Forwarding {
102 pub fn trust_ips<'a>(ips: impl IntoIterator<Item = &'a str>) -> Self {
111 Self(TrustProxy::Cidr(
112 ips.into_iter().map(|ip| ip.parse().unwrap()).collect(),
113 ))
114 }
115
116 pub fn trust_fn<F>(trust_predicate: F) -> Self
128 where
129 F: Fn(&IpAddr) -> bool + Send + Sync + 'static,
130 {
131 Self(TrustProxy::Function(TrustFn::from(trust_predicate)))
132 }
133
134 pub fn trust_always() -> Self {
142 Self(TrustProxy::Always)
143 }
144}
145
146impl Handler for Forwarding {
147 async fn run(&self, mut conn: Conn) -> Conn {
148 if !self.0.is_trusted(conn.peer_ip()) {
149 return conn;
150 }
151
152 let forwarded = match Forwarded::from_headers(conn.request_headers()) {
153 Ok(Some(forwarded)) => forwarded.into_owned(),
154 Err(error) => {
155 log::error!("{error}");
156 return conn
157 .halt()
158 .with_state(error)
159 .with_status(Status::BadRequest);
160 }
161 Ok(None) => return conn,
162 };
163
164 log::debug!("received trusted forwarded {:?}", &forwarded);
165
166 let inner_mut: &mut trillium_http::Conn<Box<dyn Transport>> = conn.as_mut();
167
168 if let Some(host) = forwarded.host() {
169 inner_mut.set_host(String::from(host));
170 }
171
172 if let Some(proto) = forwarded.proto() {
173 inner_mut.set_secure(proto == "https");
174 }
175
176 if let Some(ip) = forwarded.forwarded_for().first()
177 && let Ok(ip_addr) = ip
178 .trim_start_matches('[')
179 .trim_end_matches(']')
180 .parse::<IpAddr>()
181 {
182 inner_mut.set_peer_ip(Some(ip_addr));
183 }
184
185 conn.with_state(forwarded)
186 }
187}