1use crate::{
2 Buffer, Conn, Headers, HttpContext, Method, TypeSet, Version, h3::H3Connection,
3 received_body::read_buffered,
4};
5use fieldwork::Fieldwork;
6use futures_lite::{AsyncRead, AsyncWrite};
7use std::{
8 borrow::Cow,
9 fmt::{self, Debug, Formatter},
10 io,
11 net::IpAddr,
12 pin::Pin,
13 str,
14 sync::Arc,
15 task::{self, Poll},
16};
17use trillium_macros::AsyncWrite;
18
19#[derive(AsyncWrite, Fieldwork)]
27#[fieldwork(get, get_mut, set, with, take, into_field, rename_predicates)]
28pub struct Upgrade<Transport> {
29 request_headers: Headers,
31
32 #[field(get = false)]
34 path: Cow<'static, str>,
35
36 #[field(copy)]
38 method: Method,
39
40 state: TypeSet,
42
43 #[async_write]
45 transport: Transport,
46
47 #[field(deref = "[u8]", into_field = false, set = false, with = false)]
52 buffer: Buffer,
53
54 #[field(deref = false)]
56 context: Arc<HttpContext>,
57
58 #[field(copy)]
60 peer_ip: Option<IpAddr>,
61
62 authority: Option<Cow<'static, str>>,
64
65 scheme: Option<Cow<'static, str>>,
67
68 #[field(get(deref = false))]
70 h3_connection: Option<Arc<H3Connection>>,
71
72 protocol: Option<Cow<'static, str>>,
74
75 #[field = "http_version"]
77 version: Version,
78
79 secure: bool,
81}
82
83impl<Transport> Upgrade<Transport> {
84 #[doc(hidden)]
85 pub fn new(
86 request_headers: Headers,
87 path: impl Into<Cow<'static, str>>,
88 method: Method,
89 transport: Transport,
90 buffer: Buffer,
91 version: Version,
92 ) -> Self {
93 Self {
94 request_headers,
95 path: path.into(),
96 method,
97 transport,
98 buffer,
99 state: TypeSet::new(),
100 context: Arc::default(),
101 peer_ip: None,
102 authority: None,
103 scheme: None,
104 h3_connection: None,
105 protocol: None,
106 secure: false,
107 version,
108 }
109 }
110
111 pub fn take_buffer(&mut self) -> Vec<u8> {
113 std::mem::take(&mut self.buffer).into()
114 }
115
116 #[doc(hidden)]
117 pub fn buffer_and_transport_mut(&mut self) -> (&mut Buffer, &mut Transport) {
118 (&mut self.buffer, &mut self.transport)
119 }
120
121 pub fn shared_state(&self) -> &TypeSet {
123 self.context.shared_state()
124 }
125
126 pub fn path(&self) -> &str {
128 match self.path.split_once('?') {
129 Some((path, _)) => path,
130 None => &self.path,
131 }
132 }
133
134 pub fn querystring(&self) -> &str {
136 self.path
137 .split_once('?')
138 .map(|(_, query)| query)
139 .unwrap_or_default()
140 }
141
142 pub fn map_transport<T: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static>(
146 self,
147 f: impl Fn(Transport) -> T,
148 ) -> Upgrade<T> {
149 Upgrade {
150 transport: f(self.transport),
151 path: self.path,
152 method: self.method,
153 state: self.state,
154 buffer: self.buffer,
155 request_headers: self.request_headers,
156 context: self.context,
157 peer_ip: self.peer_ip,
158 authority: self.authority,
159 scheme: self.scheme,
160 h3_connection: self.h3_connection,
161 protocol: self.protocol,
162 version: self.version,
163 secure: self.secure,
164 }
165 }
166}
167
168impl<Transport> Debug for Upgrade<Transport> {
169 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
170 f.debug_struct(&format!("Upgrade<{}>", std::any::type_name::<Transport>()))
171 .field("request_headers", &self.request_headers)
172 .field("path", &self.path)
173 .field("method", &self.method)
174 .field("buffer", &self.buffer)
175 .field("context", &self.context)
176 .field("state", &self.state)
177 .field("transport", &format_args!(".."))
178 .field("peer_ip", &self.peer_ip)
179 .field("authority", &self.authority)
180 .field("scheme", &self.scheme)
181 .field("h3_connection", &self.h3_connection)
182 .field("protocol", &self.protocol)
183 .field("version", &self.version)
184 .field("secure", &self.secure)
185 .finish()
186 }
187}
188
189impl<Transport> From<Conn<Transport>> for Upgrade<Transport> {
190 fn from(conn: Conn<Transport>) -> Self {
191 let Conn {
192 request_headers,
193 path,
194 method,
195 state,
196 transport,
197 buffer,
198 context,
199 peer_ip,
200 authority,
201 scheme,
202 h3_connection,
203 protocol,
204 version,
205 secure,
206 ..
207 } = conn;
208
209 Self {
210 request_headers,
211 path,
212 method,
213 state,
214 transport,
215 buffer,
216 context,
217 peer_ip,
218 authority,
219 scheme,
220 h3_connection,
221 protocol,
222 version,
223 secure,
224 }
225 }
226}
227
228impl<Transport: AsyncRead + Unpin> AsyncRead for Upgrade<Transport> {
229 fn poll_read(
230 mut self: Pin<&mut Self>,
231 cx: &mut task::Context<'_>,
232 buf: &mut [u8],
233 ) -> Poll<io::Result<usize>> {
234 let Self {
235 transport, buffer, ..
236 } = &mut *self;
237 read_buffered(buffer, transport, cx, buf)
238 }
239}