Skip to main content

trillium_http/
upgrade.rs

1use crate::{
2    Buffer, Conn, Headers, HttpContext, Method, TypeSet, Version, h3::H3Connection,
3    received_body::read_buffered,
4};
5use fieldwork::Fieldwork;
6use futures_lite::{AsyncRead, AsyncWrite};
7use std::{
8    borrow::Cow,
9    fmt::{self, Debug, Formatter},
10    io,
11    net::IpAddr,
12    pin::Pin,
13    str,
14    sync::Arc,
15    task::{self, Poll},
16};
17use trillium_macros::AsyncWrite;
18
19/// This struct represents a http upgrade. It contains all of the data available on a Conn, as well
20/// as owning the underlying transport.
21///
22/// **Important implementation note**: When reading directly from the transport, ensure that you
23/// read from `buffer` first if there are bytes in it. Alternatively, read directly from the
24/// Upgrade, as that [`AsyncRead`] implementation will drain the buffer first before reading from
25/// the transport.
26#[derive(AsyncWrite, Fieldwork)]
27#[fieldwork(get, get_mut, set, with, take, into_field, rename_predicates)]
28pub struct Upgrade<Transport> {
29    /// The http request headers
30    request_headers: Headers,
31
32    /// The request path
33    #[field(get = false)]
34    path: Cow<'static, str>,
35
36    /// The http request method
37    #[field(copy)]
38    method: Method,
39
40    /// Any state that has been accumulated on the Conn before negotiating the upgrade
41    state: TypeSet,
42
43    /// The underlying io (often a `TcpStream` or similar)
44    #[async_write]
45    transport: Transport,
46
47    /// Any bytes that have been read from the underlying transport already.
48    ///
49    /// It is your responsibility to process these bytes before reading directly from the
50    /// transport.
51    #[field(deref = "[u8]", into_field = false, set = false, with = false)]
52    buffer: Buffer,
53
54    /// The [`HttpContext`] shared for this server
55    #[field(deref = false)]
56    context: Arc<HttpContext>,
57
58    /// the ip address of the connection, if available
59    #[field(copy)]
60    peer_ip: Option<IpAddr>,
61
62    /// the :authority http/3 pseudo-header
63    authority: Option<Cow<'static, str>>,
64
65    /// the :scheme http/3 pseudo-header
66    scheme: Option<Cow<'static, str>>,
67
68    /// the HTTP/3 connection associated with this upgrade, if this was an HTTP/3 connection
69    #[field(get(deref = false))]
70    h3_connection: Option<Arc<H3Connection>>,
71
72    /// the :protocol http/3 pseudo-header
73    protocol: Option<Cow<'static, str>>,
74
75    /// the http version
76    #[field = "http_version"]
77    version: Version,
78
79    /// whether this connection was deemed secure by the handler stack
80    secure: bool,
81}
82
83impl<Transport> Upgrade<Transport> {
84    #[doc(hidden)]
85    pub fn new(
86        request_headers: Headers,
87        path: impl Into<Cow<'static, str>>,
88        method: Method,
89        transport: Transport,
90        buffer: Buffer,
91        version: Version,
92    ) -> Self {
93        Self {
94            request_headers,
95            path: path.into(),
96            method,
97            transport,
98            buffer,
99            state: TypeSet::new(),
100            context: Arc::default(),
101            peer_ip: None,
102            authority: None,
103            scheme: None,
104            h3_connection: None,
105            protocol: None,
106            secure: false,
107            version,
108        }
109    }
110
111    /// Take any buffered bytes
112    pub fn take_buffer(&mut self) -> Vec<u8> {
113        std::mem::take(&mut self.buffer).into()
114    }
115
116    #[doc(hidden)]
117    pub fn buffer_and_transport_mut(&mut self) -> (&mut Buffer, &mut Transport) {
118        (&mut self.buffer, &mut self.transport)
119    }
120
121    /// borrow the shared state [`TypeSet`] for this application
122    pub fn shared_state(&self) -> &TypeSet {
123        self.context.shared_state()
124    }
125
126    /// the http request path up to but excluding any query component
127    pub fn path(&self) -> &str {
128        match self.path.split_once('?') {
129            Some((path, _)) => path,
130            None => &self.path,
131        }
132    }
133
134    /// retrieves the query component of the path
135    pub fn querystring(&self) -> &str {
136        self.path
137            .split_once('?')
138            .map(|(_, query)| query)
139            .unwrap_or_default()
140    }
141
142    /// Modify the transport type of this upgrade.
143    ///
144    /// This is useful for boxing the transport in order to erase the type argument.
145    pub fn map_transport<T: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static>(
146        self,
147        f: impl Fn(Transport) -> T,
148    ) -> Upgrade<T> {
149        Upgrade {
150            transport: f(self.transport),
151            path: self.path,
152            method: self.method,
153            state: self.state,
154            buffer: self.buffer,
155            request_headers: self.request_headers,
156            context: self.context,
157            peer_ip: self.peer_ip,
158            authority: self.authority,
159            scheme: self.scheme,
160            h3_connection: self.h3_connection,
161            protocol: self.protocol,
162            version: self.version,
163            secure: self.secure,
164        }
165    }
166}
167
168impl<Transport> Debug for Upgrade<Transport> {
169    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
170        f.debug_struct(&format!("Upgrade<{}>", std::any::type_name::<Transport>()))
171            .field("request_headers", &self.request_headers)
172            .field("path", &self.path)
173            .field("method", &self.method)
174            .field("buffer", &self.buffer)
175            .field("context", &self.context)
176            .field("state", &self.state)
177            .field("transport", &format_args!(".."))
178            .field("peer_ip", &self.peer_ip)
179            .field("authority", &self.authority)
180            .field("scheme", &self.scheme)
181            .field("h3_connection", &self.h3_connection)
182            .field("protocol", &self.protocol)
183            .field("version", &self.version)
184            .field("secure", &self.secure)
185            .finish()
186    }
187}
188
189impl<Transport> From<Conn<Transport>> for Upgrade<Transport> {
190    fn from(conn: Conn<Transport>) -> Self {
191        let Conn {
192            request_headers,
193            path,
194            method,
195            state,
196            transport,
197            buffer,
198            context,
199            peer_ip,
200            authority,
201            scheme,
202            h3_connection,
203            protocol,
204            version,
205            secure,
206            ..
207        } = conn;
208
209        Self {
210            request_headers,
211            path,
212            method,
213            state,
214            transport,
215            buffer,
216            context,
217            peer_ip,
218            authority,
219            scheme,
220            h3_connection,
221            protocol,
222            version,
223            secure,
224        }
225    }
226}
227
228impl<Transport: AsyncRead + Unpin> AsyncRead for Upgrade<Transport> {
229    fn poll_read(
230        mut self: Pin<&mut Self>,
231        cx: &mut task::Context<'_>,
232        buf: &mut [u8],
233    ) -> Poll<io::Result<usize>> {
234        let Self {
235            transport, buffer, ..
236        } = &mut *self;
237        read_buffered(buffer, transport, cx, buf)
238    }
239}