Skip to main content

trillium_client/
conn.rs

1use crate::{Pool, ResponseBody, h3::H3ClientState, util::encoding};
2use encoding_rs::Encoding;
3use std::{borrow::Cow, net::SocketAddr, sync::Arc, time::Duration};
4use trillium_http::{
5    Body, Buffer, HeaderName, HeaderValues, Headers, HttpContext, Method, ReceivedBody,
6    ReceivedBodyState, Status, TypeSet, Version, h3::H3Connection,
7};
8use trillium_server_common::{
9    ArcedConnector, Transport,
10    url::{Origin, Url},
11};
12
13mod h1;
14mod h3;
15mod shared;
16mod unexpected_status_error;
17
18#[cfg(any(feature = "serde_json", feature = "sonic-rs"))]
19pub use shared::ClientSerdeError;
20pub use unexpected_status_error::UnexpectedStatusError;
21
22/// a client connection, representing both an outbound http request and a
23/// http response
24#[must_use]
25#[derive(fieldwork::Fieldwork)]
26pub struct Conn {
27    pub(crate) pool: Option<Pool<Origin, Box<dyn Transport>>>,
28    pub(crate) h3_client_state: Option<H3ClientState>,
29    pub(crate) h3_connection: Option<(Arc<H3Connection>, u64)>,
30    pub(crate) buffer: Buffer,
31    pub(crate) response_body_state: ReceivedBodyState,
32    pub(crate) config: ArcedConnector,
33    pub(crate) headers_finalized: bool,
34    pub(crate) max_head_length: usize,
35    pub(crate) state: TypeSet,
36    pub(crate) context: Arc<HttpContext>,
37
38    /// the transport for this conn
39    ///
40    /// This should only be used to call your own custom methods on the transport that do not read
41    /// or write any data. Calling any method that reads from or writes to the transport will
42    /// disrupt the HTTP protocol.
43    #[field(get, get_mut)]
44    pub(crate) transport: Option<Box<dyn Transport>>,
45
46    /// the url for this conn.
47    ///
48    /// ```
49    /// use trillium_client::{Client, Method};
50    /// use trillium_testing::client_config;
51    ///
52    /// let client = Client::from(client_config());
53    ///
54    /// let conn = client.get("http://localhost:9080");
55    ///
56    /// let url = conn.url(); //<-
57    ///
58    /// assert_eq!(url.host_str().unwrap(), "localhost");
59    /// ```
60    #[field(get, set, get_mut)]
61    pub(crate) url: Url,
62
63    /// the method for this conn.
64    ///
65    /// ```
66    /// use trillium_client::{Client, Method};
67    /// use trillium_testing::client_config;
68    ///
69    /// let client = Client::from(client_config());
70    /// let conn = client.get("http://localhost:9080");
71    ///
72    /// let method = conn.method(); //<-
73    ///
74    /// assert_eq!(method, Method::Get);
75    /// ```
76    #[field(get, set, copy)]
77    pub(crate) method: Method,
78
79    /// the request headers
80    #[field(get, get_mut)]
81    pub(crate) request_headers: Headers,
82
83    #[field(get, get_mut)]
84    /// the response headers
85    pub(crate) response_headers: Headers,
86
87    /// the status code for this conn.
88    ///
89    /// If the conn has not yet been sent, this will be None.
90    ///
91    /// ```
92    /// use trillium_client::{Client, Status};
93    /// use trillium_testing::{client_config, with_server};
94    ///
95    /// async fn handler(conn: trillium::Conn) -> trillium::Conn {
96    ///     conn.with_status(418)
97    /// }
98    ///
99    /// with_server(handler, |url| async move {
100    ///     let client = Client::new(client_config());
101    ///     let conn = client.get(url).await?;
102    ///     assert_eq!(Status::ImATeapot, conn.status().unwrap());
103    ///     Ok(())
104    /// });
105    /// ```
106    #[field(get, copy)]
107    pub(crate) status: Option<Status>,
108
109    /// the request body
110    ///
111    /// ```
112    /// env_logger::init();
113    /// use trillium_client::Client;
114    /// use trillium_testing::{client_config, with_server};
115    ///
116    /// let handler = |mut conn: trillium::Conn| async move {
117    ///     let body = conn.request_body_string().await.unwrap();
118    ///     conn.ok(format!("request body was: {}", body))
119    /// };
120    ///
121    /// with_server(handler, |url| async move {
122    ///     let client = Client::from(client_config());
123    ///     let mut conn = client
124    ///         .post(url)
125    ///         .with_body("body") //<-
126    ///         .await?;
127    ///
128    ///     assert_eq!(
129    ///         conn.response_body().read_string().await?,
130    ///         "request body was: body"
131    ///     );
132    ///     Ok(())
133    /// });
134    /// ```
135    #[field(with = with_body, argument = body, set, into, take, option_set_some)]
136    pub(crate) request_body: Option<Body>,
137
138    /// the timeout for this conn
139    ///
140    /// this can also be set on the client with [`Client::set_timeout`](crate::Client::set_timeout)
141    /// and [`Client::with_timeout`](crate::Client::with_timeout)
142    #[field(with, set, get, get_mut, take, copy, option_set_some)]
143    pub(crate) timeout: Option<Duration>,
144
145    /// the http version for this conn
146    ///
147    /// prior to conn execution, this reflects the intended http version that will be sent, and
148    /// after execution this reflects the server-indicated http version
149    #[field(get, set, with, copy)]
150    pub(crate) http_version: Version,
151
152    /// the :authority pseudo-header, populated during h3 header finalization
153    #[field(get)]
154    pub(crate) authority: Option<Cow<'static, str>>,
155    /// the :scheme pseudo-header, populated during h3 header finalization
156
157    #[field(get)]
158    pub(crate) scheme: Option<Cow<'static, str>>,
159
160    /// the :path pseudo-header, populated during h3 header finalization
161    #[field(get)]
162    pub(crate) path: Option<Cow<'static, str>>,
163
164    /// an explicit request target override, used only for `OPTIONS *` and `CONNECT host:port`
165    ///
166    /// When set and the method is OPTIONS or CONNECT, this value is used as the HTTP request
167    /// target instead of deriving it from the url. For all other methods, this field is ignored.
168    #[field(with, set, get, option_set_some, into)]
169    pub(crate) request_target: Option<Cow<'static, str>>,
170
171    /// trailers sent with the request body, populated after the body has been fully sent.
172    ///
173    /// Only present when the request body was constructed with [`Body::new_with_trailers`] and
174    /// the body has been fully sent. For H3, this is populated after `send_h3_request`; for H1,
175    /// after `send_body` with a chunked body.
176    #[field(get)]
177    pub(crate) request_trailers: Option<Headers>,
178
179    /// trailers received with the response body, populated after the response body has been fully
180    /// read.
181    ///
182    /// For H3, these are decoded from the trailing HEADERS frame. For H1, from chunked trailers
183    /// (once H1 trailer receive is implemented).
184    #[field(get)]
185    pub(crate) response_trailers: Option<Headers>,
186}
187
188/// default http user-agent header
189pub const USER_AGENT: &str = concat!("trillium-client/", env!("CARGO_PKG_VERSION"));
190
191impl Conn {
192    /// chainable setter for [`inserting`](Headers::insert) a request header
193    ///
194    /// ```
195    /// use trillium_client::Client;
196    /// use trillium_testing::{client_config, with_server};
197    ///
198    /// let handler = |conn: trillium::Conn| async move {
199    ///     let header = conn
200    ///         .request_headers()
201    ///         .get_str("some-request-header")
202    ///         .unwrap_or_default();
203    ///     let response = format!("some-request-header was {}", header);
204    ///     conn.ok(response)
205    /// };
206    ///
207    /// with_server(handler, |url| async move {
208    ///     let client = Client::new(client_config());
209    ///     let mut conn = client
210    ///         .get(url)
211    ///         .with_request_header("some-request-header", "header-value") // <--
212    ///         .await?;
213    ///     assert_eq!(
214    ///         conn.response_body().read_string().await?,
215    ///         "some-request-header was header-value"
216    ///     );
217    ///     Ok(())
218    /// })
219    /// ```
220    pub fn with_request_header(
221        mut self,
222        name: impl Into<HeaderName<'static>>,
223        value: impl Into<HeaderValues>,
224    ) -> Self {
225        self.request_headers.insert(name, value);
226        self
227    }
228
229    /// chainable setter for `extending` request headers
230    ///
231    /// ```
232    /// use trillium_client::Client;
233    /// use trillium_testing::{client_config, with_server};
234    ///
235    /// let handler = |conn: trillium::Conn| async move {
236    ///     let header = conn
237    ///         .request_headers()
238    ///         .get_str("some-request-header")
239    ///         .unwrap_or_default();
240    ///     let response = format!("some-request-header was {}", header);
241    ///     conn.ok(response)
242    /// };
243    ///
244    /// with_server(handler, move |url| async move {
245    ///     let client = Client::new(client_config());
246    ///     let mut conn = client
247    ///         .get(url)
248    ///         .with_request_headers([
249    ///             ("some-request-header", "header-value"),
250    ///             ("some-other-req-header", "other-header-value"),
251    ///         ])
252    ///         .await?;
253    ///
254    ///     assert_eq!(
255    ///         conn.response_body().read_string().await?,
256    ///         "some-request-header was header-value"
257    ///     );
258    ///     Ok(())
259    /// })
260    /// ```
261    pub fn with_request_headers<HN, HV, I>(mut self, headers: I) -> Self
262    where
263        I: IntoIterator<Item = (HN, HV)> + Send,
264        HN: Into<HeaderName<'static>>,
265        HV: Into<HeaderValues>,
266    {
267        self.request_headers.extend(headers);
268        self
269    }
270
271    /// Chainable method to remove a request header if present
272    pub fn without_request_header(mut self, name: impl Into<HeaderName<'static>>) -> Self {
273        self.request_headers.remove(name);
274        self
275    }
276
277    /// chainable setter for json body. this requires the `serde_json` crate feature to be enabled.
278    #[cfg(feature = "serde_json")]
279    pub fn with_json_body(self, body: &impl serde::Serialize) -> serde_json::Result<Self> {
280        use trillium_http::KnownHeaderName;
281
282        Ok(self
283            .with_body(serde_json::to_string(body)?)
284            .with_request_header(KnownHeaderName::ContentType, "application/json"))
285    }
286
287    /// chainable setter for json body. this requires the `sonic-rs` crate feature to be enabled.
288    #[cfg(feature = "sonic-rs")]
289    pub fn with_json_body(self, body: &impl serde::Serialize) -> sonic_rs::Result<Self> {
290        use trillium_http::KnownHeaderName;
291
292        Ok(self
293            .with_body(sonic_rs::to_string(body)?)
294            .with_request_header(KnownHeaderName::ContentType, "application/json"))
295    }
296
297    pub(crate) fn response_encoding(&self) -> &'static Encoding {
298        encoding(&self.response_headers)
299    }
300
301    /// returns a [`ResponseBody`](crate::ResponseBody) that borrows the connection inside this
302    /// conn.
303    /// ```
304    /// use trillium_client::Client;
305    /// use trillium_testing::{client_config, with_server};
306    ///
307    /// let handler = |mut conn: trillium::Conn| async move { conn.ok("hello from trillium") };
308    ///
309    /// with_server(handler, |url| async move {
310    ///     let client = Client::from(client_config());
311    ///     let mut conn = client.get(url).await?;
312    ///
313    ///     let response_body = conn.response_body(); //<-
314    ///
315    ///     assert_eq!(19, response_body.content_length().unwrap());
316    ///     let string = response_body.read_string().await?;
317    ///     assert_eq!("hello from trillium", string);
318    ///     Ok(())
319    /// });
320    /// ```
321    #[allow(clippy::needless_borrow, clippy::needless_borrows_for_generic_args)]
322    pub fn response_body(&mut self) -> ResponseBody<'_> {
323        ReceivedBody::new(
324            self.response_content_length(),
325            &mut self.buffer,
326            self.transport.as_mut().unwrap(),
327            &mut self.response_body_state,
328            None,
329            encoding(&self.response_headers),
330        )
331        .with_trailers(&mut self.response_trailers)
332        .with_h3_connection(self.h3_connection.clone())
333        .into()
334    }
335
336    /// Attempt to deserialize the response body. Note that this consumes the body content.
337    #[cfg(feature = "serde_json")]
338    pub async fn response_json<T>(&mut self) -> Result<T, ClientSerdeError>
339    where
340        T: serde::de::DeserializeOwned,
341    {
342        let body = self.response_body().read_string().await?;
343        Ok(serde_json::from_str(&body)?)
344    }
345
346    /// Attempt to deserialize the response body. Note that this consumes the body content.
347    #[cfg(feature = "sonic-rs")]
348    pub async fn response_json<T>(&mut self) -> Result<T, ClientSerdeError>
349    where
350        T: serde::de::DeserializeOwned,
351    {
352        let body = self.response_body().read_string().await?;
353        Ok(sonic_rs::from_str(&body)?)
354    }
355
356    /// Returns the conn or an [`UnexpectedStatusError`] that contains the conn
357    ///
358    /// ```
359    /// use trillium_client::{Client, Status};
360    /// use trillium_testing::{client_config, with_server};
361    ///
362    /// with_server(Status::NotFound, |url| async move {
363    ///     let client = Client::new(client_config());
364    ///     assert_eq!(
365    ///         client.get(url).await?.success().unwrap_err().to_string(),
366    ///         "expected a success (2xx) status code, but got 404 Not Found"
367    ///     );
368    ///     Ok(())
369    /// });
370    ///
371    /// with_server(Status::Ok, |url| async move {
372    ///     let client = Client::new(client_config());
373    ///     assert!(client.get(url).await?.success().is_ok());
374    ///     Ok(())
375    /// });
376    /// ```
377    pub fn success(self) -> Result<Self, UnexpectedStatusError> {
378        match self.status() {
379            Some(status) if status.is_success() => Ok(self),
380            _ => Err(self.into()),
381        }
382    }
383
384    /// Returns this conn to the connection pool if it is keepalive, and
385    /// closes it otherwise. This will happen asynchronously as a spawned
386    /// task when the conn is dropped, but calling it explicitly allows
387    /// you to block on it and control where it happens.
388    pub async fn recycle(mut self) {
389        if self.is_keep_alive() && self.transport.is_some() && self.pool.is_some() {
390            self.finish_reading_body().await;
391        }
392    }
393
394    /// attempts to retrieve the connected peer address
395    pub fn peer_addr(&self) -> Option<SocketAddr> {
396        self.transport
397            .as_ref()
398            .and_then(|t| t.peer_addr().ok().flatten())
399    }
400
401    /// add state to the client conn and return self
402    pub fn with_state<T: Send + Sync + 'static>(mut self, state: T) -> Self {
403        self.insert_state(state);
404        self
405    }
406
407    /// add state to the client conn, returning any previously set state of this type
408    pub fn insert_state<T: Send + Sync + 'static>(&mut self, state: T) -> Option<T> {
409        self.state.insert(state)
410    }
411
412    /// borrow state
413    pub fn state<T: Send + Sync + 'static>(&self) -> Option<&T> {
414        self.state.get()
415    }
416
417    /// borrow state mutably
418    pub fn state_mut<T: Send + Sync + 'static>(&mut self) -> Option<&mut T> {
419        self.state.get_mut()
420    }
421
422    /// take state
423    pub fn take_state<T: Send + Sync + 'static>(&mut self) -> Option<T> {
424        self.state.take()
425    }
426}