1use crate::{
2 ClientHandler, Conn, IntoUrl, Pool, USER_AGENT, client_handler::ArcedClientHandler,
3 conn::H2Pooled, h3::H3ClientState,
4};
5use std::{any::Any, fmt::Debug, sync::Arc, time::Duration};
6use trillium_http::{
7 HeaderName, HeaderValues, Headers, HttpContext, KnownHeaderName, Method, ProtocolSession,
8 ReceivedBodyState, TypeSet, Version::Http1_1,
9};
10use trillium_server_common::{
11 ArcedConnector, ArcedQuicClientConfig, Connector, QuicClientConfig, Transport,
12 url::{Origin, Url},
13};
14
15const DEFAULT_H2_IDLE_TIMEOUT: Duration = Duration::from_secs(300);
19
20const DEFAULT_H2_IDLE_PING_THRESHOLD: Duration = Duration::from_secs(10);
23
24const DEFAULT_H2_IDLE_PING_TIMEOUT: Duration = Duration::from_secs(20);
27
28#[derive(Clone, Debug, fieldwork::Fieldwork)]
32pub struct Client {
33 config: ArcedConnector,
34
35 #[field(vis = "pub(crate)", get)]
36 h3: Option<H3ClientState>,
37
38 #[field(vis = "pub(crate)", get)]
39 pool: Option<Pool<Origin, Box<dyn Transport>>>,
40
41 #[field(vis = "pub(crate)", get)]
42 h2_pool: Option<Pool<Origin, H2Pooled>>,
43
44 #[field(get, set, with, without, copy)]
48 h2_idle_timeout: Option<Duration>,
49
50 #[field(get, set, with, copy, without)]
55 h2_idle_ping_threshold: Option<Duration>,
56
57 #[field(get, set, with, copy)]
64 h2_idle_ping_timeout: Duration,
65
66 #[field(get)]
68 base: Option<Arc<Url>>,
69
70 #[field(get)]
72 default_headers: Arc<Headers>,
73
74 #[field(get, set, with, copy, without, option_set_some)]
76 timeout: Option<Duration>,
77
78 #[field(get, get_mut, set, with, into)]
80 context: Arc<HttpContext>,
81
82 #[field(vis = "pub(crate)", get)]
86 handler: ArcedClientHandler,
87}
88
89macro_rules! method {
90 ($fn_name:ident, $method:ident) => {
91 method!(
92 $fn_name,
93 $method,
94 concat!(
95 "Builds a new client conn with the ",
97 stringify!($fn_name),
98 " http method and the provided url.
99
100```
101use trillium_client::{Client, Method};
102use trillium_testing::client_config;
103
104let client = Client::new(client_config());
105let conn = client.",
106 stringify!($fn_name),
107 "(\"http://localhost:8080/some/route\"); //<-
108
109assert_eq!(conn.method(), Method::",
110 stringify!($method),
111 ");
112assert_eq!(conn.url().to_string(), \"http://localhost:8080/some/route\");
113```
114"
115 )
116 );
117 };
118
119 ($fn_name:ident, $method:ident, $doc_comment:expr_2021) => {
120 #[doc = $doc_comment]
121 pub fn $fn_name(&self, url: impl IntoUrl) -> Conn {
122 self.build_conn(Method::$method, url)
123 }
124 };
125}
126
127pub(crate) fn default_request_headers() -> Headers {
128 Headers::new()
129 .with_inserted_header(KnownHeaderName::UserAgent, USER_AGENT)
130 .with_inserted_header(KnownHeaderName::Accept, "*/*")
131}
132
133impl Client {
134 method!(get, Get);
135
136 method!(post, Post);
137
138 method!(put, Put);
139
140 method!(delete, Delete);
141
142 method!(patch, Patch);
143
144 pub fn new(connector: impl Connector) -> Self {
146 Self {
147 config: ArcedConnector::new(connector),
148 h3: None,
149 pool: Some(Pool::default()),
150 h2_pool: Some(Pool::default()),
151 h2_idle_timeout: Some(DEFAULT_H2_IDLE_TIMEOUT),
152 h2_idle_ping_threshold: Some(DEFAULT_H2_IDLE_PING_THRESHOLD),
153 h2_idle_ping_timeout: DEFAULT_H2_IDLE_PING_TIMEOUT,
154 base: None,
155 default_headers: Arc::new(default_request_headers()),
156 timeout: None,
157 context: Default::default(),
158 handler: ArcedClientHandler::new(()),
159 }
160 }
161
162 pub fn new_with_quic<C: Connector, Q: QuicClientConfig<C>>(connector: C, quic: Q) -> Self {
172 let arced_quic = ArcedQuicClientConfig::new(&connector, quic);
174
175 #[cfg_attr(not(feature = "webtransport"), allow(unused_mut))]
176 let mut context = HttpContext::default();
177 #[cfg(feature = "webtransport")]
178 {
179 context
184 .config_mut()
185 .set_h3_datagrams_enabled(true)
186 .set_webtransport_enabled(true)
187 .set_extended_connect_enabled(true);
188 }
189
190 Self {
191 config: ArcedConnector::new(connector),
192 h3: Some(H3ClientState::new(arced_quic)),
193 pool: Some(Pool::default()),
194 h2_pool: Some(Pool::default()),
195 h2_idle_timeout: Some(DEFAULT_H2_IDLE_TIMEOUT),
196 h2_idle_ping_threshold: Some(DEFAULT_H2_IDLE_PING_THRESHOLD),
197 h2_idle_ping_timeout: DEFAULT_H2_IDLE_PING_TIMEOUT,
198 base: None,
199 default_headers: Arc::new(default_request_headers()),
200 timeout: None,
201 context: Arc::new(context),
202 handler: ArcedClientHandler::new(()),
203 }
204 }
205
206 #[must_use]
215 pub fn with_handler<H: ClientHandler>(mut self, handler: H) -> Self {
216 self.set_handler(handler);
217 self
218 }
219
220 pub fn set_handler<H: ClientHandler>(&mut self, handler: H) -> &mut Self {
223 self.handler = ArcedClientHandler::new(handler);
224 self
225 }
226
227 pub fn downcast_handler<T: Any + 'static>(&self) -> Option<&T> {
233 self.handler.downcast_ref()
234 }
235
236 pub fn without_default_header(mut self, name: impl Into<HeaderName<'static>>) -> Self {
238 self.default_headers_mut().remove(name);
239 self
240 }
241
242 pub fn with_default_header(
244 mut self,
245 name: impl Into<HeaderName<'static>>,
246 value: impl Into<HeaderValues>,
247 ) -> Self {
248 self.default_headers_mut().insert(name, value);
249 self
250 }
251
252 pub fn default_headers_mut(&mut self) -> &mut Headers {
256 Arc::make_mut(&mut self.default_headers)
257 }
258
259 pub fn without_keepalive(mut self) -> Self {
268 self.pool = None;
269 self.h2_pool = None;
270 self
271 }
272
273 pub fn build_conn<M>(&self, method: M, url: impl IntoUrl) -> Conn
289 where
290 M: TryInto<Method>,
291 <M as TryInto<Method>>::Error: Debug,
292 {
293 let method = method.try_into().unwrap();
294 let (url, request_target) = if let Some(base) = &self.base
295 && let Some(request_target) = url.request_target(method)
296 {
297 ((**base).clone(), Some(request_target))
298 } else {
299 (self.build_url(url).unwrap(), None)
300 };
301
302 Conn {
303 url,
304 method,
305 request_headers: Headers::clone(&self.default_headers),
306 response_headers: Headers::new(),
307 transport: None,
308 status: None,
309 request_body: None,
310 protocol_session: ProtocolSession::Http1,
311 #[cfg(feature = "webtransport")]
312 wt_pool_entry: None,
313 buffer: Vec::with_capacity(128).into(),
314 response_body_state: ReceivedBodyState::Start,
315 headers_finalized: false,
316 halted: false,
317 error: None,
318 body_override: None,
319 timeout: self.timeout,
320 http_version: Http1_1,
321 max_head_length: 8 * 1024,
322 state: TypeSet::new(),
323 context: self.context.clone(),
324 authority: None,
325 scheme: None,
326 path: None,
327 request_target,
328 protocol: None,
329 request_trailers: None,
330 response_trailers: None,
331 client: self.clone(),
332 followup: None,
333 }
334 }
335
336 pub fn connector(&self) -> &ArcedConnector {
338 &self.config
339 }
340
341 pub fn clean_up_pool(&self) {
345 if let Some(pool) = &self.pool {
346 pool.cleanup();
347 }
348 if let Some(h2_pool) = &self.h2_pool {
349 h2_pool.cleanup();
350 }
351 }
352
353 pub fn with_base(mut self, base: impl IntoUrl) -> Self {
355 self.set_base(base).unwrap();
356 self
357 }
358
359 pub fn build_url(&self, url: impl IntoUrl) -> crate::Result<Url> {
361 url.into_url(self.base())
362 }
363
364 pub fn set_base(&mut self, base: impl IntoUrl) -> crate::Result<()> {
366 let mut base = base.into_url(None)?;
367
368 if !base.path().ends_with('/') {
369 log::warn!("appending a trailing / to {base}");
370 base.set_path(&format!("{}/", base.path()));
371 }
372
373 self.base = Some(Arc::new(base));
374 Ok(())
375 }
376
377 pub fn base_mut(&mut self) -> Option<&mut Url> {
382 let base = self.base.as_mut()?;
383 Some(Arc::make_mut(base))
384 }
385}
386
387impl<T: Connector> From<T> for Client {
388 fn from(connector: T) -> Self {
389 Self::new(connector)
390 }
391}