1use crate::{Error, Pool, pool::PoolEntry};
2use encoding_rs::Encoding;
3use futures_lite::{AsyncRead, AsyncReadExt, AsyncWriteExt};
4use std::{
5 fmt::{self, Debug, Formatter},
6 io, mem,
7 pin::Pin,
8 task::{Context, Poll, ready},
9};
10use trillium_http::{
11 Body, BodySource, Headers, HttpConfig, MutCow, ReceivedBody, ReceivedBodyState,
12};
13use trillium_server_common::{Runtime, Transport, url::Origin};
14
15pub struct ResponseBody<'a> {
46 inner: ResponseBodyInner<'a>,
47 cleanup: Option<CleanupContext>,
54}
55
56#[allow(clippy::large_enum_variant)]
57enum ResponseBodyInner<'a> {
58 Received(ReceivedBody<'a, Box<dyn Transport>>),
59 Override(OverrideBody<'a>),
60 Closing(Pin<Box<dyn Future<Output = ()> + Send + Sync + 'static>>),
61 Closed,
62}
63
64type H1Pool = Pool<Origin, Box<dyn Transport>>;
65
66#[derive(Clone)]
76pub(crate) struct CleanupContext {
77 pub(crate) runtime: Runtime,
78 pub(crate) h1_pool_origin: Option<(H1Pool, Origin)>,
79}
80
81impl CleanupContext {
82 pub(crate) fn handoff(&self, mut transport: Box<dyn Transport>) {
85 match &self.h1_pool_origin {
86 Some((pool, origin)) => {
87 log::trace!("body transferred, returning to pool");
88 pool.insert(origin.clone(), PoolEntry::new(transport, None));
89 }
90 None => {
91 self.runtime.clone().spawn(async move {
92 let _ = transport.close().await;
93 });
94 }
95 }
96 }
97}
98
99pub(crate) struct OverrideBody<'a> {
100 body: MutCow<'a, Body>,
101 encoding: &'static Encoding,
102 max_len: u64,
103 initial_len: usize,
104 max_preallocate: usize,
105}
106
107impl AsyncRead for OverrideBody<'_> {
108 fn poll_read(
109 mut self: Pin<&mut Self>,
110 cx: &mut Context<'_>,
111 buf: &mut [u8],
112 ) -> Poll<io::Result<usize>> {
113 let remaining = self.max_len.saturating_sub(self.body.bytes_read());
114 if remaining == 0 && !buf.is_empty() {
115 return Poll::Ready(Err(io::Error::other(Error::ReceivedBodyTooLong(
116 self.max_len,
117 ))));
118 }
119 let cap = remaining.min(buf.len() as u64) as usize;
120 Pin::new(&mut *self.body).poll_read(cx, &mut buf[..cap])
121 }
122}
123
124impl<'a> OverrideBody<'a> {
125 pub(crate) fn new(
126 body: impl Into<MutCow<'a, Body>>,
127 encoding: &'static Encoding,
128 http_config: &HttpConfig,
129 ) -> Self {
130 Self {
131 body: body.into(),
132 encoding,
133 max_len: http_config.received_body_max_len(),
134 max_preallocate: http_config.received_body_max_preallocate(),
135 initial_len: http_config.received_body_initial_len(),
136 }
137 }
138}
139
140impl Debug for ResponseBody<'_> {
141 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
142 match &self.inner {
143 ResponseBodyInner::Received(rb) => f.debug_tuple("ResponseBody").field(rb).finish(),
144 ResponseBodyInner::Override(o) => f
145 .debug_struct("ResponseBody (override)")
146 .field("body", &*o.body)
147 .field("encoding", &o.encoding.name())
148 .field("max_len", &o.max_len)
149 .finish(),
150 ResponseBodyInner::Closing(_) => f.debug_tuple("ResponseBody (closing)").finish(),
151 ResponseBodyInner::Closed => f.debug_tuple("ResponseBody (closed)").finish(),
152 }
153 }
154}
155
156impl AsyncRead for ResponseBody<'_> {
157 fn poll_read(
158 mut self: Pin<&mut Self>,
159 cx: &mut Context<'_>,
160 buf: &mut [u8],
161 ) -> Poll<io::Result<usize>> {
162 let mut bytes = 0;
163 loop {
164 match &mut self.inner {
165 ResponseBodyInner::Received(rb) => bytes = ready!(Pin::new(rb).poll_read(cx, buf))?,
166 ResponseBodyInner::Override(o) => bytes = ready!(Pin::new(o).poll_read(cx, buf))?,
167 ResponseBodyInner::Closing(fut) => {
168 ready!(fut.as_mut().poll(cx));
169 self.inner = ResponseBodyInner::Closed;
170 break;
171 }
172
173 ResponseBodyInner::Closed => break,
174 };
175
176 if bytes == 0
179 && let Some((mut rb, cleanup)) = self.prepare_for_recycle()
180 && rb.state() == ReceivedBodyState::End
181 && let Some(mut transport) = rb.take_transport()
182 {
183 if let Some((pool, origin)) = cleanup.h1_pool_origin {
184 pool.insert(origin, PoolEntry::new(transport, None));
185 } else {
186 self.inner = ResponseBodyInner::Closing(Box::pin(async move {
187 if let Err(e) = transport.close().await {
188 log::warn!("transport close failed during ResponseBody EOF: {e}");
189 }
190 }));
191 }
192 } else {
193 break;
194 }
195 }
196
197 Poll::Ready(Ok(bytes))
198 }
199}
200
201impl ResponseBody<'_> {
202 fn take_inner(&mut self) -> ResponseBodyInner<'_> {
203 mem::replace(&mut self.inner, ResponseBodyInner::Closed)
204 }
205
206 fn max_preallocate(&self) -> usize {
207 match &self.inner {
208 ResponseBodyInner::Received(rb) => rb.max_preallocate(),
209 ResponseBodyInner::Override(override_body) => override_body.max_preallocate,
210 _ => 0,
211 }
212 }
213
214 fn max_len(&self) -> u64 {
215 match &self.inner {
216 ResponseBodyInner::Received(rb) => rb.max_len(),
217 ResponseBodyInner::Override(override_body) => override_body.max_len,
218 _ => 0,
219 }
220 }
221
222 fn initial_len(&self) -> usize {
223 match &self.inner {
224 ResponseBodyInner::Received(rb) => rb.initial_len(),
225 ResponseBodyInner::Override(override_body) => override_body.initial_len,
226 _ => 0,
227 }
228 }
229
230 fn encoding(&self) -> &'static Encoding {
231 match &self.inner {
232 ResponseBodyInner::Received(rb) => rb.encoding(),
233 ResponseBodyInner::Override(override_body) => override_body.encoding,
234 _ => encoding_rs::WINDOWS_1252,
235 }
236 }
237
238 pub async fn read_bytes(mut self) -> Result<Vec<u8>, Error> {
255 let mut vec = if let Some(len) = self.content_length() {
256 if len > self.max_len() {
257 return Err(Error::ReceivedBodyTooLong(self.max_len()));
258 }
259
260 let len =
261 usize::try_from(len).map_err(|_| Error::ReceivedBodyTooLong(self.max_len()))?;
262
263 Vec::with_capacity(len.min(self.max_preallocate()))
264 } else {
265 Vec::with_capacity(self.initial_len())
266 };
267
268 self.read_to_end(&mut vec).await?;
269
270 Ok(vec)
271 }
272
273 pub async fn read_string(self) -> Result<String, Error> {
293 let encoding = self.encoding();
294 let bytes = self.read_bytes().await?;
295 let (s, _, _) = encoding.decode(&bytes);
296 Ok(s.to_string())
297 }
298
299 #[must_use]
309 pub fn with_max_len(mut self, max_len: u64) -> Self {
310 self.set_max_len(max_len);
311 self
312 }
313
314 pub fn set_max_len(&mut self, max_len: u64) -> &mut Self {
324 match &mut self.inner {
325 ResponseBodyInner::Received(rb) => {
326 rb.set_max_len(max_len);
327 }
328 ResponseBodyInner::Override(o) => {
329 o.max_len = max_len;
330 }
331 _ => {}
332 }
333 self
334 }
335
336 pub fn content_length(&self) -> Option<u64> {
341 match &self.inner {
342 ResponseBodyInner::Received(rb) => rb.content_length(),
343 ResponseBodyInner::Override(o) => o.body.len(),
344 _ => None,
345 }
346 }
347
348 fn prepare_for_recycle(
349 &mut self,
350 ) -> Option<(
351 ReceivedBody<'static, Box<dyn Transport + 'static>>,
352 CleanupContext,
353 )> {
354 let cleanup = self.cleanup.take()?;
355
356 let ResponseBodyInner::Received(rb) = self.take_inner() else {
357 return None;
358 };
359
360 let rb = rb.try_into_owned()?;
361
362 Some((rb, cleanup))
363 }
364}
365
366async fn drain(rb: &mut ReceivedBody<'static, Box<dyn Transport + 'static>>) -> io::Result<u64> {
367 let copy_loops_per_yield = rb.copy_loops_per_yield();
368 trillium_http::copy(rb, futures_lite::io::sink(), copy_loops_per_yield).await
369}
370
371async fn recycle(
372 mut rb: ReceivedBody<'static, Box<dyn Transport + 'static>>,
373 h1_pool_origin: Option<(H1Pool, Origin)>,
374) {
375 if let Some((pool, origin)) = h1_pool_origin {
376 match drain(&mut rb).await {
377 Ok(drained) => {
378 if rb.state() == ReceivedBodyState::End
379 && let Some(transport) = rb.take_transport()
380 {
381 log::trace!(
382 "drained {drained} bytes, returning transport to pool for {origin:?}"
383 );
384 pool.insert(origin, PoolEntry::new(transport, None));
385 return;
386 }
387 }
388 Err(e) => log::warn!("drain failed during recycle: {e}"),
389 }
390 }
391
392 if let Some(mut transport) = rb.take_transport()
393 && let Err(e) = transport.close().await
394 {
395 log::warn!("transport close failed during recycle: {e}");
396 }
397}
398
399impl Drop for ResponseBody<'_> {
400 fn drop(&mut self) {
401 let Some((mut rb, cleanup)) = self.prepare_for_recycle() else {
402 return;
403 };
404
405 if rb.state() == ReceivedBodyState::End
407 && cleanup.h1_pool_origin.is_some()
408 && let Some(transport) = rb.take_transport()
409 && let Some((pool, origin)) = cleanup.h1_pool_origin
410 {
411 pool.insert(origin, PoolEntry::new(transport, None));
412 } else {
413 cleanup.runtime.spawn(recycle(rb, cleanup.h1_pool_origin));
414 }
415 }
416}
417
418impl BodySource for ResponseBody<'static> {
419 fn trailers(self: Pin<&mut Self>) -> Option<Headers> {
420 match &mut self.get_mut().inner {
421 ResponseBodyInner::Received(rb) => Pin::new(rb).trailers(),
422 ResponseBodyInner::Override(o) => o.body.trailers(),
423 _ => None,
424 }
425 }
426}
427
428impl<'a> From<ReceivedBody<'a, Box<dyn Transport>>> for ResponseBody<'a> {
429 fn from(received_body: ReceivedBody<'a, Box<dyn Transport>>) -> Self {
430 Self {
431 inner: ResponseBodyInner::Received(received_body),
432 cleanup: None,
433 }
434 }
435}
436
437impl<'a> From<OverrideBody<'a>> for ResponseBody<'a> {
438 fn from(o: OverrideBody<'a>) -> Self {
439 Self {
440 inner: ResponseBodyInner::Override(o),
441 cleanup: None,
442 }
443 }
444}
445
446impl ResponseBody<'static> {
447 pub(crate) fn received_owned(
448 body: ReceivedBody<'static, Box<dyn Transport>>,
449 cleanup: CleanupContext,
450 ) -> Self {
451 Self {
452 inner: ResponseBodyInner::Received(body),
453 cleanup: Some(cleanup),
454 }
455 }
456
457 pub async fn recycle(mut self) {
468 let Some((rb, cleanup)) = self.prepare_for_recycle() else {
469 return;
470 };
471
472 recycle(rb, cleanup.h1_pool_origin).await;
473 }
474}
475
476impl<'a> IntoFuture for ResponseBody<'a> {
477 type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send + 'a>>;
478 type Output = trillium_http::Result<String>;
479
480 fn into_future(self) -> Self::IntoFuture {
481 Box::pin(async move { self.read_string().await })
482 }
483}
484
485const _: fn() = || {
486 fn assert_send_sync<T: Send + Sync + ?Sized>() {}
487 assert_send_sync::<ResponseBody<'static>>();
488 assert_send_sync::<ResponseBody<'_>>();
489};