Skip to main content

trillium_client/
response_body.rs

1use crate::{Error, Pool, pool::PoolEntry};
2use encoding_rs::Encoding;
3use futures_lite::{AsyncRead, AsyncReadExt, AsyncWriteExt};
4use std::{
5    fmt::{self, Debug, Formatter},
6    io, mem,
7    pin::Pin,
8    task::{Context, Poll, ready},
9};
10use trillium_http::{
11    Body, BodySource, Headers, HttpConfig, MutCow, ReceivedBody, ReceivedBodyState,
12};
13use trillium_server_common::{Runtime, Transport, url::Origin};
14
15/// A response body received from a server.
16///
17/// Most of the time this represents a body that will be read from the underlying transport, but it
18/// can also wrap an override body installed by middleware via [`ConnExt::set_response_body`]
19/// — e.g. cache hits, mocked responses, or circuit-breaker short-circuits. Reads, encoding
20/// handling, and `max_len` enforcement work transparently across both cases.
21///
22/// [`ConnExt::set_response_body`]: crate::ConnExt::set_response_body
23///
24/// ```rust
25/// use trillium_client::Client;
26/// use trillium_testing::{client_config, with_server};
27///
28/// with_server("hello from trillium", |url| async move {
29///     let client = Client::new(client_config());
30///     let mut conn = client.get(url).await?;
31///     let body = conn.response_body(); //<-
32///     assert_eq!(Some(19), body.content_length());
33///     assert_eq!("hello from trillium", body.read_string().await?);
34///     Ok(())
35/// });
36/// ```
37///
38/// ## Bounds checking
39///
40/// Every `ResponseBody` has a maximum length beyond which it will return an error, expressed as a
41/// u64. To override this on the specific `ResponseBody`, use [`ResponseBody::with_max_len`] or
42/// [`ResponseBody::set_max_len`]. The bound is enforced on override bodies as well as
43/// transport-backed ones, so a user-set memory cap holds even when middleware has replaced the
44/// body with externally-sourced bytes.
45pub struct ResponseBody<'a> {
46    inner: ResponseBodyInner<'a>,
47    /// Set on `'static` Received bodies built via
48    /// [`Conn::take_response_body`][crate::Conn::take_response_body]. `recycle` and `Drop`
49    /// consult it to decide whether to drain (keepalive) or close the underlying transport.
50    /// `None` for borrowed bodies (the conn handles their cleanup) and for Override bodies (no
51    /// transport is attached at this layer — `take_response_body` already evicted any leftover
52    /// transport before returning).
53    cleanup: Option<CleanupContext>,
54    /// Trailers harvested off the inner [`ReceivedBody`] when it reaches `End`. The
55    /// EOF-driven recycle in `poll_read` moves the `ReceivedBody` out before the caller can
56    /// observe its trailers, so they're captured here to outlive it and surfaced through
57    /// [`BodySource::trailers`].
58    trailers: Option<Headers>,
59}
60
61#[allow(clippy::large_enum_variant)]
62enum ResponseBodyInner<'a> {
63    Received(ReceivedBody<'a, Box<dyn Transport>>),
64    Override(OverrideBody<'a>),
65    Closing(Pin<Box<dyn Future<Output = ()> + Send + Sync + 'static>>),
66    Closed,
67}
68
69type H1Pool = Pool<Origin, Box<dyn Transport>>;
70
71/// Carries everything `Drop for ResponseBody` and [`ResponseBody::recycle`] need to release
72/// a transport without re-deriving from a [`crate::Conn`] (which is gone by then).
73///
74/// `pool_origin: Some` means "keepalive transport, pool is configured — insert here on
75/// completion." `None` means "close on completion" (non-keepalive *or* no pool). The same
76/// instance is cloned into the body's `on_completion` callback in
77/// [`Conn::take_received_body`][crate::Conn::take_received_body], so the user-driven
78/// drain/read paths and the Drop/recycle paths share one source of truth for what to do
79/// with the transport when the body finishes.
80#[derive(Clone)]
81pub(crate) struct CleanupContext {
82    pub(crate) runtime: Runtime,
83    pub(crate) h1_pool_origin: Option<(H1Pool, Origin)>,
84}
85
86impl CleanupContext {
87    /// Hand a freshly-released transport off to its destination — pool insert (sync) or
88    /// spawn close.
89    pub(crate) fn handoff(&self, mut transport: Box<dyn Transport>) {
90        match &self.h1_pool_origin {
91            Some((pool, origin)) => {
92                log::trace!("body transferred, returning to pool");
93                pool.insert(origin.clone(), PoolEntry::new(transport, None));
94            }
95            None => {
96                self.runtime.clone().spawn(async move {
97                    log_close_result(transport.close().await);
98                });
99            }
100        }
101    }
102}
103
104pub(crate) struct OverrideBody<'a> {
105    body: MutCow<'a, Body>,
106    encoding: &'static Encoding,
107    max_len: u64,
108    initial_len: usize,
109    max_preallocate: usize,
110}
111
112impl AsyncRead for OverrideBody<'_> {
113    fn poll_read(
114        mut self: Pin<&mut Self>,
115        cx: &mut Context<'_>,
116        buf: &mut [u8],
117    ) -> Poll<io::Result<usize>> {
118        let remaining = self.max_len.saturating_sub(self.body.bytes_read());
119        if remaining == 0 && !buf.is_empty() {
120            return Poll::Ready(Err(io::Error::other(Error::ReceivedBodyTooLong(
121                self.max_len,
122            ))));
123        }
124        let cap = remaining.min(buf.len() as u64) as usize;
125        Pin::new(&mut *self.body).poll_read(cx, &mut buf[..cap])
126    }
127}
128
129impl<'a> OverrideBody<'a> {
130    pub(crate) fn new(
131        body: impl Into<MutCow<'a, Body>>,
132        encoding: &'static Encoding,
133        http_config: &HttpConfig,
134    ) -> Self {
135        Self {
136            body: body.into(),
137            encoding,
138            max_len: http_config.received_body_max_len(),
139            max_preallocate: http_config.received_body_max_preallocate(),
140            initial_len: http_config.received_body_initial_len(),
141        }
142    }
143}
144
145impl Debug for ResponseBody<'_> {
146    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
147        match &self.inner {
148            ResponseBodyInner::Received(rb) => f.debug_tuple("ResponseBody").field(rb).finish(),
149            ResponseBodyInner::Override(o) => f
150                .debug_struct("ResponseBody (override)")
151                .field("body", &*o.body)
152                .field("encoding", &o.encoding.name())
153                .field("max_len", &o.max_len)
154                .finish(),
155            ResponseBodyInner::Closing(_) => f.debug_tuple("ResponseBody (closing)").finish(),
156            ResponseBodyInner::Closed => f.debug_tuple("ResponseBody (closed)").finish(),
157        }
158    }
159}
160
161impl AsyncRead for ResponseBody<'_> {
162    fn poll_read(
163        mut self: Pin<&mut Self>,
164        cx: &mut Context<'_>,
165        buf: &mut [u8],
166    ) -> Poll<io::Result<usize>> {
167        let mut bytes = 0;
168        loop {
169            match &mut self.inner {
170                ResponseBodyInner::Received(rb) => bytes = ready!(Pin::new(rb).poll_read(cx, buf))?,
171                ResponseBodyInner::Override(o) => bytes = ready!(Pin::new(o).poll_read(cx, buf))?,
172                ResponseBodyInner::Closing(fut) => {
173                    ready!(fut.as_mut().poll(cx));
174                    self.inner = ResponseBodyInner::Closed;
175                    break;
176                }
177
178                ResponseBodyInner::Closed => break,
179            };
180
181            // Inline transport settlement — see take_received_body's `cleanup` param for
182            // why this isn't done via on_completion.
183            if bytes == 0
184                && let Some((mut rb, cleanup)) = self.prepare_for_recycle()
185                && rb.state() == ReceivedBodyState::End
186                && let Some(mut transport) = rb.take_transport()
187            {
188                self.trailers = Pin::new(&mut rb).trailers();
189                if let Some((pool, origin)) = cleanup.h1_pool_origin {
190                    pool.insert(origin, PoolEntry::new(transport, None));
191                } else {
192                    self.inner = ResponseBodyInner::Closing(Box::pin(async move {
193                        log_close_result(transport.close().await);
194                    }));
195                }
196            } else {
197                break;
198            }
199        }
200
201        Poll::Ready(Ok(bytes))
202    }
203}
204
205impl ResponseBody<'_> {
206    fn take_inner(&mut self) -> ResponseBodyInner<'_> {
207        mem::replace(&mut self.inner, ResponseBodyInner::Closed)
208    }
209
210    fn max_preallocate(&self) -> usize {
211        match &self.inner {
212            ResponseBodyInner::Received(rb) => rb.max_preallocate(),
213            ResponseBodyInner::Override(override_body) => override_body.max_preallocate,
214            _ => 0,
215        }
216    }
217
218    fn max_len(&self) -> u64 {
219        match &self.inner {
220            ResponseBodyInner::Received(rb) => rb.max_len(),
221            ResponseBodyInner::Override(override_body) => override_body.max_len,
222            _ => 0,
223        }
224    }
225
226    fn initial_len(&self) -> usize {
227        match &self.inner {
228            ResponseBodyInner::Received(rb) => rb.initial_len(),
229            ResponseBodyInner::Override(override_body) => override_body.initial_len,
230            _ => 0,
231        }
232    }
233
234    fn encoding(&self) -> &'static Encoding {
235        match &self.inner {
236            ResponseBodyInner::Received(rb) => rb.encoding(),
237            ResponseBodyInner::Override(override_body) => override_body.encoding,
238            _ => encoding_rs::WINDOWS_1252,
239        }
240    }
241
242    /// Similar to [`ResponseBody::read_string`], but returns the raw bytes. This is useful for
243    /// bodies that are not text.
244    ///
245    /// You can use this in conjunction with `encoding` if you need different handling of malformed
246    /// character encoding than the lossy conversion provided by [`ResponseBody::read_string`].
247    ///
248    /// An empty or nonexistent body will yield an empty Vec, not an error.
249    ///
250    /// # Errors
251    ///
252    /// This will return an error if there is an IO error on the underlying transport such as a
253    /// disconnect.
254    ///
255    /// This will also return an error if the length exceeds the maximum length. To configure the
256    /// value on this specific request body, use [`ResponseBody::with_max_len`] or
257    /// [`ResponseBody::set_max_len`]
258    pub async fn read_bytes(mut self) -> Result<Vec<u8>, Error> {
259        let mut vec = if let Some(len) = self.content_length() {
260            if len > self.max_len() {
261                return Err(Error::ReceivedBodyTooLong(self.max_len()));
262            }
263
264            let len =
265                usize::try_from(len).map_err(|_| Error::ReceivedBodyTooLong(self.max_len()))?;
266
267            Vec::with_capacity(len.min(self.max_preallocate()))
268        } else {
269            Vec::with_capacity(self.initial_len())
270        };
271
272        self.read_to_end(&mut vec).await?;
273
274        Ok(vec)
275    }
276
277    /// Reads the entire body to a `String`.
278    ///
279    /// Uses the encoding determined by the content-type (mime) charset. If an encoding problem
280    /// is encountered, the returned `String` will contain utf8 replacement characters.
281    ///
282    /// Note that this can only be performed once per Conn, as the underlying data is not cached
283    /// anywhere. This is the only copy of the body contents.
284    ///
285    /// An empty or nonexistent body will yield an empty String, not an error
286    ///
287    /// # Errors
288    ///
289    /// This will return an error if there is an IO error on the
290    /// underlying transport such as a disconnect
291    ///
292    ///
293    /// This will also return an error if the length exceeds the maximum length. To configure the
294    /// value on this specific response body, use [`ResponseBody::with_max_len`] or
295    /// [`ResponseBody::set_max_len`].
296    pub async fn read_string(self) -> Result<String, Error> {
297        let encoding = self.encoding();
298        let bytes = self.read_bytes().await?;
299        let (s, _, _) = encoding.decode(&bytes);
300        Ok(s.to_string())
301    }
302
303    /// Set the maximum content length to read, returning self
304    ///
305    /// This protects against a memory-use denial-of-service attack wherein an untrusted peer sends
306    /// an unbounded request body. This is especially important when using
307    /// [`ResponseBody::read_string`] and [`ResponseBody::read_bytes`] instead of streaming with
308    /// `AsyncRead`.
309    ///
310    /// The default value can be found documented [in the trillium-http
311    /// crate](https://docs.trillium.rs/trillium_http/struct.HttpConfig#method.received_body_max_len)
312    #[must_use]
313    pub fn with_max_len(mut self, max_len: u64) -> Self {
314        self.set_max_len(max_len);
315        self
316    }
317
318    /// Set the maximum content length to read
319    ///
320    /// This protects against a memory-use denial-of-service attack wherein an untrusted peer sends
321    /// an unbounded request body. This is especially important when using
322    /// [`ResponseBody::read_string`] and [`ResponseBody::read_bytes`] instead of streaming with
323    /// `AsyncRead`.
324    ///
325    /// The default value can be found documented [in the trillium-http
326    /// crate](https://docs.trillium.rs/trillium_http/struct.HttpConfig#method.received_body_max_len)
327    pub fn set_max_len(&mut self, max_len: u64) -> &mut Self {
328        match &mut self.inner {
329            ResponseBodyInner::Received(rb) => {
330                rb.set_max_len(max_len);
331            }
332            ResponseBodyInner::Override(o) => {
333                o.max_len = max_len;
334            }
335            _ => {}
336        }
337        self
338    }
339
340    /// The trailers received after the response body, if any.
341    ///
342    /// Returns `None` until the body has been read to end-of-stream, and only on protocols
343    /// that delivered a trailer section (HTTP/1.1 chunked with trailers, HTTP/2, HTTP/3).
344    /// Reading the body via [`read_string`](Self::read_string)/[`read_bytes`](Self::read_bytes)
345    /// consumes it, so to observe trailers drive the body to completion through its
346    /// [`AsyncRead`](futures_lite::AsyncRead) interface and then call this.
347    pub fn trailers(&self) -> Option<&Headers> {
348        match &self.inner {
349            ResponseBodyInner::Received(rb) => rb.trailers_ref(),
350            // Captured off the inner ReceivedBody when it was recycled at end-of-stream.
351            _ => self.trailers.as_ref(),
352        }
353    }
354
355    /// The content-length of this body, if available.
356    ///
357    /// Usually derived from the content-length header. If the response uses
358    /// transfer-encoding chunked, this will be `None`.
359    pub fn content_length(&self) -> Option<u64> {
360        match &self.inner {
361            ResponseBodyInner::Received(rb) => rb.content_length(),
362            ResponseBodyInner::Override(o) => o.body.len(),
363            _ => None,
364        }
365    }
366
367    fn prepare_for_recycle(
368        &mut self,
369    ) -> Option<(
370        ReceivedBody<'static, Box<dyn Transport + 'static>>,
371        CleanupContext,
372    )> {
373        let cleanup = self.cleanup.take()?;
374
375        let ResponseBodyInner::Received(rb) = self.take_inner() else {
376            return None;
377        };
378
379        let rb = rb.try_into_owned()?;
380
381        Some((rb, cleanup))
382    }
383}
384
385async fn drain(rb: &mut ReceivedBody<'static, Box<dyn Transport + 'static>>) -> io::Result<u64> {
386    let copy_loops_per_yield = rb.copy_loops_per_yield();
387    trillium_http::copy(rb, futures_lite::io::sink(), copy_loops_per_yield).await
388}
389
390/// Report the result of closing a transport we're discarding. `NotConnected` is the expected
391/// "already closed" signal from a finished multiplexed stream — notably h3/QUIC, whose `close`
392/// (unlike h2's) isn't idempotent and errors once the stream has finished — so it's absorbed at
393/// trace. Other errors are unexpected and warned.
394fn log_close_result(result: io::Result<()>) {
395    match result {
396        Ok(()) => {}
397        Err(e) if e.kind() == io::ErrorKind::NotConnected => {
398            log::trace!("transport already closed during recycle: {e}");
399        }
400        Err(e) => log::warn!("transport close failed during recycle: {e}"),
401    }
402}
403
404async fn recycle(
405    mut rb: ReceivedBody<'static, Box<dyn Transport + 'static>>,
406    h1_pool_origin: Option<(H1Pool, Origin)>,
407) {
408    if let Some((pool, origin)) = h1_pool_origin {
409        match drain(&mut rb).await {
410            Ok(drained) => {
411                if rb.state() == ReceivedBodyState::End
412                    && let Some(transport) = rb.take_transport()
413                {
414                    log::trace!(
415                        "drained {drained} bytes, returning transport to pool for {origin:?}"
416                    );
417                    pool.insert(origin, PoolEntry::new(transport, None));
418                    return;
419                }
420            }
421            Err(e) => log::warn!("drain failed during recycle: {e}"),
422        }
423    }
424
425    if let Some(mut transport) = rb.take_transport() {
426        log_close_result(transport.close().await);
427    }
428}
429
430impl Drop for ResponseBody<'_> {
431    fn drop(&mut self) {
432        let Some((mut rb, cleanup)) = self.prepare_for_recycle() else {
433            return;
434        };
435
436        // fast sync path for reclaiming an owned http/1.1 received body that's at End
437        if rb.state() == ReceivedBodyState::End
438            && cleanup.h1_pool_origin.is_some()
439            && let Some(transport) = rb.take_transport()
440            && let Some((pool, origin)) = cleanup.h1_pool_origin
441        {
442            pool.insert(origin, PoolEntry::new(transport, None));
443        } else {
444            cleanup.runtime.spawn(recycle(rb, cleanup.h1_pool_origin));
445        }
446    }
447}
448
449impl BodySource for ResponseBody<'static> {
450    fn trailers(self: Pin<&mut Self>) -> Option<Headers> {
451        let this = self.get_mut();
452        match &mut this.inner {
453            ResponseBodyInner::Received(rb) => Pin::new(rb).trailers(),
454            ResponseBodyInner::Override(o) => o.body.trailers(),
455            // Recycled at EOF — trailers were captured off the ReceivedBody before it was
456            // moved out. See `ResponseBody::trailers`.
457            _ => this.trailers.take(),
458        }
459    }
460}
461
462impl<'a> From<ReceivedBody<'a, Box<dyn Transport>>> for ResponseBody<'a> {
463    fn from(received_body: ReceivedBody<'a, Box<dyn Transport>>) -> Self {
464        Self {
465            inner: ResponseBodyInner::Received(received_body),
466            cleanup: None,
467            trailers: None,
468        }
469    }
470}
471
472impl<'a> From<OverrideBody<'a>> for ResponseBody<'a> {
473    fn from(o: OverrideBody<'a>) -> Self {
474        Self {
475            inner: ResponseBodyInner::Override(o),
476            cleanup: None,
477            trailers: None,
478        }
479    }
480}
481
482impl ResponseBody<'static> {
483    pub(crate) fn received_owned(
484        body: ReceivedBody<'static, Box<dyn Transport>>,
485        cleanup: CleanupContext,
486    ) -> Self {
487        Self {
488            inner: ResponseBodyInner::Received(body),
489            cleanup: Some(cleanup),
490            trailers: None,
491        }
492    }
493
494    /// Drains and pools the underlying transport when worthwhile, closes it otherwise.
495    ///
496    /// Use this to release a keepalive transport synchronously before reissuing a request on
497    /// the same client — the redirect/retry handler pattern. For an h1.1 keepalive transport
498    /// this drives the body to EOF and returns the transport to the pool. For a non-keepalive
499    /// transport this calls `transport.close()` directly without draining (since draining
500    /// would just waste bytes on a connection we're about to close).
501    ///
502    /// For an Override body (cache hit, mocked response, tee), this is a no-op — the body's
503    /// own components handle their own teardown when dropped.
504    pub async fn recycle(mut self) {
505        let Some((rb, cleanup)) = self.prepare_for_recycle() else {
506            return;
507        };
508
509        recycle(rb, cleanup.h1_pool_origin).await;
510    }
511}
512
513impl<'a> IntoFuture for ResponseBody<'a> {
514    type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send + 'a>>;
515    type Output = trillium_http::Result<String>;
516
517    fn into_future(self) -> Self::IntoFuture {
518        Box::pin(async move { self.read_string().await })
519    }
520}
521
522const _: fn() = || {
523    fn assert_send_sync<T: Send + Sync + ?Sized>() {}
524    assert_send_sync::<ResponseBody<'static>>();
525    assert_send_sync::<ResponseBody<'_>>();
526};