Skip to main content

trillium_client/conn/
shared.rs

1use super::{Body, Conn, Transport, TypeSet};
2use crate::{ClientHandler, ConnExt, Error, Result, Version};
3use std::{
4    fmt::{self, Debug, Formatter},
5    future::{Future, IntoFuture},
6    mem,
7    pin::Pin,
8};
9use trillium_http::Upgrade;
10
11/// A wrapper error for [`trillium_http::Error`] or, depending on json serializer feature, either
12/// `sonic_rs::Error` or `serde_json::Error`. Only available when either the `sonic-rs` or
13/// `serde_json` cargo features are enabled.
14#[cfg(any(feature = "serde_json", feature = "sonic-rs"))]
15#[derive(thiserror::Error, Debug)]
16pub enum ClientSerdeError {
17    /// A [`trillium_http::Error`]
18    #[error(transparent)]
19    HttpError(#[from] Error),
20
21    #[cfg(feature = "sonic-rs")]
22    /// A [`sonic_rs::Error`]
23    #[error(transparent)]
24    JsonError(#[from] sonic_rs::Error),
25
26    #[cfg(feature = "serde_json")]
27    /// A [`serde_json::Error`]
28    #[error(transparent)]
29    JsonError(#[from] serde_json::Error),
30}
31
32impl Conn {
33    pub(crate) async fn exec(&mut self) -> Result<()> {
34        // Arc-clone to dodge conflict with the `&mut self` we pass to `run`.
35        let handler = self.client.handler().clone();
36        handler.run(self).await?;
37
38        if !self.halted {
39            // Stash, don't return: `after_response` runs unconditionally so recovery handlers
40            // (stale-if-error, retry-with-fallback) get a chance to clear it.
41            if let Err(e) = self.exec_network().await {
42                self.error = Some(e);
43            }
44        } else {
45            log::trace!("conn is halted, skipping network round-trip");
46        }
47
48        // Reverse order, regardless of halt/error — mirrors server-side `before_send`.
49        handler.after_response(self).await?;
50
51        if let Some(e) = self.error.take() {
52            Err(e)
53        } else {
54            Ok(())
55        }
56    }
57
58    async fn exec_network(&mut self) -> Result<()> {
59        if matches!(self.http_version, Version::Http0_9) {
60            return Err(Error::UnsupportedVersion(self.http_version));
61        }
62
63        if self.try_exec_h3().await? {
64            return Ok(());
65        }
66        if self.try_exec_h2_pooled().await? {
67            return Ok(());
68        }
69
70        // Prior-knowledge h2: caller asserted h2, skip h1/ALPN. Useful for TLS connectors
71        // that don't expose `negotiated_alpn` (e.g. native-tls). No fallback — a non-h2
72        // server here surfaces as a plain IO error.
73        if self.http_version == Version::Http2 {
74            return self.exec_h2_prior_knowledge().await;
75        }
76
77        self.exec_h1_or_promote_h2().await
78    }
79
80    pub(crate) fn body_len(&self) -> Option<u64> {
81        if let Some(ref body) = self.request_body {
82            body.len()
83        } else {
84            Some(0)
85        }
86    }
87
88    pub(crate) fn finalize_headers(&mut self) -> Result<()> {
89        match self.http_version {
90            Version::Http1_0 | Version::Http1_1 => self.finalize_headers_h1(),
91            Version::Http2 => self.finalize_headers_h2(),
92            Version::Http3 if self.client.h3().is_some() => self.finalize_headers_h3(),
93            other => Err(Error::UnsupportedVersion(other)),
94        }
95    }
96}
97
98impl Drop for Conn {
99    fn drop(&mut self) {
100        log::trace!("dropping client conn");
101        drop(self.take_response_body());
102    }
103}
104
105impl From<Conn> for Body {
106    fn from(mut conn: Conn) -> Body {
107        // body_override (e.g. cache hit, set via `set_response_body`) bypasses the transport;
108        // transport pooling is left to `Drop`.
109        if let Some(body) = conn.body_override.take() {
110            return body;
111        }
112
113        match conn.take_received_body(true) {
114            Some(rb) => rb.into(),
115            None => Body::default(),
116        }
117    }
118}
119
120impl From<Conn> for Upgrade<Box<dyn Transport>> {
121    fn from(mut conn: Conn) -> Self {
122        Upgrade::new(
123            mem::take(&mut conn.request_headers),
124            conn.url.path().to_string(),
125            conn.method,
126            conn.transport.take().unwrap(),
127            mem::take(&mut conn.buffer),
128            conn.http_version(),
129        )
130    }
131}
132
133impl IntoFuture for Conn {
134    type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send + 'static>>;
135    type Output = Result<Conn>;
136
137    fn into_future(mut self) -> Self::IntoFuture {
138        Box::pin(async move { (&mut self).await.map(|()| self) })
139    }
140}
141
142impl<'conn> IntoFuture for &'conn mut Conn {
143    type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send + 'conn>>;
144    type Output = Result<()>;
145
146    fn into_future(self) -> Self::IntoFuture {
147        Box::pin(async move {
148            // Re-issuing handlers (FollowRedirects, retry, auth-refresh) queue a follow-up
149            // via `set_followup` in `after_response`; we recycle, swap, re-exec.
150            loop {
151                let result = if let Some(duration) = self.timeout {
152                    self.client
153                        .connector()
154                        .runtime()
155                        .timeout(duration, self.exec())
156                        .await
157                        .unwrap_or(Err(Error::TimedOut("Conn", duration)))
158                } else {
159                    self.exec().await
160                };
161
162                // `halted` is handler-internal; don't leak it out to the caller.
163                self.halted = false;
164
165                if let Err(e) = result {
166                    // Unrecovered error wins over any queued follow-up. Recovery handlers
167                    // that want the follow-up to run must `take_error()` in `after_response`.
168                    self.followup = None;
169                    return Err(e);
170                }
171
172                let Some(next) = self.take_followup() else {
173                    break;
174                };
175
176                if let Some(body) = self.take_response_body() {
177                    body.recycle().await;
178                }
179
180                let _displaced = mem::replace(self, next);
181            }
182            Ok(())
183        })
184    }
185}
186
187impl Debug for Conn {
188    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
189        f.debug_struct("Conn")
190            .field("authority", &self.authority)
191            .field("buffer", &String::from_utf8_lossy(&self.buffer))
192            .field("client", &self.client)
193            .field("protocol_session", &self.protocol_session)
194            .field("http_version", &self.http_version)
195            .field("method", &self.method)
196            .field("path", &self.path)
197            .field("request_body", &self.request_body)
198            .field("request_headers", &self.request_headers)
199            .field("request_target", &self.request_target)
200            .field("request_trailers", &self.request_trailers)
201            .field("response_body_state", &self.response_body_state)
202            .field("response_headers", &self.response_headers)
203            .field("response_trailers", &self.response_trailers)
204            .field("scheme", &self.scheme)
205            .field("state", &self.state)
206            .field("status", &self.status)
207            .field("url", &self.url)
208            .finish()
209    }
210}
211
212impl AsRef<TypeSet> for Conn {
213    fn as_ref(&self) -> &TypeSet {
214        &self.state
215    }
216}
217
218impl AsMut<TypeSet> for Conn {
219    fn as_mut(&mut self) -> &mut TypeSet {
220        &mut self.state
221    }
222}