Skip to main content

trillium_client/
response_body.rs

1use crate::{Error, Pool, pool::PoolEntry};
2use encoding_rs::Encoding;
3use futures_lite::{AsyncRead, AsyncReadExt, AsyncWriteExt};
4use std::{
5    fmt::{self, Debug, Formatter},
6    io, mem,
7    pin::Pin,
8    task::{Context, Poll, ready},
9};
10use trillium_http::{
11    Body, BodySource, Headers, HttpConfig, MutCow, ReceivedBody, ReceivedBodyState,
12};
13use trillium_server_common::{Runtime, Transport, url::Origin};
14
15/// A response body received from a server.
16///
17/// Most of the time this represents a body that will be read from the underlying transport, but it
18/// can also wrap an override body installed by middleware via [`ConnExt::set_response_body`]
19/// — e.g. cache hits, mocked responses, or circuit-breaker short-circuits. Reads, encoding
20/// handling, and `max_len` enforcement work transparently across both cases.
21///
22/// [`ConnExt::set_response_body`]: crate::ConnExt::set_response_body
23///
24/// ```rust
25/// use trillium_client::Client;
26/// use trillium_testing::{client_config, with_server};
27///
28/// with_server("hello from trillium", |url| async move {
29///     let client = Client::new(client_config());
30///     let mut conn = client.get(url).await?;
31///     let body = conn.response_body(); //<-
32///     assert_eq!(Some(19), body.content_length());
33///     assert_eq!("hello from trillium", body.read_string().await?);
34///     Ok(())
35/// });
36/// ```
37///
38/// ## Bounds checking
39///
40/// Every `ResponseBody` has a maximum length beyond which it will return an error, expressed as a
41/// u64. To override this on the specific `ResponseBody`, use [`ResponseBody::with_max_len`] or
42/// [`ResponseBody::set_max_len`]. The bound is enforced on override bodies as well as
43/// transport-backed ones, so a user-set memory cap holds even when middleware has replaced the
44/// body with externally-sourced bytes.
45pub struct ResponseBody<'a> {
46    inner: ResponseBodyInner<'a>,
47    /// Set on `'static` Received bodies built via
48    /// [`Conn::take_response_body`][crate::Conn::take_response_body]. `recycle` and `Drop`
49    /// consult it to decide whether to drain (keepalive) or close the underlying transport.
50    /// `None` for borrowed bodies (the conn handles their cleanup) and for Override bodies (no
51    /// transport is attached at this layer — `take_response_body` already evicted any leftover
52    /// transport before returning).
53    cleanup: Option<CleanupContext>,
54}
55
56#[allow(clippy::large_enum_variant)]
57enum ResponseBodyInner<'a> {
58    Received(ReceivedBody<'a, Box<dyn Transport>>),
59    Override(OverrideBody<'a>),
60    Closing(Pin<Box<dyn Future<Output = ()> + Send + Sync + 'static>>),
61    Closed,
62}
63
64type H1Pool = Pool<Origin, Box<dyn Transport>>;
65
66/// Carries everything `Drop for ResponseBody` and [`ResponseBody::recycle`] need to release
67/// a transport without re-deriving from a [`crate::Conn`] (which is gone by then).
68///
69/// `pool_origin: Some` means "keepalive transport, pool is configured — insert here on
70/// completion." `None` means "close on completion" (non-keepalive *or* no pool). The same
71/// instance is cloned into the body's `on_completion` callback in
72/// [`Conn::take_received_body`][crate::Conn::take_received_body], so the user-driven
73/// drain/read paths and the Drop/recycle paths share one source of truth for what to do
74/// with the transport when the body finishes.
75#[derive(Clone)]
76pub(crate) struct CleanupContext {
77    pub(crate) runtime: Runtime,
78    pub(crate) h1_pool_origin: Option<(H1Pool, Origin)>,
79}
80
81impl CleanupContext {
82    /// Hand a freshly-released transport off to its destination — pool insert (sync) or
83    /// spawn close.
84    pub(crate) fn handoff(&self, mut transport: Box<dyn Transport>) {
85        match &self.h1_pool_origin {
86            Some((pool, origin)) => {
87                log::trace!("body transferred, returning to pool");
88                pool.insert(origin.clone(), PoolEntry::new(transport, None));
89            }
90            None => {
91                self.runtime.clone().spawn(async move {
92                    let _ = transport.close().await;
93                });
94            }
95        }
96    }
97}
98
99pub(crate) struct OverrideBody<'a> {
100    body: MutCow<'a, Body>,
101    encoding: &'static Encoding,
102    max_len: u64,
103    initial_len: usize,
104    max_preallocate: usize,
105}
106
107impl AsyncRead for OverrideBody<'_> {
108    fn poll_read(
109        mut self: Pin<&mut Self>,
110        cx: &mut Context<'_>,
111        buf: &mut [u8],
112    ) -> Poll<io::Result<usize>> {
113        let remaining = self.max_len.saturating_sub(self.body.bytes_read());
114        if remaining == 0 && !buf.is_empty() {
115            return Poll::Ready(Err(io::Error::other(Error::ReceivedBodyTooLong(
116                self.max_len,
117            ))));
118        }
119        let cap = remaining.min(buf.len() as u64) as usize;
120        Pin::new(&mut *self.body).poll_read(cx, &mut buf[..cap])
121    }
122}
123
124impl<'a> OverrideBody<'a> {
125    pub(crate) fn new(
126        body: impl Into<MutCow<'a, Body>>,
127        encoding: &'static Encoding,
128        http_config: &HttpConfig,
129    ) -> Self {
130        Self {
131            body: body.into(),
132            encoding,
133            max_len: http_config.received_body_max_len(),
134            max_preallocate: http_config.received_body_max_preallocate(),
135            initial_len: http_config.received_body_initial_len(),
136        }
137    }
138}
139
140impl Debug for ResponseBody<'_> {
141    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
142        match &self.inner {
143            ResponseBodyInner::Received(rb) => f.debug_tuple("ResponseBody").field(rb).finish(),
144            ResponseBodyInner::Override(o) => f
145                .debug_struct("ResponseBody (override)")
146                .field("body", &*o.body)
147                .field("encoding", &o.encoding.name())
148                .field("max_len", &o.max_len)
149                .finish(),
150            ResponseBodyInner::Closing(_) => f.debug_tuple("ResponseBody (closing)").finish(),
151            ResponseBodyInner::Closed => f.debug_tuple("ResponseBody (closed)").finish(),
152        }
153    }
154}
155
156impl AsyncRead for ResponseBody<'_> {
157    fn poll_read(
158        mut self: Pin<&mut Self>,
159        cx: &mut Context<'_>,
160        buf: &mut [u8],
161    ) -> Poll<io::Result<usize>> {
162        let mut bytes = 0;
163        loop {
164            match &mut self.inner {
165                ResponseBodyInner::Received(rb) => bytes = ready!(Pin::new(rb).poll_read(cx, buf))?,
166                ResponseBodyInner::Override(o) => bytes = ready!(Pin::new(o).poll_read(cx, buf))?,
167                ResponseBodyInner::Closing(fut) => {
168                    ready!(fut.as_mut().poll(cx));
169                    self.inner = ResponseBodyInner::Closed;
170                    break;
171                }
172
173                ResponseBodyInner::Closed => break,
174            };
175
176            // Inline transport settlement — see take_received_body's `cleanup` param for
177            // why this isn't done via on_completion.
178            if bytes == 0
179                && let Some((mut rb, cleanup)) = self.prepare_for_recycle()
180                && rb.state() == ReceivedBodyState::End
181                && let Some(mut transport) = rb.take_transport()
182            {
183                if let Some((pool, origin)) = cleanup.h1_pool_origin {
184                    pool.insert(origin, PoolEntry::new(transport, None));
185                } else {
186                    self.inner = ResponseBodyInner::Closing(Box::pin(async move {
187                        if let Err(e) = transport.close().await {
188                            log::warn!("transport close failed during ResponseBody EOF: {e}");
189                        }
190                    }));
191                }
192            } else {
193                break;
194            }
195        }
196
197        Poll::Ready(Ok(bytes))
198    }
199}
200
201impl ResponseBody<'_> {
202    fn take_inner(&mut self) -> ResponseBodyInner<'_> {
203        mem::replace(&mut self.inner, ResponseBodyInner::Closed)
204    }
205
206    fn max_preallocate(&self) -> usize {
207        match &self.inner {
208            ResponseBodyInner::Received(rb) => rb.max_preallocate(),
209            ResponseBodyInner::Override(override_body) => override_body.max_preallocate,
210            _ => 0,
211        }
212    }
213
214    fn max_len(&self) -> u64 {
215        match &self.inner {
216            ResponseBodyInner::Received(rb) => rb.max_len(),
217            ResponseBodyInner::Override(override_body) => override_body.max_len,
218            _ => 0,
219        }
220    }
221
222    fn initial_len(&self) -> usize {
223        match &self.inner {
224            ResponseBodyInner::Received(rb) => rb.initial_len(),
225            ResponseBodyInner::Override(override_body) => override_body.initial_len,
226            _ => 0,
227        }
228    }
229
230    fn encoding(&self) -> &'static Encoding {
231        match &self.inner {
232            ResponseBodyInner::Received(rb) => rb.encoding(),
233            ResponseBodyInner::Override(override_body) => override_body.encoding,
234            _ => encoding_rs::WINDOWS_1252,
235        }
236    }
237
238    /// Similar to [`ResponseBody::read_string`], but returns the raw bytes. This is useful for
239    /// bodies that are not text.
240    ///
241    /// You can use this in conjunction with `encoding` if you need different handling of malformed
242    /// character encoding than the lossy conversion provided by [`ResponseBody::read_string`].
243    ///
244    /// An empty or nonexistent body will yield an empty Vec, not an error.
245    ///
246    /// # Errors
247    ///
248    /// This will return an error if there is an IO error on the underlying transport such as a
249    /// disconnect.
250    ///
251    /// This will also return an error if the length exceeds the maximum length. To configure the
252    /// value on this specific request body, use [`ResponseBody::with_max_len`] or
253    /// [`ResponseBody::set_max_len`]
254    pub async fn read_bytes(mut self) -> Result<Vec<u8>, Error> {
255        let mut vec = if let Some(len) = self.content_length() {
256            if len > self.max_len() {
257                return Err(Error::ReceivedBodyTooLong(self.max_len()));
258            }
259
260            let len =
261                usize::try_from(len).map_err(|_| Error::ReceivedBodyTooLong(self.max_len()))?;
262
263            Vec::with_capacity(len.min(self.max_preallocate()))
264        } else {
265            Vec::with_capacity(self.initial_len())
266        };
267
268        self.read_to_end(&mut vec).await?;
269
270        Ok(vec)
271    }
272
273    /// Reads the entire body to a `String`.
274    ///
275    /// Uses the encoding determined by the content-type (mime) charset. If an encoding problem
276    /// is encountered, the returned `String` will contain utf8 replacement characters.
277    ///
278    /// Note that this can only be performed once per Conn, as the underlying data is not cached
279    /// anywhere. This is the only copy of the body contents.
280    ///
281    /// An empty or nonexistent body will yield an empty String, not an error
282    ///
283    /// # Errors
284    ///
285    /// This will return an error if there is an IO error on the
286    /// underlying transport such as a disconnect
287    ///
288    ///
289    /// This will also return an error if the length exceeds the maximum length. To configure the
290    /// value on this specific response body, use [`ResponseBody::with_max_len`] or
291    /// [`ResponseBody::set_max_len`].
292    pub async fn read_string(self) -> Result<String, Error> {
293        let encoding = self.encoding();
294        let bytes = self.read_bytes().await?;
295        let (s, _, _) = encoding.decode(&bytes);
296        Ok(s.to_string())
297    }
298
299    /// Set the maximum content length to read, returning self
300    ///
301    /// This protects against a memory-use denial-of-service attack wherein an untrusted peer sends
302    /// an unbounded request body. This is especially important when using
303    /// [`ResponseBody::read_string`] and [`ResponseBody::read_bytes`] instead of streaming with
304    /// `AsyncRead`.
305    ///
306    /// The default value can be found documented [in the trillium-http
307    /// crate](https://docs.trillium.rs/trillium_http/struct.httpconfig#received_body_max_len)
308    #[must_use]
309    pub fn with_max_len(mut self, max_len: u64) -> Self {
310        self.set_max_len(max_len);
311        self
312    }
313
314    /// Set the maximum content length to read
315    ///
316    /// This protects against a memory-use denial-of-service attack wherein an untrusted peer sends
317    /// an unbounded request body. This is especially important when using
318    /// [`ResponseBody::read_string`] and [`ResponseBody::read_bytes`] instead of streaming with
319    /// `AsyncRead`.
320    ///
321    /// The default value can be found documented [in the trillium-http
322    /// crate](https://docs.trillium.rs/trillium_http/struct.httpconfig#received_body_max_len)
323    pub fn set_max_len(&mut self, max_len: u64) -> &mut Self {
324        match &mut self.inner {
325            ResponseBodyInner::Received(rb) => {
326                rb.set_max_len(max_len);
327            }
328            ResponseBodyInner::Override(o) => {
329                o.max_len = max_len;
330            }
331            _ => {}
332        }
333        self
334    }
335
336    /// The content-length of this body, if available.
337    ///
338    /// Usually derived from the content-length header. If the response uses
339    /// transfer-encoding chunked, this will be `None`.
340    pub fn content_length(&self) -> Option<u64> {
341        match &self.inner {
342            ResponseBodyInner::Received(rb) => rb.content_length(),
343            ResponseBodyInner::Override(o) => o.body.len(),
344            _ => None,
345        }
346    }
347
348    fn prepare_for_recycle(
349        &mut self,
350    ) -> Option<(
351        ReceivedBody<'static, Box<dyn Transport + 'static>>,
352        CleanupContext,
353    )> {
354        let cleanup = self.cleanup.take()?;
355
356        let ResponseBodyInner::Received(rb) = self.take_inner() else {
357            return None;
358        };
359
360        let rb = rb.try_into_owned()?;
361
362        Some((rb, cleanup))
363    }
364}
365
366async fn drain(rb: &mut ReceivedBody<'static, Box<dyn Transport + 'static>>) -> io::Result<u64> {
367    let copy_loops_per_yield = rb.copy_loops_per_yield();
368    trillium_http::copy(rb, futures_lite::io::sink(), copy_loops_per_yield).await
369}
370
371async fn recycle(
372    mut rb: ReceivedBody<'static, Box<dyn Transport + 'static>>,
373    h1_pool_origin: Option<(H1Pool, Origin)>,
374) {
375    if let Some((pool, origin)) = h1_pool_origin {
376        match drain(&mut rb).await {
377            Ok(drained) => {
378                if rb.state() == ReceivedBodyState::End
379                    && let Some(transport) = rb.take_transport()
380                {
381                    log::trace!(
382                        "drained {drained} bytes, returning transport to pool for {origin:?}"
383                    );
384                    pool.insert(origin, PoolEntry::new(transport, None));
385                    return;
386                }
387            }
388            Err(e) => log::warn!("drain failed during recycle: {e}"),
389        }
390    }
391
392    if let Some(mut transport) = rb.take_transport()
393        && let Err(e) = transport.close().await
394    {
395        log::warn!("transport close failed during recycle: {e}");
396    }
397}
398
399impl Drop for ResponseBody<'_> {
400    fn drop(&mut self) {
401        let Some((mut rb, cleanup)) = self.prepare_for_recycle() else {
402            return;
403        };
404
405        // fast sync path for reclaiming an owned http/1.1 received body that's at End
406        if rb.state() == ReceivedBodyState::End
407            && cleanup.h1_pool_origin.is_some()
408            && let Some(transport) = rb.take_transport()
409            && let Some((pool, origin)) = cleanup.h1_pool_origin
410        {
411            pool.insert(origin, PoolEntry::new(transport, None));
412        } else {
413            cleanup.runtime.spawn(recycle(rb, cleanup.h1_pool_origin));
414        }
415    }
416}
417
418impl BodySource for ResponseBody<'static> {
419    fn trailers(self: Pin<&mut Self>) -> Option<Headers> {
420        match &mut self.get_mut().inner {
421            ResponseBodyInner::Received(rb) => Pin::new(rb).trailers(),
422            ResponseBodyInner::Override(o) => o.body.trailers(),
423            _ => None,
424        }
425    }
426}
427
428impl<'a> From<ReceivedBody<'a, Box<dyn Transport>>> for ResponseBody<'a> {
429    fn from(received_body: ReceivedBody<'a, Box<dyn Transport>>) -> Self {
430        Self {
431            inner: ResponseBodyInner::Received(received_body),
432            cleanup: None,
433        }
434    }
435}
436
437impl<'a> From<OverrideBody<'a>> for ResponseBody<'a> {
438    fn from(o: OverrideBody<'a>) -> Self {
439        Self {
440            inner: ResponseBodyInner::Override(o),
441            cleanup: None,
442        }
443    }
444}
445
446impl ResponseBody<'static> {
447    pub(crate) fn received_owned(
448        body: ReceivedBody<'static, Box<dyn Transport>>,
449        cleanup: CleanupContext,
450    ) -> Self {
451        Self {
452            inner: ResponseBodyInner::Received(body),
453            cleanup: Some(cleanup),
454        }
455    }
456
457    /// Drains and pools the underlying transport when worthwhile, closes it otherwise.
458    ///
459    /// Use this to release a keepalive transport synchronously before reissuing a request on
460    /// the same client — the redirect/retry handler pattern. For an h1.1 keepalive transport
461    /// this drives the body to EOF and returns the transport to the pool. For a non-keepalive
462    /// transport this calls `transport.close()` directly without draining (since draining
463    /// would just waste bytes on a connection we're about to close).
464    ///
465    /// For an Override body (cache hit, mocked response, tee), this is a no-op — the body's
466    /// own components handle their own teardown when dropped.
467    pub async fn recycle(mut self) {
468        let Some((rb, cleanup)) = self.prepare_for_recycle() else {
469            return;
470        };
471
472        recycle(rb, cleanup.h1_pool_origin).await;
473    }
474}
475
476impl<'a> IntoFuture for ResponseBody<'a> {
477    type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send + 'a>>;
478    type Output = trillium_http::Result<String>;
479
480    fn into_future(self) -> Self::IntoFuture {
481        Box::pin(async move { self.read_string().await })
482    }
483}
484
485const _: fn() = || {
486    fn assert_send_sync<T: Send + Sync + ?Sized>() {}
487    assert_send_sync::<ResponseBody<'static>>();
488    assert_send_sync::<ResponseBody<'_>>();
489};