trillium_client/conn.rs
1use crate::{Pool, ResponseBody, h3::H3ClientState, util::encoding};
2use encoding_rs::Encoding;
3use std::{borrow::Cow, net::SocketAddr, sync::Arc, time::Duration};
4use trillium_http::{
5 Body, Buffer, HeaderName, HeaderValues, Headers, HttpContext, Method, ReceivedBody,
6 ReceivedBodyState, Status, TypeSet, Version, h3::H3Connection,
7};
8use trillium_server_common::{
9 ArcedConnector, Transport,
10 url::{Origin, Url},
11};
12
13mod h1;
14mod h3;
15mod shared;
16mod unexpected_status_error;
17
18#[cfg(any(feature = "serde_json", feature = "sonic-rs"))]
19pub use shared::ClientSerdeError;
20pub use unexpected_status_error::UnexpectedStatusError;
21
22/// a client connection, representing both an outbound http request and a
23/// http response
24#[must_use]
25#[derive(fieldwork::Fieldwork)]
26pub struct Conn {
27 pub(crate) pool: Option<Pool<Origin, Box<dyn Transport>>>,
28 pub(crate) h3_client_state: Option<H3ClientState>,
29 pub(crate) h3_connection: Option<(Arc<H3Connection>, u64)>,
30 pub(crate) buffer: Buffer,
31 pub(crate) response_body_state: ReceivedBodyState,
32 pub(crate) config: ArcedConnector,
33 pub(crate) headers_finalized: bool,
34 pub(crate) max_head_length: usize,
35 pub(crate) state: TypeSet,
36 pub(crate) context: Arc<HttpContext>,
37
38 /// the transport for this conn
39 ///
40 /// This should only be used to call your own custom methods on the transport that do not read
41 /// or write any data. Calling any method that reads from or writes to the transport will
42 /// disrupt the HTTP protocol.
43 #[field(get, get_mut)]
44 pub(crate) transport: Option<Box<dyn Transport>>,
45
46 /// the url for this conn.
47 ///
48 /// ```
49 /// use trillium_client::{Client, Method};
50 /// use trillium_testing::client_config;
51 ///
52 /// let client = Client::from(client_config());
53 ///
54 /// let conn = client.get("http://localhost:9080");
55 ///
56 /// let url = conn.url(); //<-
57 ///
58 /// assert_eq!(url.host_str().unwrap(), "localhost");
59 /// ```
60 #[field(get, set, get_mut)]
61 pub(crate) url: Url,
62
63 /// the method for this conn.
64 ///
65 /// ```
66 /// use trillium_client::{Client, Method};
67 /// use trillium_testing::client_config;
68 ///
69 /// let client = Client::from(client_config());
70 /// let conn = client.get("http://localhost:9080");
71 ///
72 /// let method = conn.method(); //<-
73 ///
74 /// assert_eq!(method, Method::Get);
75 /// ```
76 #[field(get, set, copy)]
77 pub(crate) method: Method,
78
79 /// the request headers
80 #[field(get, get_mut)]
81 pub(crate) request_headers: Headers,
82
83 #[field(get, get_mut)]
84 /// the response headers
85 pub(crate) response_headers: Headers,
86
87 /// the status code for this conn.
88 ///
89 /// If the conn has not yet been sent, this will be None.
90 ///
91 /// ```
92 /// use trillium_client::{Client, Status};
93 /// use trillium_testing::{client_config, with_server};
94 ///
95 /// async fn handler(conn: trillium::Conn) -> trillium::Conn {
96 /// conn.with_status(418)
97 /// }
98 ///
99 /// with_server(handler, |url| async move {
100 /// let client = Client::new(client_config());
101 /// let conn = client.get(url).await?;
102 /// assert_eq!(Status::ImATeapot, conn.status().unwrap());
103 /// Ok(())
104 /// });
105 /// ```
106 #[field(get, copy)]
107 pub(crate) status: Option<Status>,
108
109 /// the request body
110 ///
111 /// ```
112 /// env_logger::init();
113 /// use trillium_client::Client;
114 /// use trillium_testing::{client_config, with_server};
115 ///
116 /// let handler = |mut conn: trillium::Conn| async move {
117 /// let body = conn.request_body_string().await.unwrap();
118 /// conn.ok(format!("request body was: {}", body))
119 /// };
120 ///
121 /// with_server(handler, |url| async move {
122 /// let client = Client::from(client_config());
123 /// let mut conn = client
124 /// .post(url)
125 /// .with_body("body") //<-
126 /// .await?;
127 ///
128 /// assert_eq!(
129 /// conn.response_body().read_string().await?,
130 /// "request body was: body"
131 /// );
132 /// Ok(())
133 /// });
134 /// ```
135 #[field(with = with_body, argument = body, set, into, take, option_set_some)]
136 pub(crate) request_body: Option<Body>,
137
138 /// the timeout for this conn
139 ///
140 /// this can also be set on the client with [`Client::set_timeout`](crate::Client::set_timeout)
141 /// and [`Client::with_timeout`](crate::Client::with_timeout)
142 #[field(with, set, get, get_mut, take, copy, option_set_some)]
143 pub(crate) timeout: Option<Duration>,
144
145 /// the http version for this conn
146 ///
147 /// prior to conn execution, this reflects the intended http version that will be sent, and
148 /// after execution this reflects the server-indicated http version
149 #[field(get, set, with, copy)]
150 pub(crate) http_version: Version,
151
152 /// the :authority pseudo-header, populated during h3 header finalization
153 #[field(get)]
154 pub(crate) authority: Option<Cow<'static, str>>,
155 /// the :scheme pseudo-header, populated during h3 header finalization
156
157 #[field(get)]
158 pub(crate) scheme: Option<Cow<'static, str>>,
159
160 /// the :path pseudo-header, populated during h3 header finalization
161 #[field(get)]
162 pub(crate) path: Option<Cow<'static, str>>,
163
164 /// an explicit request target override, used only for `OPTIONS *` and `CONNECT host:port`
165 ///
166 /// When set and the method is OPTIONS or CONNECT, this value is used as the HTTP request
167 /// target instead of deriving it from the url. For all other methods, this field is ignored.
168 #[field(with, set, get, option_set_some, into)]
169 pub(crate) request_target: Option<Cow<'static, str>>,
170
171 /// trailers sent with the request body, populated after the body has been fully sent.
172 ///
173 /// Only present when the request body was constructed with [`Body::new_with_trailers`] and
174 /// the body has been fully sent. For H3, this is populated after `send_h3_request`; for H1,
175 /// after `send_body` with a chunked body.
176 #[field(get)]
177 pub(crate) request_trailers: Option<Headers>,
178
179 /// trailers received with the response body, populated after the response body has been fully
180 /// read.
181 ///
182 /// For H3, these are decoded from the trailing HEADERS frame. For H1, from chunked trailers
183 /// (once H1 trailer receive is implemented).
184 #[field(get)]
185 pub(crate) response_trailers: Option<Headers>,
186}
187
188/// default http user-agent header
189pub const USER_AGENT: &str = concat!("trillium-client/", env!("CARGO_PKG_VERSION"));
190
191impl Conn {
192 /// chainable setter for [`inserting`](Headers::insert) a request header
193 ///
194 /// ```
195 /// use trillium_client::Client;
196 /// use trillium_testing::{client_config, with_server};
197 ///
198 /// let handler = |conn: trillium::Conn| async move {
199 /// let header = conn
200 /// .request_headers()
201 /// .get_str("some-request-header")
202 /// .unwrap_or_default();
203 /// let response = format!("some-request-header was {}", header);
204 /// conn.ok(response)
205 /// };
206 ///
207 /// with_server(handler, |url| async move {
208 /// let client = Client::new(client_config());
209 /// let mut conn = client
210 /// .get(url)
211 /// .with_request_header("some-request-header", "header-value") // <--
212 /// .await?;
213 /// assert_eq!(
214 /// conn.response_body().read_string().await?,
215 /// "some-request-header was header-value"
216 /// );
217 /// Ok(())
218 /// })
219 /// ```
220 pub fn with_request_header(
221 mut self,
222 name: impl Into<HeaderName<'static>>,
223 value: impl Into<HeaderValues>,
224 ) -> Self {
225 self.request_headers.insert(name, value);
226 self
227 }
228
229 /// chainable setter for `extending` request headers
230 ///
231 /// ```
232 /// use trillium_client::Client;
233 /// use trillium_testing::{client_config, with_server};
234 ///
235 /// let handler = |conn: trillium::Conn| async move {
236 /// let header = conn
237 /// .request_headers()
238 /// .get_str("some-request-header")
239 /// .unwrap_or_default();
240 /// let response = format!("some-request-header was {}", header);
241 /// conn.ok(response)
242 /// };
243 ///
244 /// with_server(handler, move |url| async move {
245 /// let client = Client::new(client_config());
246 /// let mut conn = client
247 /// .get(url)
248 /// .with_request_headers([
249 /// ("some-request-header", "header-value"),
250 /// ("some-other-req-header", "other-header-value"),
251 /// ])
252 /// .await?;
253 ///
254 /// assert_eq!(
255 /// conn.response_body().read_string().await?,
256 /// "some-request-header was header-value"
257 /// );
258 /// Ok(())
259 /// })
260 /// ```
261 pub fn with_request_headers<HN, HV, I>(mut self, headers: I) -> Self
262 where
263 I: IntoIterator<Item = (HN, HV)> + Send,
264 HN: Into<HeaderName<'static>>,
265 HV: Into<HeaderValues>,
266 {
267 self.request_headers.extend(headers);
268 self
269 }
270
271 /// Chainable method to remove a request header if present
272 pub fn without_request_header(mut self, name: impl Into<HeaderName<'static>>) -> Self {
273 self.request_headers.remove(name);
274 self
275 }
276
277 /// chainable setter for json body. this requires the `serde_json` crate feature to be enabled.
278 #[cfg(feature = "serde_json")]
279 pub fn with_json_body(self, body: &impl serde::Serialize) -> serde_json::Result<Self> {
280 use trillium_http::KnownHeaderName;
281
282 Ok(self
283 .with_body(serde_json::to_string(body)?)
284 .with_request_header(KnownHeaderName::ContentType, "application/json"))
285 }
286
287 /// chainable setter for json body. this requires the `sonic-rs` crate feature to be enabled.
288 #[cfg(feature = "sonic-rs")]
289 pub fn with_json_body(self, body: &impl serde::Serialize) -> sonic_rs::Result<Self> {
290 use trillium_http::KnownHeaderName;
291
292 Ok(self
293 .with_body(sonic_rs::to_string(body)?)
294 .with_request_header(KnownHeaderName::ContentType, "application/json"))
295 }
296
297 pub(crate) fn response_encoding(&self) -> &'static Encoding {
298 encoding(&self.response_headers)
299 }
300
301 /// returns a [`ResponseBody`](crate::ResponseBody) that borrows the connection inside this
302 /// conn.
303 /// ```
304 /// use trillium_client::Client;
305 /// use trillium_testing::{client_config, with_server};
306 ///
307 /// let handler = |mut conn: trillium::Conn| async move { conn.ok("hello from trillium") };
308 ///
309 /// with_server(handler, |url| async move {
310 /// let client = Client::from(client_config());
311 /// let mut conn = client.get(url).await?;
312 ///
313 /// let response_body = conn.response_body(); //<-
314 ///
315 /// assert_eq!(19, response_body.content_length().unwrap());
316 /// let string = response_body.read_string().await?;
317 /// assert_eq!("hello from trillium", string);
318 /// Ok(())
319 /// });
320 /// ```
321 #[allow(clippy::needless_borrow, clippy::needless_borrows_for_generic_args)]
322 pub fn response_body(&mut self) -> ResponseBody<'_> {
323 ReceivedBody::new(
324 self.response_content_length(),
325 &mut self.buffer,
326 self.transport.as_mut().unwrap(),
327 &mut self.response_body_state,
328 None,
329 encoding(&self.response_headers),
330 )
331 .with_trailers(&mut self.response_trailers)
332 .with_h3_connection(self.h3_connection.clone())
333 .into()
334 }
335
336 /// Attempt to deserialize the response body. Note that this consumes the body content.
337 #[cfg(feature = "serde_json")]
338 pub async fn response_json<T>(&mut self) -> Result<T, ClientSerdeError>
339 where
340 T: serde::de::DeserializeOwned,
341 {
342 let body = self.response_body().read_string().await?;
343 Ok(serde_json::from_str(&body)?)
344 }
345
346 /// Attempt to deserialize the response body. Note that this consumes the body content.
347 #[cfg(feature = "sonic-rs")]
348 pub async fn response_json<T>(&mut self) -> Result<T, ClientSerdeError>
349 where
350 T: serde::de::DeserializeOwned,
351 {
352 let body = self.response_body().read_string().await?;
353 Ok(sonic_rs::from_str(&body)?)
354 }
355
356 /// Returns the conn or an [`UnexpectedStatusError`] that contains the conn
357 ///
358 /// ```
359 /// use trillium_client::{Client, Status};
360 /// use trillium_testing::{client_config, with_server};
361 ///
362 /// with_server(Status::NotFound, |url| async move {
363 /// let client = Client::new(client_config());
364 /// assert_eq!(
365 /// client.get(url).await?.success().unwrap_err().to_string(),
366 /// "expected a success (2xx) status code, but got 404 Not Found"
367 /// );
368 /// Ok(())
369 /// });
370 ///
371 /// with_server(Status::Ok, |url| async move {
372 /// let client = Client::new(client_config());
373 /// assert!(client.get(url).await?.success().is_ok());
374 /// Ok(())
375 /// });
376 /// ```
377 pub fn success(self) -> Result<Self, UnexpectedStatusError> {
378 match self.status() {
379 Some(status) if status.is_success() => Ok(self),
380 _ => Err(self.into()),
381 }
382 }
383
384 /// Returns this conn to the connection pool if it is keepalive, and
385 /// closes it otherwise. This will happen asynchronously as a spawned
386 /// task when the conn is dropped, but calling it explicitly allows
387 /// you to block on it and control where it happens.
388 pub async fn recycle(mut self) {
389 if self.is_keep_alive() && self.transport.is_some() && self.pool.is_some() {
390 self.finish_reading_body().await;
391 }
392 }
393
394 /// attempts to retrieve the connected peer address
395 pub fn peer_addr(&self) -> Option<SocketAddr> {
396 self.transport
397 .as_ref()
398 .and_then(|t| t.peer_addr().ok().flatten())
399 }
400
401 /// add state to the client conn and return self
402 pub fn with_state<T: Send + Sync + 'static>(mut self, state: T) -> Self {
403 self.insert_state(state);
404 self
405 }
406
407 /// add state to the client conn, returning any previously set state of this type
408 pub fn insert_state<T: Send + Sync + 'static>(&mut self, state: T) -> Option<T> {
409 self.state.insert(state)
410 }
411
412 /// borrow state
413 pub fn state<T: Send + Sync + 'static>(&self) -> Option<&T> {
414 self.state.get()
415 }
416
417 /// borrow state mutably
418 pub fn state_mut<T: Send + Sync + 'static>(&mut self) -> Option<&mut T> {
419 self.state.get_mut()
420 }
421
422 /// take state
423 pub fn take_state<T: Send + Sync + 'static>(&mut self) -> Option<T> {
424 self.state.take()
425 }
426}