Skip to main content

trillium_http/
body.rs

1use crate::{Headers, h3::H3Body};
2use BodyType::{Empty, Static, Streaming};
3use futures_lite::{AsyncRead, AsyncReadExt, io::Cursor, ready};
4use pin_project_lite::pin_project;
5use std::{
6    borrow::Cow,
7    fmt::{self, Debug, Formatter},
8    io::{Error, Result},
9    pin::Pin,
10    task::{Context, Poll},
11};
12use sync_wrapper::SyncWrapper;
13
14/// Trait for streaming body sources that can optionally produce trailers.
15///
16/// Implement this on types that compute trailer headers dynamically as the body
17/// is read — for example, a hashing wrapper that produces a `Digest` trailer
18/// after all bytes have been streamed.
19///
20/// For plain [`AsyncRead`] sources with no trailers, use [`Body::new_streaming`].
21/// `BodySource` is only needed when trailers must be produced.
22pub trait BodySource: AsyncRead + Send + 'static {
23    /// Returns the trailers for this body, called after the body has been fully read.
24    ///
25    /// Implementations may clear internal state on this call; the result is
26    /// only meaningful after [`AsyncRead::poll_read`] has returned `Ok(0)`.
27    fn trailers(self: Pin<&mut Self>) -> Option<Headers>;
28}
29
30pin_project! {
31    struct PlainBody<T> {
32        #[pin]
33        async_read: T,
34    }
35}
36
37impl<T: AsyncRead> AsyncRead for PlainBody<T> {
38    fn poll_read(
39        self: Pin<&mut Self>,
40        cx: &mut Context<'_>,
41        buf: &mut [u8],
42    ) -> Poll<Result<usize>> {
43        self.project().async_read.poll_read(cx, buf)
44    }
45}
46
47impl<T: AsyncRead + Send + 'static> BodySource for PlainBody<T> {
48    fn trailers(self: Pin<&mut Self>) -> Option<Headers> {
49        None
50    }
51}
52
53/// The trillium representation of a http body. This can contain
54/// either `&'static [u8]` content, `Vec<u8>` content, or a boxed
55/// [`AsyncRead`]/[`BodySource`] type.
56#[derive(Debug, Default)]
57pub struct Body(pub(crate) BodyType);
58
59impl Body {
60    /// Construct a new body from a streaming [`AsyncRead`] source. If
61    /// you have the body content in memory already, prefer
62    /// [`Body::new_static`] or one of the From conversions.
63    pub fn new_streaming(async_read: impl AsyncRead + Send + 'static, len: Option<u64>) -> Self {
64        Self::new_with_trailers(PlainBody { async_read }, len)
65    }
66
67    /// Construct a new body from a [`BodySource`] that can produce trailers after
68    /// the body has been fully read.
69    ///
70    /// Use this when trailers must be computed dynamically from the body bytes,
71    /// for example to append a content hash.
72    pub fn new_with_trailers(body: impl BodySource, len: Option<u64>) -> Self {
73        Self(Streaming {
74            async_read: SyncWrapper::new(Box::pin(body)),
75            len,
76            done: false,
77            progress: 0,
78        })
79    }
80
81    /// Returns trailers from the body source, if any.
82    ///
83    /// Only meaningful after the body has been fully read (i.e., [`AsyncRead::poll_read`]
84    /// has returned `Ok(0)`). Returns `None` for bodies constructed with
85    /// [`Body::new_streaming`] or [`Body::new_static`].
86    #[doc(hidden)] // this isn't really a user-facing interface
87    pub fn trailers(&mut self) -> Option<Headers> {
88        match &mut self.0 {
89            Streaming {
90                async_read, done, ..
91            } if *done => async_read.get_mut().as_mut().trailers(),
92            _ => None,
93        }
94    }
95
96    /// Construct a fixed-length Body from a `Vec<u8>` or `&'static
97    /// [u8]`.
98    pub fn new_static(content: impl Into<Cow<'static, [u8]>>) -> Self {
99        Self(Static {
100            content: content.into(),
101            cursor: 0,
102        })
103    }
104
105    /// Retrieve a borrow of the static content in this body. If this
106    /// body is a streaming body or an empty body, this will return
107    /// None.
108    pub fn static_bytes(&self) -> Option<&[u8]> {
109        match &self.0 {
110            Static { content, .. } => Some(content.as_ref()),
111            _ => None,
112        }
113    }
114
115    /// Transform this Body into a dyn [`AsyncRead`]. This will wrap
116    /// static content in a [`Cursor`]. Note that this is different
117    /// from reading directly from the Body, which includes chunked
118    /// encoding.
119    pub fn into_reader(self) -> Pin<Box<dyn AsyncRead + Send + Sync + 'static>> {
120        match self.0 {
121            Streaming { async_read, .. } => Box::pin(SyncAsyncReader(async_read)),
122            Static { content, .. } => Box::pin(Cursor::new(content)),
123            Empty => Box::pin(Cursor::new("")),
124        }
125    }
126
127    /// Consume this body and return the full content. If the body was
128    /// constructed with [`Body::new_streaming`], this will read the
129    /// entire streaming body into memory, awaiting the streaming
130    /// source's completion. This function will return an error if a
131    /// streaming body has already been partially or fully read.
132    ///
133    /// # Errors
134    ///
135    /// This returns an error variant if either of the following conditions are met:
136    ///
137    /// there is an io error when reading from the underlying transport such as a disconnect
138    /// the body has already been read to completion
139    pub async fn into_bytes(self) -> Result<Cow<'static, [u8]>> {
140        match self.0 {
141            Static { content, .. } => Ok(content),
142
143            Streaming {
144                async_read,
145                len,
146                progress: 0,
147                done: false,
148            } => {
149                let mut async_read = async_read.into_inner();
150                let mut buf = len
151                    .and_then(|c| c.try_into().ok())
152                    .map(Vec::with_capacity)
153                    .unwrap_or_default();
154
155                async_read.read_to_end(&mut buf).await?;
156
157                Ok(Cow::Owned(buf))
158            }
159
160            Empty => Ok(Cow::Borrowed(b"")),
161
162            Streaming { .. } => Err(Error::other("body already read to completion")),
163        }
164    }
165
166    /// Retrieve the number of bytes that have been read from this
167    /// body
168    pub fn bytes_read(&self) -> u64 {
169        self.0.bytes_read()
170    }
171
172    /// returns the content length of this body, if known and
173    /// available.
174    pub fn len(&self) -> Option<u64> {
175        self.0.len()
176    }
177
178    /// determine if the this body represents no data
179    pub fn is_empty(&self) -> bool {
180        self.0.is_empty()
181    }
182
183    /// determine if the this body represents static content
184    pub fn is_static(&self) -> bool {
185        matches!(self.0, Static { .. })
186    }
187
188    /// determine if the this body represents streaming content
189    pub fn is_streaming(&self) -> bool {
190        matches!(self.0, Streaming { .. })
191    }
192
193    /// Convert this body into an `H3Body` for reading
194    #[cfg(feature = "unstable")]
195    pub fn into_h3(self) -> H3Body {
196        H3Body::new(self)
197    }
198
199    /// Convert this body into an `H3Body` for reading
200    #[cfg(not(feature = "unstable"))]
201    pub(crate) fn into_h3(self) -> H3Body {
202        H3Body::new(self)
203    }
204}
205
206#[allow(
207    clippy::cast_sign_loss,
208    clippy::cast_possible_truncation,
209    clippy::cast_precision_loss
210)]
211fn max_bytes_to_read(buf_len: usize) -> usize {
212    assert!(
213        buf_len >= 6,
214        "buffers of length {buf_len} are too small for this implementation.
215            if this is a problem for you, please open an issue"
216    );
217
218    // #[allow(clippy::cast_precision_loss)] applied to the function
219    // is for this line. We do not expect our buffers to be on the
220    // order of petabytes, so we will not fall outside of the range of
221    // integers that can be represented by f64
222    let bytes_remaining_after_two_cr_lns = (buf_len - 4) as f64;
223
224    // #[allow(clippy::cast_sign_loss)] applied to the function is for
225    // this line. This is ok because we know buf_len is already a
226    // usize and we are just converting it to an f64 in order to do
227    // float log2(x)/4
228    //
229    // the maximum number of bytes that the hex representation of remaining bytes might take
230    let max_bytes_of_hex_framing = (bytes_remaining_after_two_cr_lns).log2() / 4f64;
231
232    // #[allow(clippy::cast_sign_loss)] applied to the function is for
233    // this line.  This is ok because max_bytes_of_hex_framing will
234    // always be smaller than bytes_remaining_after_two_cr_lns, and so
235    // there is no risk of sign loss
236    (bytes_remaining_after_two_cr_lns - max_bytes_of_hex_framing.ceil()) as usize
237}
238
239impl AsyncRead for Body {
240    fn poll_read(
241        mut self: Pin<&mut Self>,
242        cx: &mut Context<'_>,
243        buf: &mut [u8],
244    ) -> Poll<Result<usize>> {
245        match &mut self.0 {
246            Empty => Poll::Ready(Ok(0)),
247            Static { content, cursor } => {
248                let length = content.len();
249                if length == *cursor {
250                    return Poll::Ready(Ok(0));
251                }
252                let bytes = (length - *cursor).min(buf.len());
253                buf[0..bytes].copy_from_slice(&content[*cursor..*cursor + bytes]);
254                *cursor += bytes;
255                Poll::Ready(Ok(bytes))
256            }
257
258            Streaming {
259                async_read,
260                len: Some(len),
261                done,
262                progress,
263            } => {
264                if *done {
265                    return Poll::Ready(Ok(0));
266                }
267
268                let max_bytes_to_read = (*len - *progress)
269                    .try_into()
270                    .unwrap_or(buf.len())
271                    .min(buf.len());
272
273                let bytes = ready!(
274                    async_read
275                        .get_mut()
276                        .as_mut()
277                        .poll_read(cx, &mut buf[..max_bytes_to_read])
278                )?;
279
280                if bytes == 0 {
281                    *done = true;
282                } else {
283                    *progress += bytes as u64;
284                }
285
286                Poll::Ready(Ok(bytes))
287            }
288
289            Streaming {
290                async_read,
291                len: None,
292                done,
293                progress,
294            } => {
295                if *done {
296                    return Poll::Ready(Ok(0));
297                }
298
299                let max_bytes_to_read = max_bytes_to_read(buf.len());
300
301                let bytes = ready!(
302                    async_read
303                        .get_mut()
304                        .as_mut()
305                        .poll_read(cx, &mut buf[..max_bytes_to_read])
306                )?;
307
308                if bytes == 0 {
309                    *done = true;
310                    // Write only the last-chunk marker; the caller must write the
311                    // trailer-section (possibly empty) followed by the terminating CRLF.
312                    buf[..3].copy_from_slice(b"0\r\n");
313                    return Poll::Ready(Ok(3));
314                }
315
316                *progress += bytes as u64;
317
318                let start = format!("{bytes:X}\r\n");
319                let start_length = start.len();
320                let total = bytes + start_length + 2;
321                buf.copy_within(..bytes, start_length);
322                buf[..start_length].copy_from_slice(start.as_bytes());
323                buf[total - 2..total].copy_from_slice(b"\r\n");
324                Poll::Ready(Ok(total))
325            }
326        }
327    }
328}
329
330struct SyncAsyncReader(SyncWrapper<Pin<Box<dyn BodySource>>>);
331impl Debug for SyncAsyncReader {
332    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
333        f.debug_struct("SyncAsyncReader").finish()
334    }
335}
336impl AsyncRead for SyncAsyncReader {
337    fn poll_read(
338        self: Pin<&mut Self>,
339        cx: &mut Context<'_>,
340        buf: &mut [u8],
341    ) -> Poll<Result<usize>> {
342        self.get_mut().0.get_mut().as_mut().poll_read(cx, buf)
343    }
344}
345
346#[derive(Default)]
347pub(crate) enum BodyType {
348    #[default]
349    Empty,
350
351    Static {
352        content: Cow<'static, [u8]>,
353        cursor: usize,
354    },
355
356    Streaming {
357        async_read: SyncWrapper<Pin<Box<dyn BodySource>>>,
358        progress: u64,
359        len: Option<u64>,
360        done: bool,
361    },
362}
363
364impl Debug for BodyType {
365    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
366        match self {
367            Empty => f.debug_tuple("BodyType::Empty").finish(),
368            Static { content, cursor } => f
369                .debug_struct("BodyType::Static")
370                .field("content", &String::from_utf8_lossy(content))
371                .field("cursor", cursor)
372                .finish(),
373            Streaming {
374                len,
375                done,
376                progress,
377                ..
378            } => f
379                .debug_struct("BodyType::Streaming")
380                .field("async_read", &format_args!(".."))
381                .field("len", &len)
382                .field("done", &done)
383                .field("progress", &progress)
384                .finish(),
385        }
386    }
387}
388
389impl BodyType {
390    fn is_empty(&self) -> bool {
391        match *self {
392            Empty => true,
393            Static { ref content, .. } => content.is_empty(),
394            Streaming { len, .. } => len == Some(0),
395        }
396    }
397
398    fn len(&self) -> Option<u64> {
399        match *self {
400            Empty => Some(0),
401            Static { ref content, .. } => Some(content.len() as u64),
402            Streaming { len, .. } => len,
403        }
404    }
405
406    fn bytes_read(&self) -> u64 {
407        match *self {
408            Empty => 0,
409            Static { cursor, .. } => cursor as u64,
410            Streaming { progress, .. } => progress,
411        }
412    }
413}
414
415impl From<String> for Body {
416    fn from(s: String) -> Self {
417        s.into_bytes().into()
418    }
419}
420
421impl From<&'static str> for Body {
422    fn from(s: &'static str) -> Self {
423        s.as_bytes().into()
424    }
425}
426
427impl From<&'static [u8]> for Body {
428    fn from(content: &'static [u8]) -> Self {
429        Self::new_static(content)
430    }
431}
432
433impl From<Vec<u8>> for Body {
434    fn from(content: Vec<u8>) -> Self {
435        Self::new_static(content)
436    }
437}
438
439impl From<Cow<'static, [u8]>> for Body {
440    fn from(value: Cow<'static, [u8]>) -> Self {
441        Self::new_static(value)
442    }
443}
444
445impl From<Cow<'static, str>> for Body {
446    fn from(value: Cow<'static, str>) -> Self {
447        match value {
448            Cow::Borrowed(b) => b.into(),
449            Cow::Owned(o) => o.into(),
450        }
451    }
452}
453
454#[cfg(test)]
455mod test_bytes_to_read {
456    #[test]
457    fn simple_check_of_known_values() {
458        // the marked rows are the most important part of this test,
459        // and a nonobvious but intentional consequence of the
460        // implementation. in order to avoid overflowing, we must use
461        // one fewer than the available buffer bytes because
462        // increasing the read size increase the number of framed
463        // bytes by two. This occurs when the hex representation of
464        // the content bytes is near an increase in order of magnitude
465        // (F->10, FF->100, FFF-> 1000, etc)
466        let values = vec![
467            (6, 1),       // 1
468            (7, 2),       // 2
469            (20, 15),     // F
470            (21, 15),     // F <-
471            (22, 16),     // 10
472            (23, 17),     // 11
473            (260, 254),   // FE
474            (261, 254),   // FE <-
475            (262, 255),   // FF <-
476            (263, 256),   // 100
477            (4100, 4093), // FFD
478            (4101, 4093), // FFD <-
479            (4102, 4094), // FFE <-
480            (4103, 4095), // FFF <-
481            (4104, 4096), // 1000
482        ];
483
484        for (input, expected) in values {
485            let actual = super::max_bytes_to_read(input);
486            assert_eq!(
487                actual, expected,
488                "\n\nexpected max_bytes_to_read({input}) to be {expected}, but it was {actual}"
489            );
490
491            // testing the test:
492            let used_bytes = expected + 4 + format!("{expected:X}").len();
493            assert!(
494                used_bytes == input || used_bytes == input - 1,
495                "\n\nfor an input of {}, expected used bytes to be {} or {}, but was {}",
496                input,
497                input,
498                input - 1,
499                used_bytes
500            );
501        }
502    }
503}