Skip to main content

trillium_redirect/
client.rs

1//! Client-side follow-redirects middleware for [`trillium-client`][trillium_client].
2//!
3//! This module is gated behind the `client` feature flag. It provides [`FollowRedirects`], a
4//! [`ClientHandler`] that automatically follows HTTP redirects (301, 302, 303, 307, 308) up to
5//! a configurable limit, with sensible defaults around security-sensitive cases.
6//!
7//! # Behavior
8//!
9//! On a redirect response, [`FollowRedirects`] resolves the `Location` header against the
10//! current request URL, applies the policy below, and re-issues the request through the same
11//! client, so the connector and connection pool are reused.
12//!
13//! ## Method handling
14//!
15//! The redirect status determines whether the method changes and whether the request body is
16//! replayed:
17//!
18//! | Status | Method change | Body |
19//! |--------|---------------|------|
20//! | 301 Moved Permanently | POST → GET, otherwise unchanged | dropped if method changed |
21//! | 302 Found | POST → GET, otherwise unchanged | dropped if method changed |
22//! | 303 See Other | always GET | always dropped |
23//! | 307 Temporary Redirect | unchanged | replayed if static, dropped if streaming |
24//! | 308 Permanent Redirect | unchanged | replayed if static, dropped if streaming |
25//!
26//! ## Body replay
27//!
28//! Static bodies (constructed via [`Body::new_static`] or any of the `From` conversions for
29//! `Vec<u8>`, `&'static [u8]`, `String`, `&'static str`, etc.) are cloned and replayed on
30//! redirect.
31//!
32//! Streaming bodies (constructed via [`Body::new_streaming`]) are one-shot. Once consumed by
33//! the original request they cannot be replayed, and the redirected request is sent without
34//! a body.
35//!
36//! [`Body::new_static`]: trillium_client::Body::new_static
37//! [`Body::new_streaming`]: trillium_client::Body::new_streaming
38//!
39//! ## Cross-origin header filtering
40//!
41//! When the redirect target's origin (scheme + host + port) differs from the original, the
42//! following headers are dropped from the redirected request to avoid credential leakage:
43//!
44//! - `Authorization`
45//! - `Cookie`
46//! - `Proxy-Authorization`
47//!
48//! ## Defaults
49//!
50//! - **Max redirects**: 10. Override with [`FollowRedirects::with_max_redirects`].
51//! - **HTTPS → HTTP downgrade**: blocked. Allow with [`FollowRedirects::with_allow_downgrade`].
52//! - **Cross-origin redirects**: allowed. Restrict with [`FollowRedirects::with_allowed_origins`].
53//!
54//! # Example
55//!
56//! ```no_run
57//! use trillium_client::Client;
58//! use trillium_redirect::client::FollowRedirects;
59//! use trillium_testing::client_config;
60//!
61//! let client =
62//!     Client::new(client_config()).with_handler(FollowRedirects::new().with_max_redirects(5));
63//! ```
64
65use std::{collections::HashSet, sync::Arc};
66use trillium_client::{
67    Body, ClientHandler, Conn,
68    KnownHeaderName::{
69        Authorization, Connection, ContentEncoding, ContentLength, ContentType, Cookie, Expect,
70        Host, Location, ProxyAuthorization, TransferEncoding,
71    },
72    Method, Result, Status, Url,
73    url::Origin,
74};
75
76/// A [`ClientHandler`] that automatically follows HTTP redirects.
77///
78/// See the [module-level documentation][self] for behavior and configuration.
79#[derive(Debug, Clone)]
80pub struct FollowRedirects {
81    max_redirects: u32,
82    allow_downgrade: bool,
83    allowed_origins: Option<Arc<HashSet<Origin>>>,
84}
85
86impl Default for FollowRedirects {
87    fn default() -> Self {
88        Self::new()
89    }
90}
91
92impl FollowRedirects {
93    /// Construct a new [`FollowRedirects`] with default settings: max 10 redirects,
94    /// HTTPS-to-HTTP downgrade blocked, all origins allowed.
95    pub fn new() -> Self {
96        Self {
97            max_redirects: 10,
98            allow_downgrade: false,
99            allowed_origins: None,
100        }
101    }
102
103    /// Set the maximum number of redirects to follow before erroring with
104    /// [`RedirectError::TooMany`].
105    #[must_use]
106    pub fn with_max_redirects(mut self, max: u32) -> Self {
107        self.max_redirects = max;
108        self
109    }
110
111    /// Allow or block redirects from `https://` to `http://`. Default is blocked.
112    #[must_use]
113    pub fn with_allow_downgrade(mut self, allow: bool) -> Self {
114        self.allow_downgrade = allow;
115        self
116    }
117
118    /// Restrict redirects to the given allowlist of origins. Each `Url`'s
119    /// [origin][trillium_client::url::Url::origin] (scheme + host + port) is added to the
120    /// allowlist; the path/query/fragment of the input URLs is ignored.
121    ///
122    /// When set, redirects to any other origin error with [`RedirectError::OriginNotAllowed`].
123    /// When unset (the default), all origins are permitted.
124    #[must_use]
125    pub fn with_allowed_origins<I: IntoIterator<Item = Url>>(mut self, origins: I) -> Self {
126        let set: HashSet<Origin> = origins.into_iter().map(|u| u.origin()).collect();
127        self.allowed_origins = Some(Arc::new(set));
128        self
129    }
130
131    fn is_origin_allowed(&self, url: &Url) -> bool {
132        match &self.allowed_origins {
133            Some(allowed) => allowed.contains(&url.origin()),
134            None => true,
135        }
136    }
137}
138
139/// Per-conn redirect counter, stashed in conn state by [`FollowRedirects`].
140#[derive(Clone, Copy, Debug)]
141struct RedirectCount(u32);
142
143/// Snapshot of a replayable request body, stashed in conn state by [`FollowRedirects::run`]
144/// so that [`FollowRedirects::after_response`] can rebuild the body for the redirected
145/// request after the original was consumed by the network call.
146#[derive(Debug)]
147struct SavedBody(Body);
148
149/// Errors produced by [`FollowRedirects`] when a redirect cannot be followed.
150#[derive(thiserror::Error, Debug)]
151pub enum RedirectError {
152    /// The redirect chain exceeded the configured maximum.
153    #[error("redirect chain exceeded {0} redirects")]
154    TooMany(u32),
155
156    /// The redirect target's origin is not in the configured allowlist.
157    #[error("redirect to {0} not in allowed-origins list")]
158    OriginNotAllowed(String),
159
160    /// The original request was HTTPS and the redirect target is HTTP, but downgrade is not
161    /// allowed.
162    #[error("redirect from https to http blocked (call with_allow_downgrade(true) to permit)")]
163    DowngradeBlocked,
164
165    /// The redirect response had no `Location` header.
166    #[error("3xx redirect response had no Location header")]
167    MissingLocation,
168
169    /// The `Location` header could not be parsed as a valid URL relative to the request URL.
170    #[error("invalid Location header {value:?}: {error}")]
171    InvalidLocation {
172        /// The raw `Location` header value.
173        value: String,
174        /// The underlying URL parse error message.
175        error: String,
176    },
177}
178
179impl From<RedirectError> for trillium_client::Error {
180    fn from(err: RedirectError) -> Self {
181        trillium_client::Error::other(err)
182    }
183}
184
185impl ClientHandler for FollowRedirects {
186    async fn run(&self, conn: &mut Conn) -> Result<()> {
187        // Snapshot replayable bodies into conn state so we can replay across a redirect.
188        // Streaming bodies return None and are left alone — they're one-shot.
189        let snapshot = conn.request_body().and_then(Body::try_clone);
190        if let Some(snapshot) = snapshot {
191            conn.insert_state(SavedBody(snapshot));
192        }
193        Ok(())
194    }
195
196    async fn after_response(&self, conn: &mut Conn) -> Result<()> {
197        let Some(status) = conn.status() else {
198            return Ok(());
199        };
200        let Some(redirect_kind) = classify_redirect(status) else {
201            return Ok(());
202        };
203
204        // Resolve Location relative to the current request URL.
205        let location = conn
206            .response_headers()
207            .get_str(Location)
208            .ok_or(RedirectError::MissingLocation)?
209            .to_string();
210        let new_url = conn
211            .url()
212            .join(&location)
213            .map_err(|e| RedirectError::InvalidLocation {
214                value: location.clone(),
215                error: e.to_string(),
216            })?;
217
218        // Apply policy.
219        if !self.allow_downgrade && conn.url().scheme() == "https" && new_url.scheme() == "http" {
220            return Err(RedirectError::DowngradeBlocked.into());
221        }
222        if !self.is_origin_allowed(&new_url) {
223            return Err(RedirectError::OriginNotAllowed(new_url.to_string()).into());
224        }
225
226        // Bump count, error if over limit.
227        let count = conn.state::<RedirectCount>().map_or(0, |c| c.0);
228        if count >= self.max_redirects {
229            return Err(RedirectError::TooMany(self.max_redirects).into());
230        }
231
232        // Decide method + whether to keep body.
233        let original_method = conn.method();
234        let (new_method, keep_body) = match redirect_kind {
235            RedirectKind::SeeOther => (Method::Get, false),
236            RedirectKind::PreserveMethod => (original_method, true),
237            RedirectKind::PostToGet => {
238                if original_method == Method::Post {
239                    (Method::Get, false)
240                } else {
241                    (original_method, true)
242                }
243            }
244        };
245
246        if let Some(body) = conn.take_response_body() {
247            body.recycle().await;
248        }
249
250        // Build a fresh sibling conn from the same client.
251        let mut new_conn = conn.client().build_conn(new_method, new_url.clone());
252
253        std::mem::swap(new_conn.as_mut(), conn.as_mut());
254
255        // Copy request headers with a few categories stripped:
256        // - protocol/transport-managed headers — `finalize_headers` will re-derive them for the new
257        //   conn's body and url
258        // - body-description headers — only meaningful when a body is being sent
259        // - cross-origin credential headers — must not leak across origin boundaries
260        let same_origin = conn.url().origin() == new_url.origin();
261        let mut new_headers = conn.request_headers().clone();
262        new_headers.remove_all([Host, ContentLength, TransferEncoding, Expect, Connection]);
263        if !keep_body {
264            new_headers.remove_all([ContentType, ContentEncoding]);
265        }
266        if !same_origin {
267            new_headers.remove_all([Authorization, Cookie, ProxyAuthorization]);
268        }
269        *new_conn.request_headers_mut() = new_headers;
270
271        // Replay the body if the redirect kind preserves it. Static bodies were snapshotted
272        // in `run`; streaming bodies were one-shot and aren't replayable.
273        if keep_body
274            && let Some(saved) = new_conn.state::<SavedBody>()
275            && let Some(replayed) = saved.0.try_clone()
276        {
277            new_conn.set_request_body(replayed);
278        }
279
280        // Stash count + execute the redirected request.
281        new_conn.insert_state(RedirectCount(count + 1));
282        (&mut new_conn).await?;
283
284        // Swap so the user sees the final response on their original conn handle.
285        std::mem::swap(conn, &mut new_conn);
286        Ok(())
287    }
288
289    fn name(&self) -> std::borrow::Cow<'static, str> {
290        "FollowRedirects".into()
291    }
292}
293
294#[derive(Clone, Copy, Debug)]
295enum RedirectKind {
296    /// 303: always GET, drop body.
297    SeeOther,
298    /// 307/308: preserve method + body.
299    PreserveMethod,
300    /// 301/302: preserve method except POST → GET.
301    PostToGet,
302}
303
304fn classify_redirect(status: Status) -> Option<RedirectKind> {
305    match status {
306        Status::MovedPermanently | Status::Found => Some(RedirectKind::PostToGet),
307        Status::SeeOther => Some(RedirectKind::SeeOther),
308        Status::TemporaryRedirect | Status::PermanentRedirect => Some(RedirectKind::PreserveMethod),
309        _ => None,
310    }
311}