Skip to main content

trillium_client/
conn.rs

1use crate::{Pool, ResponseBody, h3::H3ClientState, util::encoding};
2use encoding_rs::Encoding;
3use std::{borrow::Cow, net::SocketAddr, sync::Arc, time::Duration};
4use trillium_http::{
5    Body, Buffer, HeaderName, HeaderValues, Headers, HttpContext, Method, ReceivedBody,
6    ReceivedBodyState, Status, TypeSet, Version,
7};
8use trillium_server_common::{
9    ArcedConnector, Transport,
10    url::{Origin, Url},
11};
12
13mod h1;
14mod h3;
15mod shared;
16mod unexpected_status_error;
17
18#[cfg(any(feature = "serde_json", feature = "sonic-rs"))]
19pub use shared::ClientSerdeError;
20pub use unexpected_status_error::UnexpectedStatusError;
21
22/// a client connection, representing both an outbound http request and a
23/// http response
24#[must_use]
25#[derive(fieldwork::Fieldwork)]
26pub struct Conn {
27    pub(crate) pool: Option<Pool<Origin, Box<dyn Transport>>>,
28    pub(crate) h3: Option<H3ClientState>,
29    pub(crate) buffer: Buffer,
30    pub(crate) response_body_state: ReceivedBodyState,
31    pub(crate) config: ArcedConnector,
32    pub(crate) headers_finalized: bool,
33    pub(crate) max_head_length: usize,
34    pub(crate) state: TypeSet,
35    pub(crate) context: Arc<HttpContext>,
36
37    /// the transport for this conn
38    ///
39    /// This should only be used to call your own custom methods on the transport that do not read
40    /// or write any data. Calling any method that reads from or writes to the transport will
41    /// disrupt the HTTP protocol.
42    #[field(get, get_mut)]
43    pub(crate) transport: Option<Box<dyn Transport>>,
44
45    /// the url for this conn.
46    ///
47    /// ```
48    /// use trillium_client::{Client, Method};
49    /// use trillium_testing::client_config;
50    ///
51    /// let client = Client::from(client_config());
52    ///
53    /// let conn = client.get("http://localhost:9080");
54    ///
55    /// let url = conn.url(); //<-
56    ///
57    /// assert_eq!(url.host_str().unwrap(), "localhost");
58    /// ```
59    #[field(get, set, get_mut)]
60    pub(crate) url: Url,
61
62    /// the method for this conn.
63    ///
64    /// ```
65    /// use trillium_client::{Client, Method};
66    /// use trillium_testing::client_config;
67    ///
68    /// let client = Client::from(client_config());
69    /// let conn = client.get("http://localhost:9080");
70    ///
71    /// let method = conn.method(); //<-
72    ///
73    /// assert_eq!(method, Method::Get);
74    /// ```
75    #[field(get, set, copy)]
76    pub(crate) method: Method,
77
78    /// the request headers
79    #[field(get, get_mut)]
80    pub(crate) request_headers: Headers,
81
82    #[field(get, get_mut)]
83    /// the response headers
84    pub(crate) response_headers: Headers,
85
86    /// the status code for this conn.
87    ///
88    /// If the conn has not yet been sent, this will be None.
89    ///
90    /// ```
91    /// use trillium_client::{Client, Status};
92    /// use trillium_testing::{client_config, with_server};
93    ///
94    /// async fn handler(conn: trillium::Conn) -> trillium::Conn {
95    ///     conn.with_status(418)
96    /// }
97    ///
98    /// with_server(handler, |url| async move {
99    ///     let client = Client::new(client_config());
100    ///     let conn = client.get(url).await?;
101    ///     assert_eq!(Status::ImATeapot, conn.status().unwrap());
102    ///     Ok(())
103    /// });
104    /// ```
105    #[field(get, copy)]
106    pub(crate) status: Option<Status>,
107
108    /// the request body
109    ///
110    /// ```
111    /// env_logger::init();
112    /// use trillium_client::Client;
113    /// use trillium_testing::{client_config, with_server};
114    ///
115    /// let handler = |mut conn: trillium::Conn| async move {
116    ///     let body = conn.request_body_string().await.unwrap();
117    ///     conn.ok(format!("request body was: {}", body))
118    /// };
119    ///
120    /// with_server(handler, |url| async move {
121    ///     let client = Client::from(client_config());
122    ///     let mut conn = client
123    ///         .post(url)
124    ///         .with_body("body") //<-
125    ///         .await?;
126    ///
127    ///     assert_eq!(
128    ///         conn.response_body().read_string().await?,
129    ///         "request body was: body"
130    ///     );
131    ///     Ok(())
132    /// });
133    /// ```
134    #[field(with = with_body, argument = body, set, into, take, option_set_some)]
135    pub(crate) request_body: Option<Body>,
136
137    /// the timeout for this conn
138    ///
139    /// this can also be set on the client with [`Client::set_timeout`](crate::Client::set_timeout)
140    /// and [`Client::with_timeout`](crate::Client::with_timeout)
141    #[field(with, set, get, get_mut, take, copy, option_set_some)]
142    pub(crate) timeout: Option<Duration>,
143
144    /// the http version for this conn
145    ///
146    /// prior to conn execution, this reflects the intended http version that will be sent, and
147    /// after execution this reflects the server-indicated http version
148    #[field(get, set, with, copy)]
149    pub(crate) http_version: Version,
150
151    /// the :authority pseudo-header, populated during h3 header finalization
152    #[field(get)]
153    pub(crate) authority: Option<Cow<'static, str>>,
154    /// the :scheme pseudo-header, populated during h3 header finalization
155
156    #[field(get)]
157    pub(crate) scheme: Option<Cow<'static, str>>,
158
159    /// the :path pseudo-header, populated during h3 header finalization
160    #[field(get)]
161    pub(crate) path: Option<Cow<'static, str>>,
162
163    /// an explicit request target override, used only for `OPTIONS *` and `CONNECT host:port`
164    ///
165    /// When set and the method is OPTIONS or CONNECT, this value is used as the HTTP request
166    /// target instead of deriving it from the url. For all other methods, this field is ignored.
167    #[field(with, set, get, option_set_some, into)]
168    pub(crate) request_target: Option<Cow<'static, str>>,
169
170    /// trailers sent with the request body, populated after the body has been fully sent.
171    ///
172    /// Only present when the request body was constructed with [`Body::new_with_trailers`] and
173    /// the body has been fully sent. For H3, this is populated after `send_h3_request`; for H1,
174    /// after `send_body` with a chunked body.
175    #[field(get)]
176    pub(crate) request_trailers: Option<Headers>,
177
178    /// trailers received with the response body, populated after the response body has been fully
179    /// read.
180    ///
181    /// For H3, these are decoded from the trailing HEADERS frame. For H1, from chunked trailers
182    /// (once H1 trailer receive is implemented).
183    #[field(get)]
184    pub(crate) response_trailers: Option<Headers>,
185}
186
187/// default http user-agent header
188pub const USER_AGENT: &str = concat!("trillium-client/", env!("CARGO_PKG_VERSION"));
189
190impl Conn {
191    /// chainable setter for [`inserting`](Headers::insert) a request header
192    ///
193    /// ```
194    /// use trillium_client::Client;
195    /// use trillium_testing::{client_config, with_server};
196    ///
197    /// let handler = |conn: trillium::Conn| async move {
198    ///     let header = conn
199    ///         .request_headers()
200    ///         .get_str("some-request-header")
201    ///         .unwrap_or_default();
202    ///     let response = format!("some-request-header was {}", header);
203    ///     conn.ok(response)
204    /// };
205    ///
206    /// with_server(handler, |url| async move {
207    ///     let client = Client::new(client_config());
208    ///     let mut conn = client
209    ///         .get(url)
210    ///         .with_request_header("some-request-header", "header-value") // <--
211    ///         .await?;
212    ///     assert_eq!(
213    ///         conn.response_body().read_string().await?,
214    ///         "some-request-header was header-value"
215    ///     );
216    ///     Ok(())
217    /// })
218    /// ```
219    pub fn with_request_header(
220        mut self,
221        name: impl Into<HeaderName<'static>>,
222        value: impl Into<HeaderValues>,
223    ) -> Self {
224        self.request_headers.insert(name, value);
225        self
226    }
227
228    /// chainable setter for `extending` request headers
229    ///
230    /// ```
231    /// use trillium_client::Client;
232    /// use trillium_testing::{client_config, with_server};
233    ///
234    /// let handler = |conn: trillium::Conn| async move {
235    ///     let header = conn
236    ///         .request_headers()
237    ///         .get_str("some-request-header")
238    ///         .unwrap_or_default();
239    ///     let response = format!("some-request-header was {}", header);
240    ///     conn.ok(response)
241    /// };
242    ///
243    /// with_server(handler, move |url| async move {
244    ///     let client = Client::new(client_config());
245    ///     let mut conn = client
246    ///         .get(url)
247    ///         .with_request_headers([
248    ///             ("some-request-header", "header-value"),
249    ///             ("some-other-req-header", "other-header-value"),
250    ///         ])
251    ///         .await?;
252    ///
253    ///     assert_eq!(
254    ///         conn.response_body().read_string().await?,
255    ///         "some-request-header was header-value"
256    ///     );
257    ///     Ok(())
258    /// })
259    /// ```
260    pub fn with_request_headers<HN, HV, I>(mut self, headers: I) -> Self
261    where
262        I: IntoIterator<Item = (HN, HV)> + Send,
263        HN: Into<HeaderName<'static>>,
264        HV: Into<HeaderValues>,
265    {
266        self.request_headers.extend(headers);
267        self
268    }
269
270    /// Chainable method to remove a request header if present
271    pub fn without_request_header(mut self, name: impl Into<HeaderName<'static>>) -> Self {
272        self.request_headers.remove(name);
273        self
274    }
275
276    /// chainable setter for json body. this requires the `serde_json` crate feature to be enabled.
277    #[cfg(feature = "serde_json")]
278    pub fn with_json_body(self, body: &impl serde::Serialize) -> serde_json::Result<Self> {
279        use trillium_http::KnownHeaderName;
280
281        Ok(self
282            .with_body(serde_json::to_string(body)?)
283            .with_request_header(KnownHeaderName::ContentType, "application/json"))
284    }
285
286    /// chainable setter for json body. this requires the `sonic-rs` crate feature to be enabled.
287    #[cfg(feature = "sonic-rs")]
288    pub fn with_json_body(self, body: &impl serde::Serialize) -> sonic_rs::Result<Self> {
289        use trillium_http::KnownHeaderName;
290
291        Ok(self
292            .with_body(sonic_rs::to_string(body)?)
293            .with_request_header(KnownHeaderName::ContentType, "application/json"))
294    }
295
296    pub(crate) fn response_encoding(&self) -> &'static Encoding {
297        encoding(&self.response_headers)
298    }
299
300    /// returns a [`ResponseBody`](crate::ResponseBody) that borrows the connection inside this
301    /// conn.
302    /// ```
303    /// use trillium_client::Client;
304    /// use trillium_testing::{client_config, with_server};
305    ///
306    /// let handler = |mut conn: trillium::Conn| async move { conn.ok("hello from trillium") };
307    ///
308    /// with_server(handler, |url| async move {
309    ///     let client = Client::from(client_config());
310    ///     let mut conn = client.get(url).await?;
311    ///
312    ///     let response_body = conn.response_body(); //<-
313    ///
314    ///     assert_eq!(19, response_body.content_length().unwrap());
315    ///     let string = response_body.read_string().await?;
316    ///     assert_eq!("hello from trillium", string);
317    ///     Ok(())
318    /// });
319    /// ```
320    #[allow(clippy::needless_borrow, clippy::needless_borrows_for_generic_args)]
321    pub fn response_body(&mut self) -> ResponseBody<'_> {
322        ReceivedBody::new(
323            self.response_content_length(),
324            &mut self.buffer,
325            self.transport.as_mut().unwrap(),
326            &mut self.response_body_state,
327            None,
328            encoding(&self.response_headers),
329        )
330        .with_trailers(&mut self.response_trailers)
331        .into()
332    }
333
334    /// Attempt to deserialize the response body. Note that this consumes the body content.
335    #[cfg(feature = "serde_json")]
336    pub async fn response_json<T>(&mut self) -> Result<T, ClientSerdeError>
337    where
338        T: serde::de::DeserializeOwned,
339    {
340        let body = self.response_body().read_string().await?;
341        Ok(serde_json::from_str(&body)?)
342    }
343
344    /// Attempt to deserialize the response body. Note that this consumes the body content.
345    #[cfg(feature = "sonic-rs")]
346    pub async fn response_json<T>(&mut self) -> Result<T, ClientSerdeError>
347    where
348        T: serde::de::DeserializeOwned,
349    {
350        let body = self.response_body().read_string().await?;
351        Ok(sonic_rs::from_str(&body)?)
352    }
353
354    /// Returns the conn or an [`UnexpectedStatusError`] that contains the conn
355    ///
356    /// ```
357    /// use trillium_client::{Client, Status};
358    /// use trillium_testing::{client_config, with_server};
359    ///
360    /// with_server(Status::NotFound, |url| async move {
361    ///     let client = Client::new(client_config());
362    ///     assert_eq!(
363    ///         client.get(url).await?.success().unwrap_err().to_string(),
364    ///         "expected a success (2xx) status code, but got 404 Not Found"
365    ///     );
366    ///     Ok(())
367    /// });
368    ///
369    /// with_server(Status::Ok, |url| async move {
370    ///     let client = Client::new(client_config());
371    ///     assert!(client.get(url).await?.success().is_ok());
372    ///     Ok(())
373    /// });
374    /// ```
375    pub fn success(self) -> Result<Self, UnexpectedStatusError> {
376        match self.status() {
377            Some(status) if status.is_success() => Ok(self),
378            _ => Err(self.into()),
379        }
380    }
381
382    /// Returns this conn to the connection pool if it is keepalive, and
383    /// closes it otherwise. This will happen asynchronously as a spawned
384    /// task when the conn is dropped, but calling it explicitly allows
385    /// you to block on it and control where it happens.
386    pub async fn recycle(mut self) {
387        if self.is_keep_alive() && self.transport.is_some() && self.pool.is_some() {
388            self.finish_reading_body().await;
389        }
390    }
391
392    /// attempts to retrieve the connected peer address
393    pub fn peer_addr(&self) -> Option<SocketAddr> {
394        self.transport
395            .as_ref()
396            .and_then(|t| t.peer_addr().ok().flatten())
397    }
398
399    /// add state to the client conn and return self
400    pub fn with_state<T: Send + Sync + 'static>(mut self, state: T) -> Self {
401        self.insert_state(state);
402        self
403    }
404
405    /// add state to the client conn, returning any previously set state of this type
406    pub fn insert_state<T: Send + Sync + 'static>(&mut self, state: T) -> Option<T> {
407        self.state.insert(state)
408    }
409
410    /// borrow state
411    pub fn state<T: Send + Sync + 'static>(&self) -> Option<&T> {
412        self.state.get()
413    }
414
415    /// borrow state mutably
416    pub fn state_mut<T: Send + Sync + 'static>(&mut self) -> Option<&mut T> {
417        self.state.get_mut()
418    }
419
420    /// take state
421    pub fn take_state<T: Send + Sync + 'static>(&mut self) -> Option<T> {
422        self.state.take()
423    }
424}