trillium_redirect/client.rs
1//! Client-side follow-redirects middleware for [`trillium-client`][trillium_client].
2//!
3//! This module is gated behind the `client` feature flag. It provides [`FollowRedirects`], a
4//! [`ClientHandler`] that automatically follows HTTP redirects (301, 302, 303, 307, 308) up to
5//! a configurable limit, with sensible defaults around security-sensitive cases.
6//!
7//! # Behavior
8//!
9//! On a redirect response, [`FollowRedirects`] resolves the `Location` header against the
10//! current request URL, applies the policy below, and re-issues the request through the same
11//! client, so the connector and connection pool are reused.
12//!
13//! ## Method handling
14//!
15//! The redirect status determines whether the method changes and whether the request body is
16//! replayed:
17//!
18//! | Status | Method change | Body |
19//! |--------|---------------|------|
20//! | 301 Moved Permanently | POST → GET, otherwise unchanged | dropped if method changed |
21//! | 302 Found | POST → GET, otherwise unchanged | dropped if method changed |
22//! | 303 See Other | always GET | always dropped |
23//! | 307 Temporary Redirect | unchanged | replayed if static, dropped if streaming |
24//! | 308 Permanent Redirect | unchanged | replayed if static, dropped if streaming |
25//!
26//! ## Body replay
27//!
28//! Static bodies (constructed via [`Body::new_static`] or any of the `From` conversions for
29//! `Vec<u8>`, `&'static [u8]`, `String`, `&'static str`, etc.) are cloned and replayed on
30//! redirect.
31//!
32//! Streaming bodies (constructed via [`Body::new_streaming`]) are one-shot. Once consumed by
33//! the original request they cannot be replayed, and the redirected request is sent without
34//! a body.
35//!
36//! [`Body::new_static`]: trillium_client::Body::new_static
37//! [`Body::new_streaming`]: trillium_client::Body::new_streaming
38//!
39//! ## Cross-origin header filtering
40//!
41//! When the redirect target's origin (scheme + host + port) differs from the original, the
42//! following headers are dropped from the redirected request to avoid credential leakage:
43//!
44//! - `Authorization`
45//! - `Cookie`
46//! - `Proxy-Authorization`
47//!
48//! ## Defaults
49//!
50//! - **Max redirects**: 10. Override with [`FollowRedirects::with_max_redirects`].
51//! - **HTTPS → HTTP downgrade**: blocked. Allow with [`FollowRedirects::with_allow_downgrade`].
52//! - **Cross-origin redirects**: allowed. Restrict with [`FollowRedirects::with_allowed_origins`].
53//!
54//! # Example
55//!
56//! ```no_run
57//! use trillium_client::Client;
58//! use trillium_redirect::client::FollowRedirects;
59//! use trillium_testing::client_config;
60//!
61//! let client =
62//! Client::new(client_config()).with_handler(FollowRedirects::new().with_max_redirects(5));
63//! ```
64
65use std::{collections::HashSet, sync::Arc};
66use trillium_client::{
67 Body, ClientHandler, Conn,
68 KnownHeaderName::{
69 Authorization, Connection, ContentEncoding, ContentLength, ContentType, Cookie, Expect,
70 Host, Location, ProxyAuthorization, TransferEncoding,
71 },
72 Method, Result, Status, Url,
73 url::Origin,
74};
75
76/// A [`ClientHandler`] that automatically follows HTTP redirects.
77///
78/// See the [module-level documentation][self] for behavior and configuration.
79#[derive(Debug, Clone)]
80pub struct FollowRedirects {
81 max_redirects: u32,
82 allow_downgrade: bool,
83 allowed_origins: Option<Arc<HashSet<Origin>>>,
84}
85
86impl Default for FollowRedirects {
87 fn default() -> Self {
88 Self::new()
89 }
90}
91
92impl FollowRedirects {
93 /// Construct a new [`FollowRedirects`] with default settings: max 10 redirects,
94 /// HTTPS-to-HTTP downgrade blocked, all origins allowed.
95 pub fn new() -> Self {
96 Self {
97 max_redirects: 10,
98 allow_downgrade: false,
99 allowed_origins: None,
100 }
101 }
102
103 /// Set the maximum number of redirects to follow before erroring with
104 /// [`RedirectError::TooMany`].
105 #[must_use]
106 pub fn with_max_redirects(mut self, max: u32) -> Self {
107 self.max_redirects = max;
108 self
109 }
110
111 /// Allow or block redirects from `https://` to `http://`. Default is blocked.
112 #[must_use]
113 pub fn with_allow_downgrade(mut self, allow: bool) -> Self {
114 self.allow_downgrade = allow;
115 self
116 }
117
118 /// Restrict redirects to the given allowlist of origins. Each `Url`'s
119 /// [origin][trillium_client::url::Url::origin] (scheme + host + port) is added to the
120 /// allowlist; the path/query/fragment of the input URLs is ignored.
121 ///
122 /// When set, redirects to any other origin error with [`RedirectError::OriginNotAllowed`].
123 /// When unset (the default), all origins are permitted.
124 #[must_use]
125 pub fn with_allowed_origins<I: IntoIterator<Item = Url>>(mut self, origins: I) -> Self {
126 let set: HashSet<Origin> = origins.into_iter().map(|u| u.origin()).collect();
127 self.allowed_origins = Some(Arc::new(set));
128 self
129 }
130
131 fn is_origin_allowed(&self, url: &Url) -> bool {
132 match &self.allowed_origins {
133 Some(allowed) => allowed.contains(&url.origin()),
134 None => true,
135 }
136 }
137}
138
139/// Per-conn redirect counter, stashed in conn state by [`FollowRedirects`].
140#[derive(Clone, Copy, Debug)]
141struct RedirectCount(u32);
142
143/// Snapshot of a replayable request body, stashed in conn state by [`FollowRedirects::run`]
144/// so that [`FollowRedirects::after_response`] can rebuild the body for the redirected
145/// request after the original was consumed by the network call.
146#[derive(Debug)]
147struct SavedBody(Body);
148
149/// Errors produced by [`FollowRedirects`] when a redirect cannot be followed.
150#[derive(thiserror::Error, Debug)]
151pub enum RedirectError {
152 /// The redirect chain exceeded the configured maximum.
153 #[error("redirect chain exceeded {0} redirects")]
154 TooMany(u32),
155
156 /// The redirect target's origin is not in the configured allowlist.
157 #[error("redirect to {0} not in allowed-origins list")]
158 OriginNotAllowed(String),
159
160 /// The original request was HTTPS and the redirect target is HTTP, but downgrade is not
161 /// allowed.
162 #[error("redirect from https to http blocked (call with_allow_downgrade(true) to permit)")]
163 DowngradeBlocked,
164
165 /// The redirect response had no `Location` header.
166 #[error("3xx redirect response had no Location header")]
167 MissingLocation,
168
169 /// The `Location` header could not be parsed as a valid URL relative to the request URL.
170 #[error("invalid Location header {value:?}: {error}")]
171 InvalidLocation {
172 /// The raw `Location` header value.
173 value: String,
174 /// The underlying URL parse error message.
175 error: String,
176 },
177}
178
179impl From<RedirectError> for trillium_client::Error {
180 fn from(err: RedirectError) -> Self {
181 trillium_client::Error::other(err)
182 }
183}
184
185impl ClientHandler for FollowRedirects {
186 async fn run(&self, conn: &mut Conn) -> Result<()> {
187 // Snapshot replayable bodies into conn state so we can replay across a redirect.
188 // Streaming bodies return None and are left alone — they're one-shot.
189 let snapshot = conn.request_body().and_then(Body::try_clone);
190 if let Some(snapshot) = snapshot {
191 conn.insert_state(SavedBody(snapshot));
192 }
193 Ok(())
194 }
195
196 async fn after_response(&self, conn: &mut Conn) -> Result<()> {
197 let Some(status) = conn.status() else {
198 return Ok(());
199 };
200 let Some(redirect_kind) = classify_redirect(status) else {
201 return Ok(());
202 };
203
204 // Resolve Location relative to the current request URL.
205 let location = conn
206 .response_headers()
207 .get_str(Location)
208 .ok_or(RedirectError::MissingLocation)?
209 .to_string();
210 let new_url = conn
211 .url()
212 .join(&location)
213 .map_err(|e| RedirectError::InvalidLocation {
214 value: location.clone(),
215 error: e.to_string(),
216 })?;
217
218 // Apply policy.
219 if !self.allow_downgrade && conn.url().scheme() == "https" && new_url.scheme() == "http" {
220 return Err(RedirectError::DowngradeBlocked.into());
221 }
222 if !self.is_origin_allowed(&new_url) {
223 return Err(RedirectError::OriginNotAllowed(new_url.to_string()).into());
224 }
225
226 // Bump count, error if over limit.
227 let count = conn.state::<RedirectCount>().map_or(0, |c| c.0);
228 if count >= self.max_redirects {
229 return Err(RedirectError::TooMany(self.max_redirects).into());
230 }
231
232 // Decide method + whether to keep body.
233 let original_method = conn.method();
234 let (new_method, keep_body) = match redirect_kind {
235 RedirectKind::SeeOther => (Method::Get, false),
236 RedirectKind::PreserveMethod => (original_method, true),
237 RedirectKind::PostToGet => {
238 if original_method == Method::Post {
239 (Method::Get, false)
240 } else {
241 (original_method, true)
242 }
243 }
244 };
245
246 if let Some(body) = conn.take_response_body() {
247 body.recycle().await;
248 }
249
250 // Build a fresh sibling conn from the same client.
251 let mut new_conn = conn.client().build_conn(new_method, new_url.clone());
252
253 std::mem::swap(new_conn.as_mut(), conn.as_mut());
254
255 // Copy request headers with a few categories stripped:
256 // - protocol/transport-managed headers — `finalize_headers` will re-derive them for the new
257 // conn's body and url
258 // - body-description headers — only meaningful when a body is being sent
259 // - cross-origin credential headers — must not leak across origin boundaries
260 let same_origin = conn.url().origin() == new_url.origin();
261 let mut new_headers = conn.request_headers().clone();
262 new_headers.remove_all([Host, ContentLength, TransferEncoding, Expect, Connection]);
263 if !keep_body {
264 new_headers.remove_all([ContentType, ContentEncoding]);
265 }
266 if !same_origin {
267 new_headers.remove_all([Authorization, Cookie, ProxyAuthorization]);
268 }
269 *new_conn.request_headers_mut() = new_headers;
270
271 // Replay the body if the redirect kind preserves it. Static bodies were snapshotted
272 // in `run`; streaming bodies were one-shot and aren't replayable.
273 if keep_body
274 && let Some(saved) = new_conn.state::<SavedBody>()
275 && let Some(replayed) = saved.0.try_clone()
276 {
277 new_conn.set_request_body(replayed);
278 }
279
280 // Stash count + execute the redirected request.
281 new_conn.insert_state(RedirectCount(count + 1));
282 (&mut new_conn).await?;
283
284 // Swap so the user sees the final response on their original conn handle.
285 std::mem::swap(conn, &mut new_conn);
286 Ok(())
287 }
288
289 fn name(&self) -> std::borrow::Cow<'static, str> {
290 "FollowRedirects".into()
291 }
292}
293
294#[derive(Clone, Copy, Debug)]
295enum RedirectKind {
296 /// 303: always GET, drop body.
297 SeeOther,
298 /// 307/308: preserve method + body.
299 PreserveMethod,
300 /// 301/302: preserve method except POST → GET.
301 PostToGet,
302}
303
304fn classify_redirect(status: Status) -> Option<RedirectKind> {
305 match status {
306 Status::MovedPermanently | Status::Found => Some(RedirectKind::PostToGet),
307 Status::SeeOther => Some(RedirectKind::SeeOther),
308 Status::TemporaryRedirect | Status::PermanentRedirect => Some(RedirectKind::PreserveMethod),
309 _ => None,
310 }
311}