Skip to main content

trillium_http/
upgrade.rs

1use crate::{
2    Buffer, Conn, Headers, HttpContext, Method, ProtocolSession, Status, TypeSet, Version,
3    h2::H2Connection, h3::H3Connection, received_body::read_buffered,
4};
5use fieldwork::Fieldwork;
6use futures_lite::{AsyncRead, AsyncWrite};
7use std::{
8    borrow::Cow,
9    fmt::{self, Debug, Formatter},
10    io,
11    net::IpAddr,
12    pin::Pin,
13    str,
14    sync::Arc,
15    task::{self, Poll},
16    time::Instant,
17};
18use trillium_macros::AsyncWrite;
19
20/// An HTTP upgrade — owns the underlying transport along with all the data from the
21/// originating [`Conn`].
22///
23/// **Reading the transport directly**: drain `buffer` first if it has bytes in it. Reading
24/// via the [`AsyncRead`] impl on `Upgrade` handles this automatically.
25#[derive(AsyncWrite, Fieldwork)]
26#[fieldwork(get, get_mut, set, with, take, into_field, rename_predicates)]
27pub struct Upgrade<Transport> {
28    /// The http request headers
29    request_headers: Headers,
30
31    /// The http response headers as set on the underlying [`Conn`] before the upgrade was
32    /// negotiated. These have already been sent to the peer; preserved here for inspection.
33    response_headers: Headers,
34
35    /// The request path
36    #[field(get = false)]
37    path: Cow<'static, str>,
38
39    /// The http request method
40    #[field(copy)]
41    method: Method,
42
43    /// Any state that has been accumulated on the Conn before negotiating the upgrade
44    state: TypeSet,
45
46    /// The underlying io (often a `TcpStream` or similar)
47    #[async_write]
48    transport: Transport,
49
50    /// Any bytes that have been read from the underlying transport already.
51    ///
52    /// It is your responsibility to process these bytes before reading directly from the
53    /// transport.
54    #[field(deref = "[u8]", into_field = false, set = false, with = false)]
55    buffer: Buffer,
56
57    /// The [`HttpContext`] shared for this server
58    #[field(deref = false)]
59    context: Arc<HttpContext>,
60
61    /// the ip address of the connection, if available
62    #[field(copy)]
63    peer_ip: Option<IpAddr>,
64
65    /// the wall-clock time at which the underlying [`Conn`] was constructed. Useful for
66    /// instrumentation that wants elapsed time across the upgrade transition.
67    #[field(copy)]
68    start_time: Instant,
69
70    /// the :authority http/3 pseudo-header
71    authority: Option<Cow<'static, str>>,
72
73    /// the :scheme http/3 pseudo-header
74    scheme: Option<Cow<'static, str>>,
75
76    /// the [`ProtocolSession`] for this upgrade — bundles the per-protocol session state
77    /// (h2/h3 connection driver and stream id) that was attached to the originating Conn.
78    /// `Http1` for upgrades from h1 / synthetic conns.
79    #[field = false]
80    protocol_session: ProtocolSession,
81
82    /// the :protocol http/3 pseudo-header
83    protocol: Option<Cow<'static, str>>,
84
85    /// the http version
86    #[field = "http_version"]
87    version: Version,
88
89    /// the http response status set on the underlying [`Conn`] at the time the upgrade was
90    /// negotiated (typically `101 Switching Protocols` or `200 OK` for CONNECT). `None` if no
91    /// status was set explicitly.
92    #[field(copy)]
93    status: Option<Status>,
94
95    /// whether this connection was deemed secure by the handler stack
96    secure: bool,
97}
98
99impl<Transport> Upgrade<Transport> {
100    #[doc(hidden)]
101    pub fn new(
102        request_headers: Headers,
103        path: impl Into<Cow<'static, str>>,
104        method: Method,
105        transport: Transport,
106        buffer: Buffer,
107        version: Version,
108    ) -> Self {
109        Self {
110            request_headers,
111            response_headers: Headers::new(),
112            path: path.into(),
113            method,
114            transport,
115            buffer,
116            state: TypeSet::new(),
117            context: Arc::default(),
118            peer_ip: None,
119            start_time: Instant::now(),
120            authority: None,
121            scheme: None,
122            protocol_session: ProtocolSession::Http1,
123            protocol: None,
124            secure: false,
125            version,
126            status: None,
127        }
128    }
129
130    /// the [`H2Connection`] driver for this upgrade, if it originated from an HTTP/2 stream
131    pub fn h2_connection(&self) -> Option<&Arc<H2Connection>> {
132        self.protocol_session.h2_connection()
133    }
134
135    /// the h2 stream id for this upgrade, if it originated from an HTTP/2 stream
136    pub fn h2_stream_id(&self) -> Option<u32> {
137        self.protocol_session.h2_stream_id()
138    }
139
140    /// the [`H3Connection`] driver for this upgrade, if it originated from an HTTP/3 stream
141    pub fn h3_connection(&self) -> Option<&Arc<H3Connection>> {
142        self.protocol_session.h3_connection()
143    }
144
145    /// the h3 stream id for this upgrade, if it originated from an HTTP/3 stream
146    pub fn h3_stream_id(&self) -> Option<u64> {
147        self.protocol_session.h3_stream_id()
148    }
149
150    /// Take any buffered bytes
151    pub fn take_buffer(&mut self) -> Vec<u8> {
152        std::mem::take(&mut self.buffer).into()
153    }
154
155    #[doc(hidden)]
156    pub fn buffer_and_transport_mut(&mut self) -> (&mut Buffer, &mut Transport) {
157        (&mut self.buffer, &mut self.transport)
158    }
159
160    /// borrow the shared state [`TypeSet`] for this application
161    pub fn shared_state(&self) -> &TypeSet {
162        self.context.shared_state()
163    }
164
165    /// the http request path up to but excluding any query component
166    pub fn path(&self) -> &str {
167        match self.path.split_once('?') {
168            Some((path, _)) => path,
169            None => &self.path,
170        }
171    }
172
173    /// retrieves the query component of the path
174    pub fn querystring(&self) -> &str {
175        self.path
176            .split_once('?')
177            .map(|(_, query)| query)
178            .unwrap_or_default()
179    }
180
181    /// Modify the transport type of this upgrade.
182    ///
183    /// This is useful for boxing the transport in order to erase the type argument.
184    pub fn map_transport<T: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static>(
185        self,
186        f: impl Fn(Transport) -> T,
187    ) -> Upgrade<T> {
188        // Manual respread: rustc treats `Upgrade<Transport>` and `Upgrade<T>` as disjoint
189        // and rejects `..self` without the unstable `type_changing_struct_update` feature.
190        // If a new field is added to `Upgrade`, update this respread, `Conn::map_transport`
191        // (`conn.rs`), and `From<Conn> for Upgrade` below — they share this drift hazard.
192        Upgrade {
193            transport: f(self.transport),
194            path: self.path,
195            method: self.method,
196            state: self.state,
197            buffer: self.buffer,
198            request_headers: self.request_headers,
199            response_headers: self.response_headers,
200            context: self.context,
201            peer_ip: self.peer_ip,
202            start_time: self.start_time,
203            authority: self.authority,
204            scheme: self.scheme,
205            protocol_session: self.protocol_session,
206            protocol: self.protocol,
207            version: self.version,
208            status: self.status,
209            secure: self.secure,
210        }
211    }
212}
213
214impl<Transport> Debug for Upgrade<Transport> {
215    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
216        f.debug_struct(&format!("Upgrade<{}>", std::any::type_name::<Transport>()))
217            .field("request_headers", &self.request_headers)
218            .field("response_headers", &self.response_headers)
219            .field("path", &self.path)
220            .field("method", &self.method)
221            .field("buffer", &self.buffer)
222            .field("context", &self.context)
223            .field("state", &self.state)
224            .field("transport", &format_args!(".."))
225            .field("peer_ip", &self.peer_ip)
226            .field("start_time", &self.start_time)
227            .field("authority", &self.authority)
228            .field("scheme", &self.scheme)
229            .field("protocol_session", &self.protocol_session)
230            .field("protocol", &self.protocol)
231            .field("version", &self.version)
232            .field("status", &self.status)
233            .field("secure", &self.secure)
234            .finish()
235    }
236}
237
238impl<Transport> From<Conn<Transport>> for Upgrade<Transport> {
239    fn from(conn: Conn<Transport>) -> Self {
240        // Exhaustive destructure (no `..` rest pattern) so that adding a new field to
241        // `Conn` is a compile error here, forcing a deliberate carry-vs-drop decision
242        // for the upgrade transition. The discarded fields below are response-body /
243        // request-body / instrumentation state that is meaningless once the conn has
244        // crossed into the upgrade phase. This shares a drift hazard with
245        // `Conn::map_transport` (`conn.rs`) and `Upgrade::map_transport` above.
246        let Conn {
247            request_headers,
248            response_headers,
249            path,
250            method,
251            state,
252            transport,
253            buffer,
254            context,
255            peer_ip,
256            start_time,
257            authority,
258            scheme,
259            protocol_session,
260            protocol,
261            version,
262            status,
263            secure,
264            // Deliberately dropped — response-body / request-body lifecycle state with
265            // no role on the upgraded transport.
266            response_body: _,
267            request_body_state: _,
268            after_send: _,
269            request_trailers: _,
270        } = conn;
271
272        Self {
273            request_headers,
274            response_headers,
275            path,
276            method,
277            state,
278            transport,
279            buffer,
280            context,
281            peer_ip,
282            start_time,
283            authority,
284            scheme,
285            protocol_session,
286            protocol,
287            version,
288            status,
289            secure,
290        }
291    }
292}
293
294impl<Transport: AsyncRead + Unpin> AsyncRead for Upgrade<Transport> {
295    fn poll_read(
296        mut self: Pin<&mut Self>,
297        cx: &mut task::Context<'_>,
298        buf: &mut [u8],
299    ) -> Poll<io::Result<usize>> {
300        let Self {
301            transport, buffer, ..
302        } = &mut *self;
303        read_buffered(buffer, transport, cx, buf)
304    }
305}