1use crate::{Error, Pool, pool::PoolEntry};
2use encoding_rs::Encoding;
3use futures_lite::{AsyncRead, AsyncReadExt, AsyncWriteExt};
4use std::{
5 fmt::{self, Debug, Formatter},
6 io, mem,
7 pin::Pin,
8 task::{Context, Poll, ready},
9};
10use trillium_http::{
11 Body, BodySource, Headers, HttpConfig, MutCow, ReceivedBody, ReceivedBodyState,
12};
13use trillium_server_common::{Runtime, Transport, url::Origin};
14
15pub struct ResponseBody<'a> {
46 inner: ResponseBodyInner<'a>,
47 cleanup: Option<CleanupContext>,
54 trailers: Option<Headers>,
59}
60
61#[allow(clippy::large_enum_variant)]
62enum ResponseBodyInner<'a> {
63 Received(ReceivedBody<'a, Box<dyn Transport>>),
64 Override(OverrideBody<'a>),
65 Closing(Pin<Box<dyn Future<Output = ()> + Send + Sync + 'static>>),
66 Closed,
67}
68
69type H1Pool = Pool<Origin, Box<dyn Transport>>;
70
71#[derive(Clone)]
81pub(crate) struct CleanupContext {
82 pub(crate) runtime: Runtime,
83 pub(crate) h1_pool_origin: Option<(H1Pool, Origin)>,
84}
85
86impl CleanupContext {
87 pub(crate) fn handoff(&self, mut transport: Box<dyn Transport>) {
90 match &self.h1_pool_origin {
91 Some((pool, origin)) => {
92 log::trace!("body transferred, returning to pool");
93 pool.insert(origin.clone(), PoolEntry::new(transport, None));
94 }
95 None => {
96 self.runtime.clone().spawn(async move {
97 log_close_result(transport.close().await);
98 });
99 }
100 }
101 }
102}
103
104pub(crate) struct OverrideBody<'a> {
105 body: MutCow<'a, Body>,
106 encoding: &'static Encoding,
107 max_len: u64,
108 initial_len: usize,
109 max_preallocate: usize,
110}
111
112impl AsyncRead for OverrideBody<'_> {
113 fn poll_read(
114 mut self: Pin<&mut Self>,
115 cx: &mut Context<'_>,
116 buf: &mut [u8],
117 ) -> Poll<io::Result<usize>> {
118 let remaining = self.max_len.saturating_sub(self.body.bytes_read());
119 if remaining == 0 && !buf.is_empty() {
120 return Poll::Ready(Err(io::Error::other(Error::ReceivedBodyTooLong(
121 self.max_len,
122 ))));
123 }
124 let cap = remaining.min(buf.len() as u64) as usize;
125 Pin::new(&mut *self.body).poll_read(cx, &mut buf[..cap])
126 }
127}
128
129impl<'a> OverrideBody<'a> {
130 pub(crate) fn new(
131 body: impl Into<MutCow<'a, Body>>,
132 encoding: &'static Encoding,
133 http_config: &HttpConfig,
134 ) -> Self {
135 Self {
136 body: body.into(),
137 encoding,
138 max_len: http_config.received_body_max_len(),
139 max_preallocate: http_config.received_body_max_preallocate(),
140 initial_len: http_config.received_body_initial_len(),
141 }
142 }
143}
144
145impl Debug for ResponseBody<'_> {
146 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
147 match &self.inner {
148 ResponseBodyInner::Received(rb) => f.debug_tuple("ResponseBody").field(rb).finish(),
149 ResponseBodyInner::Override(o) => f
150 .debug_struct("ResponseBody (override)")
151 .field("body", &*o.body)
152 .field("encoding", &o.encoding.name())
153 .field("max_len", &o.max_len)
154 .finish(),
155 ResponseBodyInner::Closing(_) => f.debug_tuple("ResponseBody (closing)").finish(),
156 ResponseBodyInner::Closed => f.debug_tuple("ResponseBody (closed)").finish(),
157 }
158 }
159}
160
161impl AsyncRead for ResponseBody<'_> {
162 fn poll_read(
163 mut self: Pin<&mut Self>,
164 cx: &mut Context<'_>,
165 buf: &mut [u8],
166 ) -> Poll<io::Result<usize>> {
167 let mut bytes = 0;
168 loop {
169 match &mut self.inner {
170 ResponseBodyInner::Received(rb) => bytes = ready!(Pin::new(rb).poll_read(cx, buf))?,
171 ResponseBodyInner::Override(o) => bytes = ready!(Pin::new(o).poll_read(cx, buf))?,
172 ResponseBodyInner::Closing(fut) => {
173 ready!(fut.as_mut().poll(cx));
174 self.inner = ResponseBodyInner::Closed;
175 break;
176 }
177
178 ResponseBodyInner::Closed => break,
179 };
180
181 if bytes == 0
184 && let Some((mut rb, cleanup)) = self.prepare_for_recycle()
185 && rb.state() == ReceivedBodyState::End
186 && let Some(mut transport) = rb.take_transport()
187 {
188 self.trailers = Pin::new(&mut rb).trailers();
189 if let Some((pool, origin)) = cleanup.h1_pool_origin {
190 pool.insert(origin, PoolEntry::new(transport, None));
191 } else {
192 self.inner = ResponseBodyInner::Closing(Box::pin(async move {
193 log_close_result(transport.close().await);
194 }));
195 }
196 } else {
197 break;
198 }
199 }
200
201 Poll::Ready(Ok(bytes))
202 }
203}
204
205impl ResponseBody<'_> {
206 fn take_inner(&mut self) -> ResponseBodyInner<'_> {
207 mem::replace(&mut self.inner, ResponseBodyInner::Closed)
208 }
209
210 fn max_preallocate(&self) -> usize {
211 match &self.inner {
212 ResponseBodyInner::Received(rb) => rb.max_preallocate(),
213 ResponseBodyInner::Override(override_body) => override_body.max_preallocate,
214 _ => 0,
215 }
216 }
217
218 fn max_len(&self) -> u64 {
219 match &self.inner {
220 ResponseBodyInner::Received(rb) => rb.max_len(),
221 ResponseBodyInner::Override(override_body) => override_body.max_len,
222 _ => 0,
223 }
224 }
225
226 fn initial_len(&self) -> usize {
227 match &self.inner {
228 ResponseBodyInner::Received(rb) => rb.initial_len(),
229 ResponseBodyInner::Override(override_body) => override_body.initial_len,
230 _ => 0,
231 }
232 }
233
234 fn encoding(&self) -> &'static Encoding {
235 match &self.inner {
236 ResponseBodyInner::Received(rb) => rb.encoding(),
237 ResponseBodyInner::Override(override_body) => override_body.encoding,
238 _ => encoding_rs::WINDOWS_1252,
239 }
240 }
241
242 pub async fn read_bytes(mut self) -> Result<Vec<u8>, Error> {
259 let mut vec = if let Some(len) = self.content_length() {
260 if len > self.max_len() {
261 return Err(Error::ReceivedBodyTooLong(self.max_len()));
262 }
263
264 let len =
265 usize::try_from(len).map_err(|_| Error::ReceivedBodyTooLong(self.max_len()))?;
266
267 Vec::with_capacity(len.min(self.max_preallocate()))
268 } else {
269 Vec::with_capacity(self.initial_len())
270 };
271
272 self.read_to_end(&mut vec).await?;
273
274 Ok(vec)
275 }
276
277 pub async fn read_string(self) -> Result<String, Error> {
297 let encoding = self.encoding();
298 let bytes = self.read_bytes().await?;
299 let (s, _, _) = encoding.decode(&bytes);
300 Ok(s.to_string())
301 }
302
303 #[must_use]
313 pub fn with_max_len(mut self, max_len: u64) -> Self {
314 self.set_max_len(max_len);
315 self
316 }
317
318 pub fn set_max_len(&mut self, max_len: u64) -> &mut Self {
328 match &mut self.inner {
329 ResponseBodyInner::Received(rb) => {
330 rb.set_max_len(max_len);
331 }
332 ResponseBodyInner::Override(o) => {
333 o.max_len = max_len;
334 }
335 _ => {}
336 }
337 self
338 }
339
340 pub fn trailers(&self) -> Option<&Headers> {
348 match &self.inner {
349 ResponseBodyInner::Received(rb) => rb.trailers_ref(),
350 _ => self.trailers.as_ref(),
352 }
353 }
354
355 pub fn content_length(&self) -> Option<u64> {
360 match &self.inner {
361 ResponseBodyInner::Received(rb) => rb.content_length(),
362 ResponseBodyInner::Override(o) => o.body.len(),
363 _ => None,
364 }
365 }
366
367 fn prepare_for_recycle(
368 &mut self,
369 ) -> Option<(
370 ReceivedBody<'static, Box<dyn Transport + 'static>>,
371 CleanupContext,
372 )> {
373 let cleanup = self.cleanup.take()?;
374
375 let ResponseBodyInner::Received(rb) = self.take_inner() else {
376 return None;
377 };
378
379 let rb = rb.try_into_owned()?;
380
381 Some((rb, cleanup))
382 }
383}
384
385async fn drain(rb: &mut ReceivedBody<'static, Box<dyn Transport + 'static>>) -> io::Result<u64> {
386 let copy_loops_per_yield = rb.copy_loops_per_yield();
387 trillium_http::copy(rb, futures_lite::io::sink(), copy_loops_per_yield).await
388}
389
390fn log_close_result(result: io::Result<()>) {
395 match result {
396 Ok(()) => {}
397 Err(e) if e.kind() == io::ErrorKind::NotConnected => {
398 log::trace!("transport already closed during recycle: {e}");
399 }
400 Err(e) => log::warn!("transport close failed during recycle: {e}"),
401 }
402}
403
404async fn recycle(
405 mut rb: ReceivedBody<'static, Box<dyn Transport + 'static>>,
406 h1_pool_origin: Option<(H1Pool, Origin)>,
407) {
408 if let Some((pool, origin)) = h1_pool_origin {
409 match drain(&mut rb).await {
410 Ok(drained) => {
411 if rb.state() == ReceivedBodyState::End
412 && let Some(transport) = rb.take_transport()
413 {
414 log::trace!(
415 "drained {drained} bytes, returning transport to pool for {origin:?}"
416 );
417 pool.insert(origin, PoolEntry::new(transport, None));
418 return;
419 }
420 }
421 Err(e) => log::warn!("drain failed during recycle: {e}"),
422 }
423 }
424
425 if let Some(mut transport) = rb.take_transport() {
426 log_close_result(transport.close().await);
427 }
428}
429
430impl Drop for ResponseBody<'_> {
431 fn drop(&mut self) {
432 let Some((mut rb, cleanup)) = self.prepare_for_recycle() else {
433 return;
434 };
435
436 if rb.state() == ReceivedBodyState::End
438 && cleanup.h1_pool_origin.is_some()
439 && let Some(transport) = rb.take_transport()
440 && let Some((pool, origin)) = cleanup.h1_pool_origin
441 {
442 pool.insert(origin, PoolEntry::new(transport, None));
443 } else {
444 cleanup.runtime.spawn(recycle(rb, cleanup.h1_pool_origin));
445 }
446 }
447}
448
449impl BodySource for ResponseBody<'static> {
450 fn trailers(self: Pin<&mut Self>) -> Option<Headers> {
451 let this = self.get_mut();
452 match &mut this.inner {
453 ResponseBodyInner::Received(rb) => Pin::new(rb).trailers(),
454 ResponseBodyInner::Override(o) => o.body.trailers(),
455 _ => this.trailers.take(),
458 }
459 }
460}
461
462impl<'a> From<ReceivedBody<'a, Box<dyn Transport>>> for ResponseBody<'a> {
463 fn from(received_body: ReceivedBody<'a, Box<dyn Transport>>) -> Self {
464 Self {
465 inner: ResponseBodyInner::Received(received_body),
466 cleanup: None,
467 trailers: None,
468 }
469 }
470}
471
472impl<'a> From<OverrideBody<'a>> for ResponseBody<'a> {
473 fn from(o: OverrideBody<'a>) -> Self {
474 Self {
475 inner: ResponseBodyInner::Override(o),
476 cleanup: None,
477 trailers: None,
478 }
479 }
480}
481
482impl ResponseBody<'static> {
483 pub(crate) fn received_owned(
484 body: ReceivedBody<'static, Box<dyn Transport>>,
485 cleanup: CleanupContext,
486 ) -> Self {
487 Self {
488 inner: ResponseBodyInner::Received(body),
489 cleanup: Some(cleanup),
490 trailers: None,
491 }
492 }
493
494 pub async fn recycle(mut self) {
505 let Some((rb, cleanup)) = self.prepare_for_recycle() else {
506 return;
507 };
508
509 recycle(rb, cleanup.h1_pool_origin).await;
510 }
511}
512
513impl<'a> IntoFuture for ResponseBody<'a> {
514 type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send + 'a>>;
515 type Output = trillium_http::Result<String>;
516
517 fn into_future(self) -> Self::IntoFuture {
518 Box::pin(async move { self.read_string().await })
519 }
520}
521
522const _: fn() = || {
523 fn assert_send_sync<T: Send + Sync + ?Sized>() {}
524 assert_send_sync::<ResponseBody<'static>>();
525 assert_send_sync::<ResponseBody<'_>>();
526};