1#![forbid(unsafe_code)]
19#![deny(
20 missing_copy_implementations,
21 rustdoc::missing_crate_level_docs,
22 missing_debug_implementations,
23 nonstandard_style,
24 unused_qualifications
25)]
26#![warn(missing_docs)]
27
28#[cfg(test)]
29#[doc = include_str!("../README.md")]
30mod readme {}
31
32#[cfg(feature = "client")]
33pub mod client;
34
35pub use async_compression::Level;
36#[cfg(feature = "client")]
37use async_compression::futures::bufread::{BrotliDecoder, GzipDecoder, ZstdDecoder};
38use async_compression::futures::bufread::{BrotliEncoder, GzipEncoder, ZstdEncoder};
39use futures_lite::{
40 AsyncBufRead, AsyncReadExt,
41 io::{BufReader, Cursor},
42};
43use std::{
44 collections::BTreeSet,
45 fmt::{self, Display, Formatter},
46 str::FromStr,
47};
48use trillium::{
49 Body, Conn, Handler, HeaderValues,
50 KnownHeaderName::{AcceptEncoding, ContentEncoding, ContentType, Vary},
51 conn_unwrap,
52};
53
54#[derive(PartialEq, Eq, Clone, Copy, Debug, Ord, PartialOrd)]
56#[non_exhaustive]
57pub enum CompressionAlgorithm {
58 Brotli,
60
61 Gzip,
63
64 Zstd,
66
67 Identity,
70}
71
72impl CompressionAlgorithm {
73 fn as_str(&self) -> &'static str {
74 match self {
75 CompressionAlgorithm::Brotli => "br",
76 CompressionAlgorithm::Gzip => "gzip",
77 CompressionAlgorithm::Zstd => "zstd",
78 CompressionAlgorithm::Identity => "identity",
79 }
80 }
81
82 fn from_str_exact(s: &str) -> Option<Self> {
83 match s {
84 "br" => Some(CompressionAlgorithm::Brotli),
85 "gzip" => Some(CompressionAlgorithm::Gzip),
86 "x-gzip" => Some(CompressionAlgorithm::Gzip),
87 "zstd" => Some(CompressionAlgorithm::Zstd),
88 "identity" => Some(CompressionAlgorithm::Identity),
89 _ => None,
90 }
91 }
92}
93
94impl AsRef<str> for CompressionAlgorithm {
95 fn as_ref(&self) -> &str {
96 self.as_str()
97 }
98}
99
100impl Display for CompressionAlgorithm {
101 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
102 f.write_str(self.as_str())
103 }
104}
105
106impl FromStr for CompressionAlgorithm {
107 type Err = String;
108
109 fn from_str(s: &str) -> Result<Self, Self::Err> {
110 Self::from_str_exact(s)
111 .or_else(|| Self::from_str_exact(&s.to_ascii_lowercase()))
112 .ok_or_else(|| format!("unrecognized coding {s}"))
113 }
114}
115
116#[derive(Clone, Debug)]
118pub struct Compression {
119 algorithms: BTreeSet<CompressionAlgorithm>,
120 brotli_level: Level,
121 gzip_level: Level,
122 zstd_level: Level,
123}
124
125impl Default for Compression {
126 fn default() -> Self {
127 use CompressionAlgorithm::*;
128 Self {
129 algorithms: [Zstd, Brotli, Gzip].into_iter().collect(),
130 brotli_level: Level::Precise(4),
134 gzip_level: Level::Default,
135 zstd_level: Level::Default,
136 }
137 }
138}
139
140impl Compression {
141 pub fn new() -> Self {
143 Self::default()
144 }
145
146 fn set_algorithms(&mut self, algos: &[CompressionAlgorithm]) {
147 self.algorithms = algos.iter().copied().collect();
148 }
149
150 pub fn with_algorithms(mut self, algorithms: &[CompressionAlgorithm]) -> Self {
154 self.set_algorithms(algorithms);
155 self
156 }
157
158 pub fn with_brotli_level(mut self, level: Level) -> Self {
163 self.brotli_level = level;
164 self
165 }
166
167 pub fn with_gzip_level(mut self, level: Level) -> Self {
170 self.gzip_level = level;
171 self
172 }
173
174 pub fn with_zstd_level(mut self, level: Level) -> Self {
177 self.zstd_level = level;
178 self
179 }
180
181 fn levels(&self) -> Levels {
182 Levels {
183 brotli: self.brotli_level,
184 gzip: self.gzip_level,
185 zstd: self.zstd_level,
186 }
187 }
188
189 fn negotiate(&self, header: &str) -> Option<CompressionAlgorithm> {
190 parse_accept_encoding(header)
191 .into_iter()
192 .find_map(|(algo, _)| {
193 if self.algorithms.contains(&algo) {
194 Some(algo)
195 } else {
196 None
197 }
198 })
199 }
200}
201
202fn parse_accept_encoding(header: &str) -> Vec<(CompressionAlgorithm, u8)> {
203 let mut vec = header
204 .split(',')
205 .filter_map(|s| {
206 let mut iter = s.trim().split(';');
207 let (algo, q) = (iter.next()?, iter.next());
208 let algo = algo.trim().parse().ok()?;
209 let q = q
210 .and_then(|q| {
211 q.trim()
212 .strip_prefix("q=")
213 .and_then(|q| q.parse::<f32>().map(|f| (f * 100.0) as u8).ok())
214 })
215 .unwrap_or(100u8);
216 Some((algo, q))
217 })
218 .collect::<Vec<(CompressionAlgorithm, u8)>>();
219
220 vec.sort_by(|(algo_a, a), (algo_b, b)| match b.cmp(a) {
221 std::cmp::Ordering::Equal => algo_a.cmp(algo_b),
222 other => other,
223 });
224
225 vec
226}
227
228fn is_already_compressed(content_type: &str) -> bool {
234 let primary = content_type
235 .split(';')
236 .next()
237 .unwrap_or(content_type)
238 .trim();
239 matches!(
240 primary,
241 "image/png"
242 | "image/jpeg"
243 | "image/jpg"
244 | "image/gif"
245 | "image/webp"
246 | "image/avif"
247 | "image/heic"
248 | "image/heif"
249 | "image/apng"
250 | "image/x-icon"
251 | "video/mp4"
252 | "video/webm"
253 | "video/ogg"
254 | "video/quicktime"
255 | "video/x-msvideo"
256 | "audio/mpeg"
257 | "audio/ogg"
258 | "audio/webm"
259 | "audio/aac"
260 | "audio/flac"
261 | "audio/mp4"
262 | "font/woff"
263 | "font/woff2"
264 | "application/zip"
265 | "application/gzip"
266 | "application/x-gzip"
267 | "application/x-bzip2"
268 | "application/x-xz"
269 | "application/x-7z-compressed"
270 | "application/x-rar-compressed"
271 | "application/zstd"
272 ) || primary.starts_with("video/")
273 || primary.starts_with("audio/")
274}
275
276#[derive(Clone, Copy, Debug)]
278pub(crate) struct Levels {
279 brotli: Level,
280 gzip: Level,
281 zstd: Level,
282}
283
284impl Default for Levels {
285 fn default() -> Self {
286 Self {
287 brotli: Level::Precise(4),
288 gzip: Level::Default,
289 zstd: Level::Default,
290 }
291 }
292}
293
294impl CompressionAlgorithm {
295 pub(crate) async fn encode(self, body: Body, levels: Levels) -> (Body, bool) {
299 if self == Self::Identity {
300 return (body, false);
301 }
302
303 if body.is_static() {
304 let bytes = body.static_bytes().unwrap();
305 match self.encode_static(bytes, levels).await {
306 Some(data) if data.len() < bytes.len() => {
307 log::trace!(
308 "compressed {} body {} → {} bytes",
309 self.as_str(),
310 bytes.len(),
311 data.len()
312 );
313 (Body::new_static(data), true)
314 }
315 _ => (body, false),
316 }
317 } else if body.is_streaming() {
318 (
319 self.encode_streaming(BufReader::new(body.into_reader()), levels),
320 true,
321 )
322 } else {
323 (body, false)
324 }
325 }
326
327 async fn encode_static(self, bytes: &[u8], levels: Levels) -> Option<Vec<u8>> {
329 let mut data = vec![];
330 let result = match self {
331 Self::Identity => return None,
332 Self::Zstd => {
333 ZstdEncoder::with_quality(Cursor::new(bytes), levels.zstd)
334 .read_to_end(&mut data)
335 .await
336 }
337 Self::Brotli => {
338 BrotliEncoder::with_quality(Cursor::new(bytes), levels.brotli)
339 .read_to_end(&mut data)
340 .await
341 }
342 Self::Gzip => {
343 GzipEncoder::with_quality(Cursor::new(bytes), levels.gzip)
344 .read_to_end(&mut data)
345 .await
346 }
347 };
348 result.ok().map(|_| data)
349 }
350
351 fn encode_streaming(self, reader: impl AsyncBufRead + Send + 'static, levels: Levels) -> Body {
353 match self {
354 Self::Identity => Body::new_streaming(reader, None),
355 Self::Zstd => Body::new_streaming(ZstdEncoder::with_quality(reader, levels.zstd), None),
356 Self::Brotli => {
357 Body::new_streaming(BrotliEncoder::with_quality(reader, levels.brotli), None)
358 }
359 Self::Gzip => Body::new_streaming(GzipEncoder::with_quality(reader, levels.gzip), None),
360 }
361 }
362
363 #[cfg(feature = "client")]
365 pub(crate) fn decode_streaming(self, reader: impl AsyncBufRead + Send + 'static) -> Body {
366 match self {
367 Self::Identity => Body::new_streaming(reader, None),
368 Self::Zstd => Body::new_streaming(ZstdDecoder::new(reader), None),
369 Self::Brotli => Body::new_streaming(BrotliDecoder::new(reader), None),
370 Self::Gzip => Body::new_streaming(GzipDecoder::new(reader), None),
371 }
372 }
373}
374
375impl Handler for Compression {
376 async fn run(&self, mut conn: Conn) -> Conn {
377 if let Some(header) = conn
378 .request_headers()
379 .get_str(AcceptEncoding)
380 .and_then(|h| self.negotiate(h))
381 {
382 conn.insert_state(header);
383 }
384 conn
385 }
386
387 async fn before_send(&self, mut conn: Conn) -> Conn {
388 if conn.response_headers().get_str(ContentEncoding).is_some() {
391 return conn;
392 }
393
394 if conn
396 .response_headers()
397 .get_str(ContentType)
398 .is_some_and(is_already_compressed)
399 {
400 return conn;
401 }
402
403 let Some(algo) = conn.state::<CompressionAlgorithm>().copied() else {
404 return conn;
405 };
406
407 let body = conn_unwrap!(conn.take_response_body(), conn);
408 let (body, compression_used) = algo.encode(body, self.levels()).await;
409
410 if compression_used {
411 let vary = conn
412 .response_headers()
413 .get_str(Vary)
414 .map(|vary| HeaderValues::from(format!("{vary}, Accept-Encoding")))
415 .unwrap_or_else(|| HeaderValues::from("Accept-Encoding"));
416
417 conn.response_headers_mut().extend([
418 (ContentEncoding, HeaderValues::from(algo.as_str())),
419 (Vary, vary),
420 ]);
421 }
422
423 conn.with_body(body)
424 }
425}
426
427pub fn compression() -> Compression {
429 Compression::new()
430}