Skip to main content

trillium_client/conn/
shared.rs

1use super::{Body, Conn, ReceivedBody, ReceivedBodyState, Transport, TypeSet, encoding};
2use crate::{Error, Result, Version, pool::PoolEntry};
3use futures_lite::{AsyncWriteExt, io};
4use std::{
5    fmt::{self, Debug, Formatter},
6    future::{Future, IntoFuture},
7    mem,
8    pin::Pin,
9};
10use trillium_http::Upgrade;
11
12/// A wrapper error for [`trillium_http::Error`] or, depending on json serializer feature, either
13/// `sonic_rs::Error` or `serde_json::Error`. Only available when either the `sonic-rs` or
14/// `serde_json` cargo features are enabled.
15#[cfg(any(feature = "serde_json", feature = "sonic-rs"))]
16#[derive(thiserror::Error, Debug)]
17pub enum ClientSerdeError {
18    /// A [`trillium_http::Error`]
19    #[error(transparent)]
20    HttpError(#[from] Error),
21
22    #[cfg(feature = "sonic-rs")]
23    /// A [`sonic_rs::Error`]
24    #[error(transparent)]
25    JsonError(#[from] sonic_rs::Error),
26
27    #[cfg(feature = "serde_json")]
28    /// A [`serde_json::Error`]
29    #[error(transparent)]
30    JsonError(#[from] serde_json::Error),
31}
32
33impl Conn {
34    pub(crate) async fn exec(&mut self) -> Result<()> {
35        match self.http_version {
36            Version::Http0_9 | Version::Http2 => {
37                return Err(Error::UnsupportedVersion(self.http_version));
38            }
39            _ => {}
40        }
41
42        if !self.try_exec_h3().await? {
43            self.exec_h1().await?;
44        }
45
46        Ok(())
47    }
48
49    pub(crate) fn body_len(&self) -> Option<u64> {
50        if let Some(ref body) = self.request_body {
51            body.len()
52        } else {
53            Some(0)
54        }
55    }
56
57    pub(crate) fn finalize_headers(&mut self) -> Result<()> {
58        match self.http_version {
59            Version::Http1_0 | Version::Http1_1 => self.finalize_headers_h1(),
60            Version::Http3 if self.h3_client_state.is_some() => self.finalize_headers_h3(),
61            other => Err(Error::UnsupportedVersion(other)),
62        }
63    }
64}
65
66impl Drop for Conn {
67    fn drop(&mut self) {
68        log::trace!("dropping client conn");
69        let Some(mut transport) = self.transport.take() else {
70            log::trace!("no transport, nothing to do");
71
72            return;
73        };
74
75        if !self.is_keep_alive() {
76            log::trace!("not keep alive, closing");
77
78            self.config
79                .runtime()
80                .clone()
81                .spawn(async move { transport.close().await });
82
83            return;
84        }
85
86        let Ok(Some(peer_addr)) = transport.peer_addr() else {
87            return;
88        };
89        let Some(pool) = self.pool.take() else { return };
90
91        let origin = self.url.origin();
92
93        if self.response_body_state == ReceivedBodyState::End {
94            log::trace!(
95                "response body has been read to completion, checking transport back into pool for \
96                 {}",
97                &peer_addr
98            );
99            pool.insert(origin, PoolEntry::new(transport, None));
100        } else {
101            let content_length = self.response_content_length();
102            let buffer = mem::take(&mut self.buffer);
103            let response_body_state = self.response_body_state;
104            let encoding = encoding(&self.response_headers);
105            self.config.runtime().spawn(async move {
106                let mut response_body = ReceivedBody::new(
107                    content_length,
108                    buffer,
109                    transport,
110                    response_body_state,
111                    None,
112                    encoding,
113                );
114
115                match io::copy(&mut response_body, io::sink()).await {
116                    Ok(bytes) => {
117                        let transport = response_body.take_transport().unwrap();
118                        log::trace!(
119                            "read {} bytes in order to recycle conn for {}",
120                            bytes,
121                            &peer_addr
122                        );
123                        pool.insert(origin, PoolEntry::new(transport, None));
124                    }
125
126                    Err(ioerror) => log::error!("unable to recycle conn due to {}", ioerror),
127                };
128            });
129        }
130    }
131}
132
133impl From<Conn> for Body {
134    fn from(conn: Conn) -> Body {
135        let received_body: ReceivedBody<'static, _> = conn.into();
136        received_body.into()
137    }
138}
139
140impl From<Conn> for ReceivedBody<'static, Box<dyn Transport>> {
141    fn from(mut conn: Conn) -> Self {
142        let _ = conn.finalize_headers();
143        let runtime = conn.config.runtime();
144        let origin = conn.url.origin();
145
146        let on_completion = if conn.is_keep_alive()
147            && let Some(pool) = conn.pool.take()
148        {
149            Box::new(move |transport: Box<dyn Transport>| {
150                log::trace!("body transferred, returning to pool");
151                pool.insert(origin.clone(), PoolEntry::new(transport, None));
152            }) as Box<dyn FnOnce(Box<dyn Transport>) + Send + Sync + 'static>
153        } else {
154            Box::new(move |mut transport: Box<dyn Transport>| {
155                runtime.spawn(async move { transport.close().await });
156            }) as Box<dyn FnOnce(Box<dyn Transport>) + Send + Sync + 'static>
157        };
158
159        ReceivedBody::new(
160            conn.response_content_length(),
161            mem::take(&mut conn.buffer),
162            conn.transport.take().unwrap(),
163            conn.response_body_state,
164            Some(on_completion),
165            conn.response_encoding(),
166        )
167    }
168}
169
170impl From<Conn> for Upgrade<Box<dyn Transport>> {
171    fn from(mut conn: Conn) -> Self {
172        Upgrade::new(
173            mem::take(&mut conn.request_headers),
174            conn.url.path().to_string(),
175            conn.method,
176            conn.transport.take().unwrap(),
177            mem::take(&mut conn.buffer),
178            conn.http_version(),
179        )
180    }
181}
182
183impl IntoFuture for Conn {
184    type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send + 'static>>;
185    type Output = Result<Conn>;
186
187    fn into_future(mut self) -> Self::IntoFuture {
188        Box::pin(async move { (&mut self).await.map(|()| self) })
189    }
190}
191
192impl<'conn> IntoFuture for &'conn mut Conn {
193    type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send + 'conn>>;
194    type Output = Result<()>;
195
196    fn into_future(self) -> Self::IntoFuture {
197        Box::pin(async move {
198            if let Some(duration) = self.timeout {
199                self.config
200                    .runtime()
201                    .timeout(duration, self.exec())
202                    .await
203                    .unwrap_or(Err(Error::TimedOut("Conn", duration)))?;
204            } else {
205                self.exec().await?;
206            }
207            Ok(())
208        })
209    }
210}
211
212impl Debug for Conn {
213    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
214        f.debug_struct("Conn")
215            .field("authority", &self.authority)
216            .field("buffer", &String::from_utf8_lossy(&self.buffer))
217            .field("config", &self.config)
218            .field("h3_client_state", &self.h3_client_state)
219            .field("h3_connection", &self.h3_connection)
220            .field("http_version", &self.http_version)
221            .field("method", &self.method)
222            .field("path", &self.path)
223            .field("pool", &self.pool)
224            .field("request_body", &self.request_body)
225            .field("request_headers", &self.request_headers)
226            .field("request_target", &self.request_target)
227            .field("request_trailers", &self.request_trailers)
228            .field("response_body_state", &self.response_body_state)
229            .field("response_headers", &self.response_headers)
230            .field("response_trailers", &self.response_trailers)
231            .field("scheme", &self.scheme)
232            .field("state", &self.state)
233            .field("status", &self.status)
234            .field("url", &self.url)
235            .finish()
236    }
237}
238
239impl AsRef<TypeSet> for Conn {
240    fn as_ref(&self) -> &TypeSet {
241        &self.state
242    }
243}
244
245impl AsMut<TypeSet> for Conn {
246    fn as_mut(&mut self) -> &mut TypeSet {
247        &mut self.state
248    }
249}