Skip to main content

trillium_compression/
lib.rs

1//! Body compression for trillium.rs
2//!
3//! Currently, this crate only supports compressing outbound bodies with
4//! the zstd, brotli, and gzip algorithms (in order of preference),
5//! although more algorithms may be added in the future. The correct
6//! algorithm will be selected based on the Accept-Encoding header sent by
7//! the client, if one exists.
8//!
9//! Defaults are tuned for HTTP transport: brotli at quality 4 (matching
10//! nginx/caddy/Cloudflare). To opt into stronger or weaker compression,
11//! see [`Compression::with_brotli_level`], [`Compression::with_gzip_level`],
12//! and [`Compression::with_zstd_level`].
13//!
14//! Responses with `Content-Encoding` already set (e.g. precompressed
15//! sidecars) are passed through unchanged. Responses with already-
16//! compressed `Content-Type` (images, video, audio, fonts, archives) are
17//! skipped by default.
18#![forbid(unsafe_code)]
19#![deny(
20    missing_copy_implementations,
21    rustdoc::missing_crate_level_docs,
22    missing_debug_implementations,
23    nonstandard_style,
24    unused_qualifications
25)]
26#![warn(missing_docs)]
27
28#[cfg(test)]
29#[doc = include_str!("../README.md")]
30mod readme {}
31
32#[cfg(feature = "client")]
33pub mod client;
34
35pub use async_compression::Level;
36#[cfg(feature = "client")]
37use async_compression::futures::bufread::{BrotliDecoder, GzipDecoder, ZstdDecoder};
38use async_compression::futures::bufread::{BrotliEncoder, GzipEncoder, ZstdEncoder};
39use futures_lite::{
40    AsyncBufRead, AsyncReadExt,
41    io::{BufReader, Cursor},
42};
43use std::{
44    collections::BTreeSet,
45    fmt::{self, Display, Formatter},
46    str::FromStr,
47};
48use trillium::{
49    Body, Conn, Handler, HeaderValues,
50    KnownHeaderName::{AcceptEncoding, ContentEncoding, ContentType, Vary},
51    conn_unwrap,
52};
53
54/// Algorithms supported by this crate
55#[derive(PartialEq, Eq, Clone, Copy, Debug, Ord, PartialOrd)]
56#[non_exhaustive]
57pub enum CompressionAlgorithm {
58    /// Brotli algorithm
59    Brotli,
60
61    /// Gzip algorithm
62    Gzip,
63
64    /// Zstd algorithm
65    Zstd,
66
67    /// The identity content-coding: no transformation. Set this on a client conn's state to opt
68    /// a single request out of a configured default request encoding.
69    Identity,
70}
71
72impl CompressionAlgorithm {
73    fn as_str(&self) -> &'static str {
74        match self {
75            CompressionAlgorithm::Brotli => "br",
76            CompressionAlgorithm::Gzip => "gzip",
77            CompressionAlgorithm::Zstd => "zstd",
78            CompressionAlgorithm::Identity => "identity",
79        }
80    }
81
82    fn from_str_exact(s: &str) -> Option<Self> {
83        match s {
84            "br" => Some(CompressionAlgorithm::Brotli),
85            "gzip" => Some(CompressionAlgorithm::Gzip),
86            "x-gzip" => Some(CompressionAlgorithm::Gzip),
87            "zstd" => Some(CompressionAlgorithm::Zstd),
88            "identity" => Some(CompressionAlgorithm::Identity),
89            _ => None,
90        }
91    }
92}
93
94impl AsRef<str> for CompressionAlgorithm {
95    fn as_ref(&self) -> &str {
96        self.as_str()
97    }
98}
99
100impl Display for CompressionAlgorithm {
101    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
102        f.write_str(self.as_str())
103    }
104}
105
106impl FromStr for CompressionAlgorithm {
107    type Err = String;
108
109    fn from_str(s: &str) -> Result<Self, Self::Err> {
110        Self::from_str_exact(s)
111            .or_else(|| Self::from_str_exact(&s.to_ascii_lowercase()))
112            .ok_or_else(|| format!("unrecognized coding {s}"))
113    }
114}
115
116/// Trillium handler for compression
117#[derive(Clone, Debug)]
118pub struct Compression {
119    algorithms: BTreeSet<CompressionAlgorithm>,
120    brotli_level: Level,
121    gzip_level: Level,
122    zstd_level: Level,
123}
124
125impl Default for Compression {
126    fn default() -> Self {
127        use CompressionAlgorithm::*;
128        Self {
129            algorithms: [Zstd, Brotli, Gzip].into_iter().collect(),
130            // q11 (async-compression default) is ~10x slower than q4 with
131            // only a few percent better ratio — bad fit for the response
132            // hot path. Match nginx/caddy transport defaults.
133            brotli_level: Level::Precise(4),
134            gzip_level: Level::Default,
135            zstd_level: Level::Default,
136        }
137    }
138}
139
140impl Compression {
141    /// constructs a new compression handler
142    pub fn new() -> Self {
143        Self::default()
144    }
145
146    fn set_algorithms(&mut self, algos: &[CompressionAlgorithm]) {
147        self.algorithms = algos.iter().copied().collect();
148    }
149
150    /// sets the compression algorithms that this handler will
151    /// use. the default of Zstd, Brotli, Gzip is recommended. Note that the
152    /// order is ignored.
153    pub fn with_algorithms(mut self, algorithms: &[CompressionAlgorithm]) -> Self {
154        self.set_algorithms(algorithms);
155        self
156    }
157
158    /// sets the brotli compression level. The default is `Level::Precise(4)`,
159    /// matching common reverse proxy transport defaults. `Level::Default`
160    /// resolves to brotli quality 11, which is much slower for marginal
161    /// size gains.
162    pub fn with_brotli_level(mut self, level: Level) -> Self {
163        self.brotli_level = level;
164        self
165    }
166
167    /// sets the gzip compression level. The default is `Level::Default`,
168    /// which resolves to gzip level 6.
169    pub fn with_gzip_level(mut self, level: Level) -> Self {
170        self.gzip_level = level;
171        self
172    }
173
174    /// sets the zstd compression level. The default is `Level::Default`,
175    /// which resolves to zstd level 3.
176    pub fn with_zstd_level(mut self, level: Level) -> Self {
177        self.zstd_level = level;
178        self
179    }
180
181    fn levels(&self) -> Levels {
182        Levels {
183            brotli: self.brotli_level,
184            gzip: self.gzip_level,
185            zstd: self.zstd_level,
186        }
187    }
188
189    fn negotiate(&self, header: &str) -> Option<CompressionAlgorithm> {
190        parse_accept_encoding(header)
191            .into_iter()
192            .find_map(|(algo, _)| {
193                if self.algorithms.contains(&algo) {
194                    Some(algo)
195                } else {
196                    None
197                }
198            })
199    }
200}
201
202fn parse_accept_encoding(header: &str) -> Vec<(CompressionAlgorithm, u8)> {
203    let mut vec = header
204        .split(',')
205        .filter_map(|s| {
206            let mut iter = s.trim().split(';');
207            let (algo, q) = (iter.next()?, iter.next());
208            let algo = algo.trim().parse().ok()?;
209            let q = q
210                .and_then(|q| {
211                    q.trim()
212                        .strip_prefix("q=")
213                        .and_then(|q| q.parse::<f32>().map(|f| (f * 100.0) as u8).ok())
214                })
215                .unwrap_or(100u8);
216            Some((algo, q))
217        })
218        .collect::<Vec<(CompressionAlgorithm, u8)>>();
219
220    vec.sort_by(|(algo_a, a), (algo_b, b)| match b.cmp(a) {
221        std::cmp::Ordering::Equal => algo_a.cmp(algo_b),
222        other => other,
223    });
224
225    vec
226}
227
228/// Returns true if the content-type identifies a payload that is already
229/// compressed and should be passed through. The list covers image/audio/
230/// video binary formats, web fonts, and common archive formats. Plain-
231/// text-y types like `image/svg+xml` and `application/wasm` are intentionally
232/// not skipped.
233fn is_already_compressed(content_type: &str) -> bool {
234    let primary = content_type
235        .split(';')
236        .next()
237        .unwrap_or(content_type)
238        .trim();
239    matches!(
240        primary,
241        "image/png"
242            | "image/jpeg"
243            | "image/jpg"
244            | "image/gif"
245            | "image/webp"
246            | "image/avif"
247            | "image/heic"
248            | "image/heif"
249            | "image/apng"
250            | "image/x-icon"
251            | "video/mp4"
252            | "video/webm"
253            | "video/ogg"
254            | "video/quicktime"
255            | "video/x-msvideo"
256            | "audio/mpeg"
257            | "audio/ogg"
258            | "audio/webm"
259            | "audio/aac"
260            | "audio/flac"
261            | "audio/mp4"
262            | "font/woff"
263            | "font/woff2"
264            | "application/zip"
265            | "application/gzip"
266            | "application/x-gzip"
267            | "application/x-bzip2"
268            | "application/x-xz"
269            | "application/x-7z-compressed"
270            | "application/x-rar-compressed"
271            | "application/zstd"
272    ) || primary.starts_with("video/")
273        || primary.starts_with("audio/")
274}
275
276/// Per-algorithm compression levels, threaded into the encode helpers.
277#[derive(Clone, Copy, Debug)]
278pub(crate) struct Levels {
279    brotli: Level,
280    gzip: Level,
281    zstd: Level,
282}
283
284impl Default for Levels {
285    fn default() -> Self {
286        Self {
287            brotli: Level::Precise(4),
288            gzip: Level::Default,
289            zstd: Level::Default,
290        }
291    }
292}
293
294impl CompressionAlgorithm {
295    /// Apply this content-coding to `body`, returning the (possibly unchanged) body and whether
296    /// encoding was actually applied. The identity coding, an empty body, and any static body that
297    /// fails to shrink are returned untouched with `false`.
298    pub(crate) async fn encode(self, body: Body, levels: Levels) -> (Body, bool) {
299        if self == Self::Identity {
300            return (body, false);
301        }
302
303        if body.is_static() {
304            let bytes = body.static_bytes().unwrap();
305            match self.encode_static(bytes, levels).await {
306                Some(data) if data.len() < bytes.len() => {
307                    log::trace!(
308                        "compressed {} body {} → {} bytes",
309                        self.as_str(),
310                        bytes.len(),
311                        data.len()
312                    );
313                    (Body::new_static(data), true)
314                }
315                _ => (body, false),
316            }
317        } else if body.is_streaming() {
318            (
319                self.encode_streaming(BufReader::new(body.into_reader()), levels),
320                true,
321            )
322        } else {
323            (body, false)
324        }
325    }
326
327    /// Compress in-memory `bytes`, or `None` for the identity coding or on encoder error.
328    async fn encode_static(self, bytes: &[u8], levels: Levels) -> Option<Vec<u8>> {
329        let mut data = vec![];
330        let result = match self {
331            Self::Identity => return None,
332            Self::Zstd => {
333                ZstdEncoder::with_quality(Cursor::new(bytes), levels.zstd)
334                    .read_to_end(&mut data)
335                    .await
336            }
337            Self::Brotli => {
338                BrotliEncoder::with_quality(Cursor::new(bytes), levels.brotli)
339                    .read_to_end(&mut data)
340                    .await
341            }
342            Self::Gzip => {
343                GzipEncoder::with_quality(Cursor::new(bytes), levels.gzip)
344                    .read_to_end(&mut data)
345                    .await
346            }
347        };
348        result.ok().map(|_| data)
349    }
350
351    /// Wrap `reader` in a streaming encoder for this content-coding; identity passes through.
352    fn encode_streaming(self, reader: impl AsyncBufRead + Send + 'static, levels: Levels) -> Body {
353        match self {
354            Self::Identity => Body::new_streaming(reader, None),
355            Self::Zstd => Body::new_streaming(ZstdEncoder::with_quality(reader, levels.zstd), None),
356            Self::Brotli => {
357                Body::new_streaming(BrotliEncoder::with_quality(reader, levels.brotli), None)
358            }
359            Self::Gzip => Body::new_streaming(GzipEncoder::with_quality(reader, levels.gzip), None),
360        }
361    }
362
363    /// Wrap `reader` in a streaming decoder for this content-coding; identity passes through.
364    #[cfg(feature = "client")]
365    pub(crate) fn decode_streaming(self, reader: impl AsyncBufRead + Send + 'static) -> Body {
366        match self {
367            Self::Identity => Body::new_streaming(reader, None),
368            Self::Zstd => Body::new_streaming(ZstdDecoder::new(reader), None),
369            Self::Brotli => Body::new_streaming(BrotliDecoder::new(reader), None),
370            Self::Gzip => Body::new_streaming(GzipDecoder::new(reader), None),
371        }
372    }
373}
374
375impl Handler for Compression {
376    async fn run(&self, mut conn: Conn) -> Conn {
377        if let Some(header) = conn
378            .request_headers()
379            .get_str(AcceptEncoding)
380            .and_then(|h| self.negotiate(h))
381        {
382            conn.insert_state(header);
383        }
384        conn
385    }
386
387    async fn before_send(&self, mut conn: Conn) -> Conn {
388        // Already encoded upstream (precompressed sidecar, or another
389        // middleware ahead of us) — leave it alone.
390        if conn.response_headers().get_str(ContentEncoding).is_some() {
391            return conn;
392        }
393
394        // Skip already-compressed payloads (images, fonts, archives, ...).
395        if conn
396            .response_headers()
397            .get_str(ContentType)
398            .is_some_and(is_already_compressed)
399        {
400            return conn;
401        }
402
403        let Some(algo) = conn.state::<CompressionAlgorithm>().copied() else {
404            return conn;
405        };
406
407        let body = conn_unwrap!(conn.take_response_body(), conn);
408        let (body, compression_used) = algo.encode(body, self.levels()).await;
409
410        if compression_used {
411            let vary = conn
412                .response_headers()
413                .get_str(Vary)
414                .map(|vary| HeaderValues::from(format!("{vary}, Accept-Encoding")))
415                .unwrap_or_else(|| HeaderValues::from("Accept-Encoding"));
416
417            conn.response_headers_mut().extend([
418                (ContentEncoding, HeaderValues::from(algo.as_str())),
419                (Vary, vary),
420            ]);
421        }
422
423        conn.with_body(body)
424    }
425}
426
427/// Alias for [`Compression::new`](crate::Compression::new)
428pub fn compression() -> Compression {
429    Compression::new()
430}