Skip to main content

trillium_http/
received_body.rs

1use crate::{Body, Buffer, Error, Headers, HttpConfig, MutCow, copy, h3::H3Connection};
2use Poll::{Pending, Ready};
3use ReceivedBodyState::{Chunked, End, FixedLength, PartialChunkSize, Start};
4use encoding_rs::Encoding;
5use futures_lite::{AsyncRead, AsyncReadExt, AsyncWrite, ready};
6use std::{
7    fmt::{self, Debug, Formatter},
8    io::{self, ErrorKind},
9    pin::Pin,
10    sync::Arc,
11    task::{Context, Poll},
12};
13
14mod chunked;
15mod fixed_length;
16mod h3_data;
17
18/// A received http body
19///
20/// This type represents a body that will be read from the underlying transport, which it may either
21/// borrow from a [`Conn`](crate::Conn) or own.
22///
23/// ```rust
24/// # use trillium_testing::HttpTest;
25/// let app = HttpTest::new(|mut conn| async move {
26///     let body = conn.request_body();
27///     let body_string = body.read_string().await.unwrap();
28///     conn.with_response_body(format!("received: {body_string}"))
29/// });
30///
31/// app.get("/").block().assert_body("received: ");
32/// app.post("/")
33///     .with_body("hello")
34///     .block()
35///     .assert_body("received: hello");
36/// ```
37///
38/// ## Bounds checking
39///
40/// Every `ReceivedBody` has a maximum length beyond which it will return an error, expressed as a
41/// u64. To override this on the specific `ReceivedBody`, use [`ReceivedBody::with_max_len`] or
42/// [`ReceivedBody::set_max_len`]
43///
44/// The default maximum length is 10mb; see [`HttpConfig::received_body_max_len`] to configure
45/// this server-wide.
46///
47/// ## Large chunks, small read buffers
48///
49/// Attempting to read a chunked body with a buffer that is shorter than the chunk size in hex will
50/// result in an error.
51#[derive(fieldwork::Fieldwork)]
52pub struct ReceivedBody<'conn, Transport> {
53    /// The content-length of this body, if available. This
54    /// usually is derived from the content-length header. If the http
55    /// request or response that this body is attached to uses
56    /// transfer-encoding chunked, this will be None.
57    ///
58    /// ```rust
59    /// # use trillium_testing::HttpTest;
60    /// HttpTest::new(|mut conn| async move {
61    ///     let body = conn.request_body();
62    ///     assert_eq!(body.content_length(), Some(5));
63    ///     let body_string = body.read_string().await.unwrap();
64    ///     conn.with_status(200)
65    ///         .with_response_body(format!("received: {body_string}"))
66    /// })
67    /// .post("/")
68    /// .with_body("hello")
69    /// .block()
70    /// .assert_ok()
71    /// .assert_body("received: hello");
72    /// ```
73    #[field(get)]
74    content_length: Option<u64>,
75
76    buffer: MutCow<'conn, Buffer>,
77
78    transport: Option<MutCow<'conn, Transport>>,
79
80    state: MutCow<'conn, ReceivedBodyState>,
81
82    on_completion: Option<Box<dyn FnOnce(Transport) + Send + Sync + 'static>>,
83
84    /// the character encoding of this body, usually determined from the content type
85    /// (mime-type) of the associated Conn.
86    #[field(get)]
87    encoding: &'static Encoding,
88
89    /// The maximum length that can be read from this body before error
90    ///
91    /// See also [`HttpConfig::received_body_max_len`]
92    #[field(with, get, set)]
93    max_len: u64,
94
95    /// The initial buffer capacity allocated when reading the body to bytes or a string
96    ///
97    /// See [`HttpConfig::received_body_initial_len`]
98    #[field(with, get, set)]
99    initial_len: usize,
100
101    /// The maximum number of read loops that reading this received body will perform before
102    /// yielding back to the runtime
103    ///
104    /// See [`HttpConfig::copy_loops_per_yield`]
105    #[field(with, get, set)]
106    copy_loops_per_yield: usize,
107
108    /// Maximum size to pre-allocate based on content-length for buffering this received body
109    ///
110    /// See [`HttpConfig::received_body_max_preallocate`]
111    #[field(with, get, set)]
112    max_preallocate: usize,
113
114    h3_max_field_section_size: u64,
115
116    trailers: MutCow<'conn, Option<Headers>>,
117
118    /// Byte offset into `b"HTTP/1.1 100 Continue\r\n\r\n"` that remains to be written before the
119    /// first read. `None` means no pending write.
120    send_100_continue_offset: Option<usize>,
121
122    /// holds connection and stream id
123    h3_connection: Option<(Arc<H3Connection>, u64)>,
124
125    /// a boxed future that handles decoding trailers
126    h3_trailer_future:
127        Option<Pin<Box<dyn Future<Output = io::Result<Headers>> + Send + Sync + 'static>>>,
128}
129
130fn slice_from(min: u64, buf: &[u8]) -> Option<&[u8]> {
131    buf.get(usize::try_from(min).unwrap_or(usize::MAX)..)
132        .filter(|buf| !buf.is_empty())
133}
134
135impl<'conn, Transport> ReceivedBody<'conn, Transport>
136where
137    Transport: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static,
138{
139    #[allow(missing_docs)]
140    #[doc(hidden)]
141    pub fn new(
142        content_length: Option<u64>,
143        buffer: impl Into<MutCow<'conn, Buffer>>,
144        transport: impl Into<MutCow<'conn, Transport>>,
145        state: impl Into<MutCow<'conn, ReceivedBodyState>>,
146        on_completion: Option<Box<dyn FnOnce(Transport) + Send + Sync + 'static>>,
147        encoding: &'static Encoding,
148    ) -> Self {
149        Self::new_with_config(
150            content_length,
151            buffer,
152            transport,
153            state,
154            on_completion,
155            encoding,
156            &HttpConfig::DEFAULT,
157        )
158    }
159
160    #[allow(missing_docs)]
161    #[doc(hidden)]
162    pub(crate) fn new_with_config(
163        content_length: Option<u64>,
164        buffer: impl Into<MutCow<'conn, Buffer>>,
165        transport: impl Into<MutCow<'conn, Transport>>,
166        state: impl Into<MutCow<'conn, ReceivedBodyState>>,
167        on_completion: Option<Box<dyn FnOnce(Transport) + Send + Sync + 'static>>,
168        encoding: &'static Encoding,
169        config: &HttpConfig,
170    ) -> Self {
171        Self {
172            content_length,
173            buffer: buffer.into(),
174            transport: Some(transport.into()),
175            state: state.into(),
176            on_completion,
177            encoding,
178            max_len: config.received_body_max_len,
179            initial_len: config.received_body_initial_len,
180            copy_loops_per_yield: config.copy_loops_per_yield,
181            max_preallocate: config.received_body_max_preallocate,
182            h3_max_field_section_size: config.h3_max_field_section_size,
183            trailers: None.into(),
184            send_100_continue_offset: None,
185            h3_connection: None,
186            h3_trailer_future: None,
187        }
188    }
189
190    /// Sets the destination for trailers decoded from the request body.
191    ///
192    /// When the body is fully read, any trailers will be written to the provided storage.
193    #[doc(hidden)]
194    #[must_use]
195    pub fn with_trailers(mut self, trailers: impl Into<MutCow<'conn, Option<Headers>>>) -> Self {
196        self.trailers = trailers.into();
197        self
198    }
199
200    #[doc(hidden)]
201    #[must_use]
202    #[cfg(feature = "unstable")]
203    pub fn with_h3_connection(mut self, h3_connection: Option<(Arc<H3Connection>, u64)>) -> Self {
204        self.h3_connection = h3_connection;
205        self
206    }
207
208    #[doc(hidden)]
209    #[must_use]
210    #[cfg(not(feature = "unstable"))]
211    pub(crate) fn with_h3_connection(
212        mut self,
213        h3_connection: Option<(Arc<H3Connection>, u64)>,
214    ) -> Self {
215        self.h3_connection = h3_connection;
216        self
217    }
218
219    /// Arranges for `HTTP/1.1 100 Continue\r\n\r\n` to be written to the transport before the
220    /// first body read. Used to implement lazy 100-continue for HTTP/1.1 request bodies.
221    #[must_use]
222    pub(crate) fn with_send_100_continue(mut self) -> Self {
223        self.send_100_continue_offset = Some(0);
224        self
225    }
226
227    // pub fn content_length(&self) -> Option<u64> {
228    //     self.content_length
229    // }
230
231    /// # Reads entire body to String.
232    ///
233    /// This uses the encoding determined by the content-type (mime)
234    /// charset. If an encoding problem is encountered, the String
235    /// returned by [`ReceivedBody::read_string`] will contain utf8
236    /// replacement characters.
237    ///
238    /// Note that this can only be performed once per Conn, as the
239    /// underlying data is not cached anywhere. This is the only copy of
240    /// the body contents.
241    ///
242    /// # Errors
243    ///
244    /// This will return an error if there is an IO error on the
245    /// underlying transport such as a disconnect
246    ///
247    /// This will also return an error if the length exceeds the maximum length. To override this
248    /// value on this specific body, use [`ReceivedBody::with_max_len`] or
249    /// [`ReceivedBody::set_max_len`]
250    pub async fn read_string(self) -> crate::Result<String> {
251        let encoding = self.encoding();
252        let bytes = self.read_bytes().await?;
253        let (s, _, _) = encoding.decode(&bytes);
254        Ok(s.to_string())
255    }
256
257    fn owns_transport(&self) -> bool {
258        self.transport.as_ref().is_some_and(MutCow::is_owned)
259    }
260
261    /// Similar to [`ReceivedBody::read_string`], but returns the raw bytes. This is useful for
262    /// bodies that are not text.
263    ///
264    /// You can use this in conjunction with `encoding` if you need different handling of malformed
265    /// character encoding than the lossy conversion provided by [`ReceivedBody::read_string`].
266    ///
267    /// # Errors
268    ///
269    /// This will return an error if there is an IO error on the underlying transport such as a
270    /// disconnect
271    ///
272    /// This will also return an error if the length exceeds
273    /// [`received_body_max_len`][HttpConfig::with_received_body_max_len]. To override this value on
274    /// this specific body, use [`ReceivedBody::with_max_len`] or [`ReceivedBody::set_max_len`]
275    pub async fn read_bytes(mut self) -> crate::Result<Vec<u8>> {
276        let mut vec = if let Some(len) = self.content_length {
277            if len > self.max_len {
278                return Err(Error::ReceivedBodyTooLong(self.max_len));
279            }
280
281            let len = usize::try_from(len).map_err(|_| Error::ReceivedBodyTooLong(self.max_len))?;
282
283            Vec::with_capacity(len.min(self.max_preallocate))
284        } else {
285            Vec::with_capacity(self.initial_len)
286        };
287
288        self.read_to_end(&mut vec).await?;
289        Ok(vec)
290    }
291
292    // pub fn encoding(&self) -> &'static Encoding {
293    //     self.encoding
294    // }
295
296    fn read_raw(&mut self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
297        if let Some(transport) = self.transport.as_deref_mut() {
298            read_buffered(&mut self.buffer, transport, cx, buf)
299        } else {
300            Ready(Err(ErrorKind::NotConnected.into()))
301        }
302    }
303
304    /// Consumes the remainder of this body from the underlying transport by reading it to the end
305    /// and discarding the contents. This is important for http1.1 keepalive, but most of the
306    /// time you do not need to directly call this. It returns the number of bytes consumed.
307    ///
308    /// # Errors
309    ///
310    /// This will return an [`std::io::Result::Err`] if there is an io error on the underlying
311    /// transport, such as a disconnect
312    #[allow(clippy::missing_errors_doc)] // false positive
313    pub async fn drain(self) -> io::Result<u64> {
314        let copy_loops_per_yield = self.copy_loops_per_yield;
315        copy(self, futures_lite::io::sink(), copy_loops_per_yield).await
316    }
317}
318
319impl<T> ReceivedBody<'static, T> {
320    /// takes the static transport from this received body
321    pub fn take_transport(&mut self) -> Option<T> {
322        self.transport.take().map(MutCow::unwrap_owned)
323    }
324}
325
326pub(crate) fn read_buffered<Transport>(
327    buffer: &mut Buffer,
328    transport: &mut Transport,
329    cx: &mut Context<'_>,
330    buf: &mut [u8],
331) -> Poll<io::Result<usize>>
332where
333    Transport: AsyncRead + Unpin,
334{
335    if buffer.is_empty() {
336        Pin::new(transport).poll_read(cx, buf)
337    } else if buffer.len() >= buf.len() {
338        let len = buf.len();
339        buf.copy_from_slice(&buffer[..len]);
340        buffer.ignore_front(len);
341        Ready(Ok(len))
342    } else {
343        let self_buffer_len = buffer.len();
344        buf[..self_buffer_len].copy_from_slice(buffer);
345        buffer.truncate(0);
346        match Pin::new(transport).poll_read(cx, &mut buf[self_buffer_len..]) {
347            Ready(Ok(additional)) => Ready(Ok(additional + self_buffer_len)),
348            Pending => Ready(Ok(self_buffer_len)),
349            other @ Ready(_) => other,
350        }
351    }
352}
353
354type StateOutput = Poll<io::Result<(ReceivedBodyState, usize)>>;
355
356impl<Transport> ReceivedBody<'_, Transport>
357where
358    Transport: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static,
359{
360    #[inline]
361    fn handle_start(&mut self) -> StateOutput {
362        Ready(Ok((
363            match self.content_length {
364                Some(0) => End,
365
366                Some(total_length) if total_length <= self.max_len => FixedLength {
367                    current_index: 0,
368                    total: total_length,
369                },
370
371                Some(_) => {
372                    return Ready(Err(io::Error::new(
373                        ErrorKind::Unsupported,
374                        "content too long",
375                    )));
376                }
377
378                None => Chunked {
379                    remaining: 0,
380                    total: 0,
381                },
382            },
383            0,
384        )))
385    }
386}
387
388impl<Transport> AsyncRead for ReceivedBody<'_, Transport>
389where
390    Transport: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static,
391{
392    fn poll_read(
393        mut self: Pin<&mut Self>,
394        cx: &mut Context<'_>,
395        buf: &mut [u8],
396    ) -> Poll<io::Result<usize>> {
397        const CONTINUE: &[u8] = b"HTTP/1.1 100 Continue\r\n\r\n";
398        while let Some(offset) = self.send_100_continue_offset {
399            let n = {
400                let Some(transport) = self.transport.as_deref_mut() else {
401                    return Ready(Err(ErrorKind::NotConnected.into()));
402                };
403                if offset == 0 {
404                    log::trace!("sending 100-continue");
405                }
406                ready!(Pin::new(transport).poll_write(cx, &CONTINUE[offset..]))?
407            };
408            if n == 0 {
409                return Ready(Err(ErrorKind::WriteZero.into()));
410            }
411            let new_offset = offset + n;
412            self.send_100_continue_offset = if new_offset >= CONTINUE.len() {
413                None
414            } else {
415                Some(new_offset)
416            };
417        }
418
419        for _ in 0..self.copy_loops_per_yield {
420            let (new_body_state, bytes) = ready!(match *self.state {
421                Start => self.handle_start(),
422                Chunked { remaining, total } => self.handle_chunked(cx, buf, remaining, total),
423                PartialChunkSize { total } => self.handle_partial(cx, buf, total),
424                FixedLength {
425                    current_index,
426                    total,
427                } => self.handle_fixed_length(cx, buf, current_index, total),
428                ReceivedBodyState::H3Data {
429                    remaining_in_frame,
430                    total,
431                    frame_type,
432                    partial_frame_header,
433                } => self.handle_h3_data(
434                    cx,
435                    buf,
436                    remaining_in_frame,
437                    total,
438                    frame_type,
439                    partial_frame_header,
440                ),
441                ReceivedBodyState::ReadingH1Trailers { total } => {
442                    self.handle_reading_h1_trailers(cx, buf, total)
443                }
444                End => Ready(Ok((End, 0))),
445            })?;
446
447            *self.state = new_body_state;
448
449            if *self.state == End {
450                if bytes == 0
451                    && let Some(h3_trailer_future) = &mut self.h3_trailer_future
452                {
453                    let trailers = ready!(h3_trailer_future.as_mut().poll(cx))?;
454                    *self.trailers = Some(trailers);
455                    self.h3_trailer_future = None;
456                }
457
458                if self.on_completion.is_some() && self.owns_transport() {
459                    let transport = self.transport.take().unwrap().unwrap_owned();
460                    let on_completion = self.on_completion.take().unwrap();
461                    on_completion(transport);
462                }
463                return Ready(Ok(bytes));
464            } else if bytes != 0 {
465                return Ready(Ok(bytes));
466            }
467        }
468
469        cx.waker().wake_by_ref();
470        Pending
471    }
472}
473
474impl<Transport> Debug for ReceivedBody<'_, Transport> {
475    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
476        f.debug_struct("ReceivedBody")
477            .field("state", &*self.state)
478            .field("content_length", &self.content_length)
479            .field("buffer", &format_args!(".."))
480            .field("on_completion", &self.on_completion.is_some())
481            .finish()
482    }
483}
484
485/// The type of H3 frame currently being processed in [`ReceivedBodyState::H3Data`].
486#[derive(Debug, Clone, Copy, Eq, PartialEq, Default)]
487#[allow(missing_docs)]
488#[doc(hidden)]
489pub enum H3BodyFrameType {
490    /// Initial state — no frame decoded yet.
491    #[default]
492    Start,
493    /// Inside a DATA frame — body bytes to keep.
494    Data,
495    /// Inside an unknown frame — payload bytes to discard.
496    Unknown,
497    /// Inside a trailing HEADERS frame — accumulate into buffer for parsing.
498    Trailers,
499}
500
501#[derive(Debug, Clone, Copy, Eq, PartialEq, Default)]
502#[allow(missing_docs)]
503#[doc(hidden)]
504pub enum ReceivedBodyState {
505    /// initial state
506    #[default]
507    Start,
508
509    /// read state for a chunked-encoded body. the number of bytes that have been read from the
510    /// current chunk is the difference between remaining and total.
511    Chunked {
512        /// remaining indicates the bytes left _in the current
513        /// chunk_. initial state is zero.
514        remaining: u64,
515
516        /// total indicates the absolute number of bytes read from all chunks
517        total: u64,
518    },
519
520    /// read state when we have buffered content between subsequent polls because chunk framing
521    /// overlapped a buffer boundary
522    PartialChunkSize { total: u64 },
523
524    /// read state for a fixed-length body.
525    FixedLength {
526        /// current index represents the bytes that have already been
527        /// read. initial state is zero
528        current_index: u64,
529
530        /// total length indicates the claimed length, usually
531        /// determined by the content-length header
532        total: u64,
533    },
534
535    /// read state for an H3 body framed as DATA frames.
536    H3Data {
537        /// bytes remaining in the current frame (DATA, Unknown, or Trailers). zero means we need
538        /// to read the next frame header.
539        remaining_in_frame: u64,
540
541        /// total body bytes read across all DATA frames.
542        total: u64,
543
544        /// what kind of frame we're currently inside.
545        frame_type: H3BodyFrameType,
546
547        /// when true, a partial frame header is sitting in `self.buffer` and needs more bytes
548        /// before we can decode it.
549        partial_frame_header: bool,
550    },
551
552    /// accumulating the HTTP/1.1 chunked trailer-section after the last-chunk (`0\r\n`).
553    ///
554    /// The trailer bytes (including any partially-received trailer headers) live in
555    /// `ReceivedBody::buffer` until a final empty line (`\r\n\r\n` or bare `\r\n`) is found.
556    ReadingH1Trailers {
557        /// total body bytes read across all chunks (for bounds-checking)
558        total: u64,
559    },
560
561    /// the terminal read state
562    End,
563}
564
565impl ReceivedBodyState {
566    pub fn new_h3() -> Self {
567        Self::H3Data {
568            remaining_in_frame: 0,
569            total: 0,
570            frame_type: H3BodyFrameType::Start,
571            partial_frame_header: false,
572        }
573    }
574}
575
576impl<Transport> From<ReceivedBody<'static, Transport>> for Body
577where
578    Transport: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static,
579{
580    fn from(rb: ReceivedBody<'static, Transport>) -> Self {
581        let len = rb.content_length;
582        Body::new_streaming(rb, len)
583    }
584}