1use crate::{
2 Buffer, Conn, Headers, HttpContext, Method, ProtocolSession, Status, TypeSet, Version,
3 h2::H2Connection, h3::H3Connection, received_body::read_buffered,
4};
5use fieldwork::Fieldwork;
6use futures_lite::{AsyncRead, AsyncWrite};
7use std::{
8 borrow::Cow,
9 fmt::{self, Debug, Formatter},
10 io,
11 net::IpAddr,
12 pin::Pin,
13 str,
14 sync::Arc,
15 task::{self, Poll},
16 time::Instant,
17};
18use trillium_macros::AsyncWrite;
19
20#[derive(AsyncWrite, Fieldwork)]
26#[fieldwork(get, get_mut, set, with, take, into_field, rename_predicates)]
27pub struct Upgrade<Transport> {
28 request_headers: Headers,
30
31 response_headers: Headers,
34
35 #[field(get = false)]
37 path: Cow<'static, str>,
38
39 #[field(copy)]
41 method: Method,
42
43 state: TypeSet,
45
46 #[async_write]
48 transport: Transport,
49
50 #[field(deref = "[u8]", into_field = false, set = false, with = false)]
55 buffer: Buffer,
56
57 #[field(deref = false)]
59 context: Arc<HttpContext>,
60
61 #[field(copy)]
63 peer_ip: Option<IpAddr>,
64
65 #[field(copy)]
68 start_time: Instant,
69
70 authority: Option<Cow<'static, str>>,
72
73 scheme: Option<Cow<'static, str>>,
75
76 #[field = false]
80 protocol_session: ProtocolSession,
81
82 protocol: Option<Cow<'static, str>>,
84
85 #[field = "http_version"]
87 version: Version,
88
89 #[field(copy)]
93 status: Option<Status>,
94
95 secure: bool,
97}
98
99impl<Transport> Upgrade<Transport> {
100 #[doc(hidden)]
101 pub fn new(
102 request_headers: Headers,
103 path: impl Into<Cow<'static, str>>,
104 method: Method,
105 transport: Transport,
106 buffer: Buffer,
107 version: Version,
108 ) -> Self {
109 Self {
110 request_headers,
111 response_headers: Headers::new(),
112 path: path.into(),
113 method,
114 transport,
115 buffer,
116 state: TypeSet::new(),
117 context: Arc::default(),
118 peer_ip: None,
119 start_time: Instant::now(),
120 authority: None,
121 scheme: None,
122 protocol_session: ProtocolSession::Http1,
123 protocol: None,
124 secure: false,
125 version,
126 status: None,
127 }
128 }
129
130 pub fn h2_connection(&self) -> Option<&Arc<H2Connection>> {
132 self.protocol_session.h2_connection()
133 }
134
135 pub fn h2_stream_id(&self) -> Option<u32> {
137 self.protocol_session.h2_stream_id()
138 }
139
140 pub fn h3_connection(&self) -> Option<&Arc<H3Connection>> {
142 self.protocol_session.h3_connection()
143 }
144
145 pub fn h3_stream_id(&self) -> Option<u64> {
147 self.protocol_session.h3_stream_id()
148 }
149
150 pub fn take_buffer(&mut self) -> Vec<u8> {
152 std::mem::take(&mut self.buffer).into()
153 }
154
155 #[doc(hidden)]
156 pub fn buffer_and_transport_mut(&mut self) -> (&mut Buffer, &mut Transport) {
157 (&mut self.buffer, &mut self.transport)
158 }
159
160 pub fn shared_state(&self) -> &TypeSet {
162 self.context.shared_state()
163 }
164
165 pub fn path(&self) -> &str {
167 match self.path.split_once('?') {
168 Some((path, _)) => path,
169 None => &self.path,
170 }
171 }
172
173 pub fn querystring(&self) -> &str {
175 self.path
176 .split_once('?')
177 .map(|(_, query)| query)
178 .unwrap_or_default()
179 }
180
181 pub fn map_transport<T: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static>(
185 self,
186 f: impl Fn(Transport) -> T,
187 ) -> Upgrade<T> {
188 Upgrade {
193 transport: f(self.transport),
194 path: self.path,
195 method: self.method,
196 state: self.state,
197 buffer: self.buffer,
198 request_headers: self.request_headers,
199 response_headers: self.response_headers,
200 context: self.context,
201 peer_ip: self.peer_ip,
202 start_time: self.start_time,
203 authority: self.authority,
204 scheme: self.scheme,
205 protocol_session: self.protocol_session,
206 protocol: self.protocol,
207 version: self.version,
208 status: self.status,
209 secure: self.secure,
210 }
211 }
212}
213
214impl<Transport> Debug for Upgrade<Transport> {
215 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
216 f.debug_struct(&format!("Upgrade<{}>", std::any::type_name::<Transport>()))
217 .field("request_headers", &self.request_headers)
218 .field("response_headers", &self.response_headers)
219 .field("path", &self.path)
220 .field("method", &self.method)
221 .field("buffer", &self.buffer)
222 .field("context", &self.context)
223 .field("state", &self.state)
224 .field("transport", &format_args!(".."))
225 .field("peer_ip", &self.peer_ip)
226 .field("start_time", &self.start_time)
227 .field("authority", &self.authority)
228 .field("scheme", &self.scheme)
229 .field("protocol_session", &self.protocol_session)
230 .field("protocol", &self.protocol)
231 .field("version", &self.version)
232 .field("status", &self.status)
233 .field("secure", &self.secure)
234 .finish()
235 }
236}
237
238impl<Transport> From<Conn<Transport>> for Upgrade<Transport> {
239 fn from(conn: Conn<Transport>) -> Self {
240 let Conn {
247 request_headers,
248 response_headers,
249 path,
250 method,
251 state,
252 transport,
253 buffer,
254 context,
255 peer_ip,
256 start_time,
257 authority,
258 scheme,
259 protocol_session,
260 protocol,
261 version,
262 status,
263 secure,
264 response_body: _,
267 request_body_state: _,
268 after_send: _,
269 request_trailers: _,
270 } = conn;
271
272 Self {
273 request_headers,
274 response_headers,
275 path,
276 method,
277 state,
278 transport,
279 buffer,
280 context,
281 peer_ip,
282 start_time,
283 authority,
284 scheme,
285 protocol_session,
286 protocol,
287 version,
288 status,
289 secure,
290 }
291 }
292}
293
294impl<Transport: AsyncRead + Unpin> AsyncRead for Upgrade<Transport> {
295 fn poll_read(
296 mut self: Pin<&mut Self>,
297 cx: &mut task::Context<'_>,
298 buf: &mut [u8],
299 ) -> Poll<io::Result<usize>> {
300 let Self {
301 transport, buffer, ..
302 } = &mut *self;
303 read_buffered(buffer, transport, cx, buf)
304 }
305}