Skip to main content

trillium_caching_headers/
cache_control.rs

1use CacheControlDirective::*;
2use std::{
3    fmt::{Display, Write},
4    ops::{Deref, DerefMut},
5    time::Duration,
6};
7use trillium::{Conn, Handler, HeaderValues, KnownHeaderName};
8/// An enum representation of the
9/// [`Cache-Control`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Cache-Control)
10/// directives.
11#[derive(Debug, Clone, Eq, PartialEq)]
12#[non_exhaustive]
13pub enum CacheControlDirective {
14    /// [`immutable`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Cache-Control#revalidation_and_reloading)
15    Immutable,
16
17    /// [`max-age`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Cache-Control#expiration)
18    MaxAge(Duration),
19
20    /// [`min-fresh`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Cache-Control#min-fresh)
21    MinFresh(Duration),
22
23    /// [`max-stale`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Cache-Control#expiration)
24    MaxStale(Option<Duration>),
25
26    /// [`must-revalidate`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Cache-Control#revalidation_and_reloading)
27    MustRevalidate,
28
29    /// [`no-cache`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Cache-Control#cacheability)
30    NoCache,
31
32    /// [`no-store`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Cache-Control#cacheability)
33    NoStore,
34
35    /// [`no-transform`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Cache-Control#other)
36    NoTransform,
37
38    /// [`only-if-cached`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Cache-Control#other)
39    OnlyIfCached,
40
41    /// [`private`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Cache-Control#cacheability)
42    Private,
43
44    /// [`proxy-revalidate`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Cache-Control#revalidation_and_reloading)
45    ProxyRevalidate,
46
47    /// [`public`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Cache-Control#cacheability)
48    Public,
49
50    /// [`s-maxage`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Cache-Control#expiration)
51    SMaxage(Duration),
52
53    /// [`stale-if-error`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Cache-Control#expiration)
54    StaleIfError(Duration),
55
56    /// [`stale-while-revalidate`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Cache-Control#expiration)
57    StaleWhileRevalidate(Duration),
58
59    /// an enum variant that will contain any unrecognized directives
60    UnknownDirective(String),
61}
62
63impl Handler for CacheControlDirective {
64    async fn run(&self, conn: Conn) -> Conn {
65        conn.with_response_header(KnownHeaderName::CacheControl, self.clone())
66    }
67}
68
69impl Handler for CacheControlHeader {
70    async fn run(&self, conn: Conn) -> Conn {
71        conn.with_response_header(KnownHeaderName::CacheControl, self.clone())
72    }
73}
74
75/// A representation of the
76/// [`Cache-Control`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Cache-Control)
77/// header.
78#[derive(Debug, Clone, PartialEq, Eq)]
79pub struct CacheControlHeader(Vec<CacheControlDirective>);
80
81/// Construct a CacheControlHeader. Alias for [`CacheControlHeader::new`]
82pub fn cache_control(into: impl Into<CacheControlHeader>) -> CacheControlHeader {
83    into.into()
84}
85
86impl<T> From<T> for CacheControlHeader
87where
88    T: IntoIterator<Item = CacheControlDirective>,
89{
90    fn from(directives: T) -> Self {
91        directives.into_iter().collect()
92    }
93}
94
95impl From<CacheControlDirective> for CacheControlHeader {
96    fn from(directive: CacheControlDirective) -> Self {
97        Self(vec![directive])
98    }
99}
100
101impl FromIterator<CacheControlDirective> for CacheControlHeader {
102    fn from_iter<T: IntoIterator<Item = CacheControlDirective>>(iter: T) -> Self {
103        Self(iter.into_iter().collect())
104    }
105}
106
107impl From<CacheControlDirective> for HeaderValues {
108    fn from(ccd: CacheControlDirective) -> HeaderValues {
109        let header: CacheControlHeader = ccd.into();
110        header.into()
111    }
112}
113
114impl From<CacheControlHeader> for HeaderValues {
115    fn from(cch: CacheControlHeader) -> Self {
116        cch.to_string().into()
117    }
118}
119
120impl CacheControlHeader {
121    /// construct a new cache control header. alias for [`CacheControlHeader::from`]
122    pub fn new(into: impl Into<Self>) -> Self {
123        into.into()
124    }
125
126    /// returns true if one of the directives is `immutable`
127    pub fn is_immutable(&self) -> bool {
128        self.contains(&Immutable)
129    }
130
131    /// returns a duration if one of the directives is `max-age`
132    pub fn max_age(&self) -> Option<Duration> {
133        self.iter().find_map(|d| match d {
134            MaxAge(d) => Some(*d),
135            _ => None,
136        })
137    }
138
139    /// returns a duration if one of the directives is `min-fresh`
140    pub fn min_fresh(&self) -> Option<Duration> {
141        self.iter().find_map(|d| match d {
142            MinFresh(d) => Some(*d),
143            _ => None,
144        })
145    }
146
147    /// returns Some(None) if one of the directives is `max-stale` but
148    /// no value is provided. returns Some(Some(duration)) if one of
149    /// the directives is max-stale and includes a duration in
150    /// seconds, such as `max-stale=3600`. Returns None if there is no
151    /// `max-stale` directive
152    pub fn max_stale(&self) -> Option<Option<Duration>> {
153        self.iter().find_map(|d| match d {
154            MaxStale(d) => Some(*d),
155            _ => None,
156        })
157    }
158
159    /// returns true if this header contains a `must-revalidate` directive
160    pub fn must_revalidate(&self) -> bool {
161        self.contains(&MustRevalidate)
162    }
163
164    /// returns true if this header contains a `no-cache` directive
165    pub fn is_no_cache(&self) -> bool {
166        self.contains(&NoCache)
167    }
168
169    /// returns true if this header contains a `no-store` directive
170    pub fn is_no_store(&self) -> bool {
171        self.contains(&NoStore)
172    }
173
174    /// returns true if this header contains a `no-transform`
175    /// directive
176    pub fn is_no_transform(&self) -> bool {
177        self.contains(&NoTransform)
178    }
179
180    /// returns true if this header contains an `only-if-cached`
181    /// directive
182    pub fn is_only_if_cached(&self) -> bool {
183        self.contains(&OnlyIfCached)
184    }
185
186    /// returns true if this header contains a `private` directive
187    pub fn is_private(&self) -> bool {
188        self.contains(&Private)
189    }
190
191    /// returns true if this header contains a `proxy-revalidate`
192    /// directive
193    pub fn is_proxy_revalidate(&self) -> bool {
194        self.contains(&ProxyRevalidate)
195    }
196
197    /// returns true if this header contains a `proxy` directive
198    pub fn is_public(&self) -> bool {
199        self.contains(&Public)
200    }
201
202    /// returns a duration if this header contains an `s-maxage`
203    /// directive
204    pub fn s_maxage(&self) -> Option<Duration> {
205        self.iter().find_map(|h| match h {
206            SMaxage(d) => Some(*d),
207            _ => None,
208        })
209    }
210
211    /// returns a duration if this header contains a stale-if-error
212    /// directive
213    pub fn stale_if_error(&self) -> Option<Duration> {
214        self.iter().find_map(|h| match h {
215            StaleIfError(d) => Some(*d),
216            _ => None,
217        })
218    }
219
220    /// returns a duration if this header contains a
221    /// stale-while-revalidate directive
222    pub fn stale_while_revalidate(&self) -> Option<Duration> {
223        self.iter().find_map(|h| match h {
224            StaleWhileRevalidate(d) => Some(*d),
225            _ => None,
226        })
227    }
228
229    /// Parse a `Cache-Control` header value. Unrecognized directives are
230    /// preserved as [`CacheControlDirective::UnknownDirective`] per RFC 9111
231    /// §5.2; this parser is infallible.
232    pub fn parse(s: &str) -> Self {
233        Self(
234            s.to_ascii_lowercase()
235                .split(',')
236                .map(str::trim)
237                .filter(|directive| !directive.is_empty())
238                .map(|directive| match directive {
239                    "immutable" => Immutable,
240                    "must-revalidate" => MustRevalidate,
241                    "no-cache" => NoCache,
242                    "no-store" => NoStore,
243                    "no-transform" => NoTransform,
244                    "only-if-cached" => OnlyIfCached,
245                    "private" => Private,
246                    "proxy-revalidate" => ProxyRevalidate,
247                    "public" => Public,
248                    "max-stale" => MaxStale(None),
249                    other => match other.split_once('=') {
250                        Some((directive, value)) => {
251                            let seconds = value.parse().map(Duration::from_secs);
252                            match (directive, seconds) {
253                                ("max-age", Ok(d)) => MaxAge(d),
254                                ("min-fresh", Ok(d)) => MinFresh(d),
255                                ("max-stale", Ok(d)) => MaxStale(Some(d)),
256                                ("s-maxage", Ok(d)) => SMaxage(d),
257                                ("stale-if-error", Ok(d)) => StaleIfError(d),
258                                ("stale-while-revalidate", Ok(d)) => StaleWhileRevalidate(d),
259                                _ => UnknownDirective(String::from(other)),
260                            }
261                        }
262                        None => UnknownDirective(String::from(other)),
263                    },
264                })
265                .collect(),
266        )
267    }
268}
269
270impl Deref for CacheControlHeader {
271    type Target = [CacheControlDirective];
272
273    fn deref(&self) -> &Self::Target {
274        self.0.as_slice()
275    }
276}
277
278impl DerefMut for CacheControlHeader {
279    fn deref_mut(&mut self) -> &mut Self::Target {
280        self.0.as_mut_slice()
281    }
282}
283
284impl Display for CacheControlHeader {
285    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
286        let mut first = true;
287        for directive in &self.0 {
288            if first {
289                first = false;
290            } else {
291                f.write_char(',')?;
292            }
293
294            match directive {
295                Immutable => write!(f, "immutable"),
296                MaxAge(d) => write!(f, "max-age={}", d.as_secs()),
297                MinFresh(d) => write!(f, "min-fresh={}", d.as_secs()),
298                MaxStale(Some(d)) => write!(f, "max-stale={}", d.as_secs()),
299                MaxStale(None) => write!(f, "max-stale"),
300                MustRevalidate => write!(f, "must-revalidate"),
301                NoCache => write!(f, "no-cache"),
302                NoStore => write!(f, "no-store"),
303                NoTransform => write!(f, "no-transform"),
304                OnlyIfCached => write!(f, "only-if-cached"),
305                Private => write!(f, "private"),
306                ProxyRevalidate => write!(f, "proxy-revalidate"),
307                Public => write!(f, "public"),
308                SMaxage(d) => write!(f, "s-maxage={}", d.as_secs()),
309                StaleIfError(d) => write!(f, "stale-if-error={}", d.as_secs()),
310                StaleWhileRevalidate(d) => write!(f, "stale-while-revalidate={}", d.as_secs()),
311                UnknownDirective(directive) => write!(f, "{directive}"),
312            }?;
313        }
314
315        Ok(())
316    }
317}
318
319#[cfg(test)]
320mod test {
321    use super::*;
322    #[test]
323    fn parse() {
324        assert_eq!(
325            CacheControlHeader(vec![NoStore]),
326            CacheControlHeader::parse("no-store")
327        );
328
329        let long = CacheControlHeader::parse(
330            "private,no-cache,no-store,max-age=0,must-revalidate,pre-check=0,post-check=0",
331        );
332
333        assert_eq!(
334            CacheControlHeader::from([
335                Private,
336                NoCache,
337                NoStore,
338                MaxAge(Duration::ZERO),
339                MustRevalidate,
340                UnknownDirective("pre-check=0".to_string()),
341                UnknownDirective("post-check=0".to_string())
342            ]),
343            long
344        );
345
346        assert_eq!(
347            long.to_string(),
348            "private,no-cache,no-store,max-age=0,must-revalidate,pre-check=0,post-check=0"
349        );
350
351        assert_eq!(
352            CacheControlHeader::from([Public, MaxAge(Duration::from_secs(604800)), Immutable]),
353            CacheControlHeader::parse("public, max-age=604800, immutable")
354        );
355    }
356
357    #[test]
358    fn min_fresh() {
359        let parsed = CacheControlHeader::parse("min-fresh=300");
360        assert_eq!(parsed.min_fresh(), Some(Duration::from_secs(300)));
361        assert_eq!(parsed.to_string(), "min-fresh=300");
362    }
363
364    #[test]
365    fn unknown_directive_with_value_does_not_fail_header() {
366        // RFC 9111 §5.2: unrecognized directives MUST be ignored, not abort
367        // parsing of the rest of the header. Previously a non-numeric value on
368        // an unknown directive would cause the whole header to fail to parse.
369        let parsed = CacheControlHeader::parse("garbage=non-numeric, max-age=600");
370        assert_eq!(parsed.max_age(), Some(Duration::from_secs(600)));
371        assert!(parsed.contains(&UnknownDirective("garbage=non-numeric".into())));
372    }
373}