Skip to main content

trillium_http/h3/
body_wrapper.rs

1use crate::{Body, Headers, body::BodyType, h3::Frame};
2use futures_lite::{AsyncRead, ready};
3use std::{io, pin::Pin, task::Poll};
4
5/// This is a temporary wrapper type that will eventually be integrated into Body's `AsyncRead`
6/// through a Version switch, but for now it's easier to keep it distinct
7#[derive(Debug)]
8pub struct H3Body {
9    body: BodyType,
10    /// Whether the DATA frame header has been written for known-length bodies.
11    /// Always false for unknown-length (each poll emits its own frame).
12    header_written: bool,
13}
14
15impl From<BodyType> for H3Body {
16    fn from(body: BodyType) -> Self {
17        Self {
18            body,
19            header_written: false,
20        }
21    }
22}
23
24impl H3Body {
25    pub(crate) fn new(body: Body) -> Self {
26        body.0.into()
27    }
28
29    /// Returns trailers from the body source, if any.
30    ///
31    /// Only meaningful after the body has been fully read (`done == true`).
32    pub fn trailers(&mut self) -> Option<Headers> {
33        match &mut self.body {
34            BodyType::Streaming {
35                async_read, done, ..
36            } if *done => async_read.get_mut().as_mut().trailers(),
37            _ => None,
38        }
39    }
40}
41
42impl AsyncRead for H3Body {
43    fn poll_read(
44        self: Pin<&mut Self>,
45        cx: &mut std::task::Context<'_>,
46        buf: &mut [u8],
47    ) -> Poll<io::Result<usize>> {
48        let this = self.get_mut();
49        match &mut this.body {
50            BodyType::Empty => Poll::Ready(Ok(0)),
51
52            BodyType::Static { content, cursor } => {
53                let remaining = content.len() - *cursor;
54                if remaining == 0 {
55                    return Poll::Ready(Ok(0));
56                }
57
58                let mut written = 0;
59                if !this.header_written {
60                    let frame = Frame::Data(remaining as u64);
61                    written += frame.encode(buf).ok_or_else(|| {
62                        io::Error::new(
63                            io::ErrorKind::WriteZero,
64                            "buffer too small for frame header",
65                        )
66                    })?;
67                    this.header_written = true;
68                }
69
70                let bytes = remaining.min(buf.len() - written);
71                buf[written..written + bytes].copy_from_slice(&content[*cursor..*cursor + bytes]);
72                *cursor += bytes;
73                Poll::Ready(Ok(written + bytes))
74            }
75
76            BodyType::Streaming {
77                async_read,
78                len: Some(len),
79                done,
80                progress,
81            } => {
82                if *done {
83                    return Poll::Ready(Ok(0));
84                }
85
86                let header_len = if this.header_written {
87                    0
88                } else {
89                    Frame::Data(*len).encoded_len()
90                };
91
92                let max_bytes = (*len - *progress)
93                    .try_into()
94                    .unwrap_or(buf.len() - header_len)
95                    .min(buf.len() - header_len);
96
97                let bytes = ready!(
98                    async_read
99                        .get_mut()
100                        .as_mut()
101                        .poll_read(cx, &mut buf[header_len..header_len + max_bytes])
102                )?;
103
104                if !this.header_written {
105                    Frame::Data(*len).encode(buf);
106                    this.header_written = true;
107                }
108
109                if bytes == 0 {
110                    *done = true;
111                } else {
112                    *progress += bytes as u64;
113                }
114
115                Poll::Ready(Ok(header_len + bytes))
116            }
117
118            BodyType::Streaming {
119                async_read,
120                len: None,
121                done,
122                progress,
123            } => {
124                if *done {
125                    return Poll::Ready(Ok(0));
126                }
127
128                let reserved = Frame::Data(buf.len() as u64).encoded_len();
129                if buf.len() <= reserved {
130                    return Poll::Ready(Err(io::Error::new(
131                        io::ErrorKind::WriteZero,
132                        "buffer too small for DATA frame",
133                    )));
134                }
135
136                let bytes = ready!(
137                    async_read
138                        .get_mut()
139                        .as_mut()
140                        .poll_read(cx, &mut buf[reserved..])
141                )?;
142
143                if bytes == 0 {
144                    *done = true;
145                    return Poll::Ready(Ok(0));
146                }
147
148                *progress += bytes as u64;
149
150                let frame = Frame::Data(bytes as u64);
151                let header_len = frame.encode(buf).unwrap();
152                if header_len < reserved {
153                    buf.copy_within(reserved..reserved + bytes, header_len);
154                }
155
156                Poll::Ready(Ok(header_len + bytes))
157            }
158        }
159    }
160}