1use super::{Body, Conn, Transport, TypeSet};
2use crate::{ClientHandler, ConnExt, Error, Result, Version};
3use smallvec::SmallVec;
4#[cfg(feature = "hickory")]
5use std::net::IpAddr;
6use std::{
7 borrow::Cow,
8 fmt::{self, Debug, Formatter},
9 future::{Future, IntoFuture},
10 mem,
11 net::SocketAddr,
12 pin::Pin,
13};
14use trillium_http::{ProtocolSession, Upgrade};
15use trillium_server_common::Destination;
16
17#[cfg(any(feature = "serde_json", feature = "sonic-rs"))]
21#[derive(thiserror::Error, Debug)]
22pub enum ClientSerdeError {
23 #[error(transparent)]
25 HttpError(#[from] Error),
26
27 #[cfg(feature = "sonic-rs")]
28 #[error(transparent)]
30 JsonError(#[from] sonic_rs::Error),
31
32 #[cfg(feature = "serde_json")]
33 #[error(transparent)]
35 JsonError(#[from] serde_json::Error),
36}
37
38impl Conn {
39 pub(crate) async fn exec(&mut self) -> Result<()> {
40 let handler = self.client.arc_handler().clone();
42 handler.run(self).await?;
43
44 if !self.halted {
45 if let Err(e) = self.exec_network().await {
48 self.error = Some(e);
49 }
50 } else {
51 log::trace!("conn is halted, skipping network round-trip");
52 }
53
54 handler.after_response(self).await?;
56
57 if let Some(e) = self.error.take() {
58 Err(e)
59 } else {
60 Ok(())
61 }
62 }
63
64 async fn exec_network(&mut self) -> Result<()> {
65 if self.http_version == Some(Version::Http0_9) {
66 return Err(Error::UnsupportedVersion(Version::Http0_9));
67 }
68
69 if self.try_reuse_h3_pool().await? {
75 return Ok(());
76 }
77 if self.try_exec_h2_pooled().await? {
78 return Ok(());
79 }
80
81 if self.try_establish_h3().await? {
84 return Ok(());
85 }
86
87 if self.http_version == Some(Version::Http2) {
91 return self.exec_h2_prior_knowledge().await;
92 }
93
94 self.exec_h1_or_promote_h2().await
95 }
96
97 pub(crate) fn body_len(&self) -> Option<u64> {
98 if let Some(ref body) = self.request_body {
99 body.len()
100 } else {
101 Some(0)
102 }
103 }
104
105 pub(crate) fn finalize_headers(&mut self) -> Result<()> {
106 match self.http_version() {
107 Version::Http1_0 | Version::Http1_1 => self.finalize_headers_h1(),
108 Version::Http2 => self.finalize_headers_h2(),
109 Version::Http3 if self.client.h3().is_some() => self.finalize_headers_h3(),
110 other => Err(Error::UnsupportedVersion(other)),
111 }
112 }
113
114 pub(crate) async fn origin_destination(&self) -> Result<Destination> {
123 let mut destination = Destination::from_url(&self.url)?;
124 let addrs = self.origin_socket_addrs().await?;
125 if !addrs.is_empty() {
126 destination.set_addrs(addrs);
127 }
128 match self.http_version {
129 Some(Version::Http1_0 | Version::Http1_1) => {
130 destination.set_alpn([Cow::Borrowed(b"http/1.1".as_slice())]);
131 }
132 Some(Version::Http2) => {
133 destination.set_alpn([Cow::Borrowed(b"h2".as_slice())]);
134 }
135 _ => {}
136 }
137 Ok(destination)
138 }
139
140 pub(crate) async fn origin_socket_addrs(&self) -> Result<SmallVec<[SocketAddr; 4]>> {
144 let Some(host) = self.url.host_str() else {
145 return Ok(SmallVec::new());
146 };
147 let port = self.url.port_or_known_default().unwrap_or(443);
148 self.resolve_socket_addrs(host, port).await
149 }
150}
151
152#[cfg(feature = "hickory")]
153impl Conn {
154 pub(crate) async fn resolve(
167 &self,
168 host: &str,
169 port: u16,
170 ) -> Result<Option<crate::dns::Resolved>> {
171 if host.parse::<IpAddr>().is_ok() {
172 return Ok(None);
173 }
174 match &self.client.resolver {
175 Some(resolver) => Ok(Some(
176 resolver
177 .resolve(&self.client, host, port, self.timeout)
178 .await?,
179 )),
180 None => Ok(None),
181 }
182 }
183
184 pub(crate) async fn resolve_socket_addrs(
185 &self,
186 host: &str,
187 port: u16,
188 ) -> Result<SmallVec<[SocketAddr; 4]>> {
189 Ok(self
190 .resolve(host, port)
191 .await?
192 .map(|resolved| resolved.socket_addrs(port))
193 .unwrap_or_default())
194 }
195}
196
197#[cfg(not(feature = "hickory"))]
198impl Conn {
199 pub(crate) async fn resolve_socket_addrs(
200 &self,
201 _host: &str,
202 _port: u16,
203 ) -> Result<SmallVec<[SocketAddr; 4]>> {
204 Ok(SmallVec::new())
205 }
206}
207
208impl Drop for Conn {
209 fn drop(&mut self) {
210 log::trace!("dropping client conn");
211 drop(self.take_response_body());
212 }
213}
214
215impl From<Conn> for Body {
216 fn from(mut conn: Conn) -> Body {
217 if let Some(body) = conn.body_override.take() {
220 return body;
221 }
222
223 match conn.take_received_body(true) {
224 Some(rb) => rb.into(),
225 None => Body::default(),
226 }
227 }
228}
229
230impl From<Conn> for Upgrade<Box<dyn Transport>> {
231 fn from(mut conn: Conn) -> Self {
240 let path = conn.path.take().unwrap_or_else(|| match conn.url.query() {
243 Some(q) => Cow::Owned(format!("{}?{q}", conn.url.path())),
244 None => Cow::Owned(conn.url.path().to_owned()),
245 });
246 let secure = conn.url.scheme() == "https";
247
248 Upgrade::from_parts(
249 mem::take(&mut conn.response_headers),
250 mem::take(&mut conn.request_headers),
251 path,
252 conn.method,
253 conn.transport
254 .take()
255 .expect("client conn has no transport — request not yet sent"),
256 mem::take(&mut conn.buffer),
257 mem::take(&mut conn.state),
258 conn.context.clone(),
259 None,
260 conn.authority.take(),
261 conn.scheme.take(),
262 mem::replace(&mut conn.protocol_session, ProtocolSession::Http1),
263 conn.protocol.take(),
264 conn.http_version(),
265 conn.status,
266 secure,
267 mem::take(&mut conn.response_body_state),
269 conn.response_trailers.take(),
272 )
273 }
274}
275
276impl IntoFuture for Conn {
277 type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send + 'static>>;
278 type Output = Result<Conn>;
279
280 fn into_future(mut self) -> Self::IntoFuture {
281 Box::pin(async move { (&mut self).await.map(|()| self) })
282 }
283}
284
285impl<'conn> IntoFuture for &'conn mut Conn {
286 type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send + 'conn>>;
287 type Output = Result<()>;
288
289 fn into_future(self) -> Self::IntoFuture {
290 Box::pin(async move {
291 loop {
294 let result = if let Some(duration) = self.timeout {
295 self.client
296 .connector()
297 .runtime()
298 .timeout(duration, self.exec())
299 .await
300 .unwrap_or(Err(Error::TimedOut("Conn", duration)))
301 } else {
302 self.exec().await
303 };
304
305 self.halted = false;
307
308 if let Err(e) = result {
309 self.followup = None;
312 return Err(e);
313 }
314
315 let Some(next) = self.take_followup() else {
316 break;
317 };
318
319 if let Some(body) = self.take_response_body() {
320 body.recycle().await;
321 }
322
323 let _displaced = mem::replace(self, next);
324 }
325 Ok(())
326 })
327 }
328}
329
330impl Debug for Conn {
331 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
332 f.debug_struct("Conn")
333 .field("authority", &self.authority)
334 .field("buffer", &String::from_utf8_lossy(&self.buffer))
335 .field("client", &self.client)
336 .field("protocol_session", &self.protocol_session)
337 .field("http_version", &self.http_version)
338 .field("method", &self.method)
339 .field("path", &self.path)
340 .field("request_body", &self.request_body)
341 .field("request_headers", &self.request_headers)
342 .field("request_target", &self.request_target)
343 .field("request_trailers", &self.request_trailers)
344 .field("response_body_state", &self.response_body_state)
345 .field("response_headers", &self.response_headers)
346 .field("response_trailers", &self.response_trailers)
347 .field("scheme", &self.scheme)
348 .field("state", &self.state)
349 .field("status", &self.status)
350 .field("url", &self.url)
351 .finish()
352 }
353}
354
355impl AsRef<TypeSet> for Conn {
356 fn as_ref(&self) -> &TypeSet {
357 &self.state
358 }
359}
360
361impl AsMut<TypeSet> for Conn {
362 fn as_mut(&mut self) -> &mut TypeSet {
363 &mut self.state
364 }
365}