1use crate::{Runtime, RuntimeTrait, ServerConnector, TestTransport, runtime};
2use async_channel::Sender;
3use std::{
4 any::{Any, type_name},
5 fmt::{self, Debug, Formatter},
6 future::{Future, IntoFuture},
7 net::IpAddr,
8 pin::Pin,
9 str,
10 sync::Arc,
11};
12use trillium::{Handler, Info, KnownHeaderName};
13use trillium_client::{Client, IntoUrl};
14use trillium_http::{HeaderName, HeaderValues, Headers, HttpContext, Method, Status};
15#[allow(clippy::test_attr_in_doctest, reason = "demonstrating test usage")]
16#[derive(Clone, Debug)]
57pub struct TestServer<H> {
58 client: Client,
59 peer_ip_sender: Sender<IpAddr>,
60 connector: ServerConnector<H>,
61}
62
63impl<H: Handler> TestServer<H> {
64 pub async fn new(handler: H) -> Self {
68 Self::new_with_runtime(handler, runtime()).await
69 }
70
71 async fn new_with_runtime(mut handler: H, rt: impl RuntimeTrait) -> Self {
72 let url = "http://trillium.test".into_url(None).unwrap();
73 let mut info = Info::from(HttpContext::default());
74 info.insert_shared_state(rt.clone());
75 info.insert_shared_state(Runtime::new(rt.clone()));
76 info.insert_shared_state(url.clone());
77 handler.init(&mut info).await;
78 let context: Arc<HttpContext> = Arc::new(info.into());
79 let mut connector = ServerConnector::new(handler)
80 .with_context(context.clone())
81 .with_runtime(rt);
82 let (peer_ip_sender, receive) = async_channel::unbounded();
83 connector.server_peer_ips_receiver = Some(receive);
84 let client = Client::new(connector.clone())
85 .without_keepalive()
86 .with_base(url);
87
88 Self {
89 client,
90 peer_ip_sender,
91 connector,
92 }
93 }
94
95 pub fn new_blocking(handler: H) -> Self {
97 let rt = runtime();
102 rt.clone().block_on(Self::new_with_runtime(handler, rt))
103 }
104
105 pub fn build<M: TryInto<Method>>(&self, method: M, path: &str) -> ConnTest
107 where
108 M::Error: Debug,
109 {
110 ConnTest {
111 inner: self.client.build_conn(method, path),
112 body: None,
113 peer_ip_sender: self.peer_ip_sender.clone(),
114 peer_ip: None,
115 runtime: self.connector.runtime().clone(),
116 }
117 }
118
119 pub fn shared_state<T: Send + Sync + 'static>(&self) -> Option<&T> {
121 self.connector.context().shared_state().get()
122 }
123
124 #[track_caller]
126 pub fn assert_shared_state<T>(&self, expected: T) -> &Self
127 where
128 T: Send + Sync + Debug + PartialEq + 'static,
129 {
130 match self.shared_state::<T>() {
131 Some(actual) => assert_eq!(*actual, expected),
132 None => panic!(
133 "expected handler state of type {}, but none was found",
134 type_name::<T>()
135 ),
136 };
137 self
138 }
139
140 pub fn assert_shared_state_with<T, F>(&self, f: F) -> &Self
142 where
143 T: Send + Sync + 'static,
144 F: FnOnce(&T),
145 {
146 match self.shared_state::<T>() {
147 Some(state) => f(state),
148 None => panic!(
149 "expected handler state of type {}, but none was found",
150 type_name::<T>()
151 ),
152 };
153 self
154 }
155
156 pub fn handler(&self) -> &H {
158 self.connector.handler()
159 }
160
161 pub fn with_host(mut self, host: &str) -> Self {
164 self.set_host(host);
165 self
166 }
167
168 pub fn set_host(&mut self, host: &str) -> &mut Self {
171 let _ = self.client.base_mut().unwrap().set_host(Some(host));
172 self
173 }
174
175 pub fn with_base(mut self, base: impl IntoUrl) -> Self {
178 self.set_base(base);
179 self
180 }
181
182 pub fn set_base(&mut self, base: impl IntoUrl) -> &mut Self {
185 self.client
186 .set_base(base)
187 .expect("unable to build a base url");
188 self
189 }
190
191 pub fn get(&self, path: &str) -> ConnTest {
193 self.build(Method::Get, path)
194 }
195
196 pub fn post(&self, path: &str) -> ConnTest {
198 self.build(Method::Post, path)
199 }
200
201 pub fn put(&self, path: &str) -> ConnTest {
203 self.build(Method::Put, path)
204 }
205
206 pub fn delete(&self, path: &str) -> ConnTest {
208 self.build(Method::Delete, path)
209 }
210
211 pub fn patch(&self, path: &str) -> ConnTest {
213 self.build(Method::Patch, path)
214 }
215}
216
217pub struct ConnTest {
224 inner: trillium_client::Conn,
225 body: Option<Vec<u8>>,
226 peer_ip_sender: Sender<IpAddr>,
227 peer_ip: Option<IpAddr>,
228 runtime: Runtime,
229}
230
231impl Debug for ConnTest {
232 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
233 f.debug_struct("ConnTest")
234 .field("inner", &self.inner)
235 .field("body", &self.body.as_deref().map(String::from_utf8_lossy))
236 .finish()
237 }
238}
239
240impl ConnTest {
242 pub fn with_request_header(
244 mut self,
245 name: impl Into<HeaderName<'static>>,
246 value: impl Into<HeaderValues>,
247 ) -> Self {
248 self.inner.request_headers_mut().insert(name, value);
249 self
250 }
251
252 pub fn with_request_headers<HN, HV, I>(mut self, headers: I) -> Self
254 where
255 I: IntoIterator<Item = (HN, HV)> + Send,
256 HN: Into<HeaderName<'static>>,
257 HV: Into<HeaderValues>,
258 {
259 self.inner.request_headers_mut().extend(headers);
260 self
261 }
262
263 pub fn without_request_header(mut self, name: impl Into<HeaderName<'static>>) -> Self {
265 self.inner.request_headers_mut().remove(name);
266 self
267 }
268
269 pub fn with_body(mut self, body: impl Into<trillium_http::Body>) -> Self {
271 self.inner.set_request_body(body);
272 self
273 }
274
275 #[cfg(any(feature = "sonic-rs", feature = "serde_json"))]
278 pub fn with_json_body(mut self, body: &impl serde::Serialize) -> Self {
279 self.inner
280 .request_headers_mut()
281 .try_insert(KnownHeaderName::ContentType, "application/json");
282
283 self.with_body(crate::to_json_string(body).unwrap())
284 }
285
286 pub fn with_peer_ip(mut self, peer_ip: impl Into<IpAddr>) -> Self {
288 self.peer_ip = Some(peer_ip.into());
289 self
290 }
291
292 pub fn block(self) -> Self {
294 self.runtime.clone().block_on(self.into_future())
295 }
296}
297
298impl ConnTest {
300 pub fn state<T: Send + Sync + 'static>(&self) -> Option<&T> {
302 self.inner.state::<T>()
303 }
304
305 #[track_caller]
307 pub fn assert_state<T>(&self, expected: T) -> &Self
308 where
309 T: PartialEq + Debug + Send + Sync + 'static,
310 {
311 match self.state::<T>() {
312 Some(actual) => assert_eq!(*actual, expected),
313 None => panic!(
314 "expected handler state of type {}, but none was found",
315 type_name::<T>()
316 ),
317 }
318 self
319 }
320
321 #[track_caller]
323 pub fn assert_no_state<T>(&self) -> &Self
324 where
325 T: Debug + Send + Sync + 'static,
326 {
327 if let Some(value) = self.state::<T>() {
328 panic!(
329 "expected no handler state of type {}, but found {:?}",
330 type_name::<T>(),
331 value
332 );
333 }
334 self
335 }
336
337 pub fn status(&self) -> Status {
341 self.inner
342 .status()
343 .expect("response not yet received — did you .await this ConnTest?")
344 }
345
346 pub fn body(&self) -> &str {
350 str::from_utf8(self.body_bytes()).expect("body was not utf-8")
351 }
352
353 pub fn body_bytes(&self) -> &[u8] {
357 self.body.as_deref().expect("body was not set")
358 }
359
360 pub fn response_headers(&self) -> &Headers {
362 self.inner.response_headers()
363 }
364
365 pub fn response_trailers(&self) -> Option<&Headers> {
367 self.inner.response_trailers()
368 }
369
370 pub fn request_trailers(&self) -> Option<&Headers> {
372 self.inner.request_trailers()
373 }
374
375 pub fn header<'a>(&self, name: impl Into<HeaderName<'a>>) -> Option<&str> {
377 self.inner.response_headers().get_str(name)
378 }
379
380 pub fn trailer<'a>(&self, name: impl Into<HeaderName<'a>>) -> Option<&str> {
382 self.inner
383 .response_trailers()
384 .and_then(|trailers| trailers.get_str(name))
385 }
386
387 #[track_caller]
389 pub fn assert_status(&self, status: impl TryInto<Status>) -> &Self {
390 let expected: Status = status
391 .try_into()
392 .ok()
393 .expect("expected a valid status code");
394 let actual = self.status();
395 assert_eq!(actual, expected, "expected status {expected}, got {actual}");
396 self
397 }
398
399 #[track_caller]
401 pub fn assert_ok(&self) -> &Self {
402 self.assert_status(200)
403 }
404
405 #[track_caller]
408 pub fn assert_body(&self, expected: &str) -> &Self {
409 assert_eq!(self.body().trim_end(), expected.trim_end());
410 self
411 }
412
413 #[track_caller]
415 pub fn assert_body_contains(&self, pattern: &str) -> &Self {
416 let body = self.body();
417 assert!(
418 body.contains(pattern),
419 "expected body to contain {pattern:?}, but body was:\n{body}"
420 );
421 self
422 }
423
424 #[track_caller]
426 pub fn assert_header<'a, HV, HN>(&self, name: HN, expected: HV) -> &Self
427 where
428 HeaderValues: PartialEq<HV>,
429 HV: Debug,
430 HN: Into<HeaderName<'a>>,
431 {
432 let name = name.into();
433
434 match self.inner.response_headers().get_values(name.clone()) {
435 Some(actual) => assert_eq!(*actual, expected, "for header {name:?}"),
436 None => panic!("header {name} not set"),
437 };
438
439 self
440 }
441
442 #[track_caller]
444 pub fn assert_headers<'a, I, HN, HV>(&self, headers: I) -> &Self
445 where
446 I: IntoIterator<Item = (HN, HV)> + Send,
447 HN: Into<HeaderName<'a>>,
448 HV: Debug,
449 HeaderValues: PartialEq<HV>,
450 {
451 for (name, expected) in headers {
452 self.assert_header(name, expected);
453 }
454
455 self
456 }
457
458 #[track_caller]
460 pub fn assert_no_header(&self, name: &str) -> &Self {
461 let actual = self.header(name);
462 assert!(
463 actual.is_none(),
464 "expected no header {name:?}, but found {actual:?}"
465 );
466 self
467 }
468
469 #[track_caller]
472 pub fn assert_header_with<'a, F>(&self, name: impl Into<HeaderName<'a>>, f: F) -> &Self
473 where
474 F: FnOnce(&HeaderValues),
475 {
476 let name = name.into();
477 match self.response_headers().get_values(name.clone()) {
478 Some(values) => f(values),
479 None => panic!("expected header {name:?}, but it was not found"),
480 }
481
482 self
483 }
484
485 #[track_caller]
487 pub fn assert_state_with<T, F>(&self, f: F) -> &Self
488 where
489 T: Send + Sync + 'static,
490 F: FnOnce(&T),
491 {
492 match self.state::<T>() {
493 Some(state) => f(state),
494 None => panic!(
495 "expected handler state of type {}, but none was found",
496 type_name::<T>()
497 ),
498 };
499 self
500 }
501
502 #[track_caller]
504 pub fn assert_body_with<F>(&self, f: F) -> &Self
505 where
506 F: FnOnce(&str),
507 {
508 f(self.body());
509 self
510 }
511
512 #[cfg(any(feature = "sonic-rs", feature = "serde_json"))]
514 #[track_caller]
515 pub fn assert_json_body_with<T, F>(&self, f: F) -> &Self
516 where
517 T: serde::de::DeserializeOwned,
518 F: FnOnce(&T),
519 {
520 let parsed: T =
521 crate::from_json_str(self.body()).expect("failed to parse response body as JSON");
522 f(&parsed);
523 self
524 }
525
526 #[cfg(any(feature = "sonic-rs", feature = "serde_json"))]
528 #[track_caller]
529 pub fn assert_json_body<T>(&self, body: &T) -> &Self
530 where
531 T: serde::de::DeserializeOwned + PartialEq + Debug,
532 {
533 let parsed: T =
534 crate::from_json_str(self.body()).expect("failed to parse response body as JSON");
535 assert_eq!(&parsed, body);
536 self
537 }
538
539 #[track_caller]
541 pub fn assert_trailer<'a, HV, HN>(&self, name: HN, expected: HV) -> &Self
542 where
543 HeaderValues: PartialEq<HV>,
544 HV: Debug,
545 HN: Into<HeaderName<'a>>,
546 {
547 let name = name.into();
548
549 match self
550 .inner
551 .response_trailers()
552 .and_then(|trailers| trailers.get_values(name.clone()))
553 {
554 Some(actual) => assert_eq!(*actual, expected, "for trailer {name:?}"),
555 None => panic!("trailer {name} not set"),
556 };
557
558 self
559 }
560
561 #[track_caller]
563 pub fn assert_trailers<'a, I, HN, HV>(&self, trailers: I) -> &Self
564 where
565 I: IntoIterator<Item = (HN, HV)> + Send,
566 HN: Into<HeaderName<'a>>,
567 HV: Debug,
568 HeaderValues: PartialEq<HV>,
569 {
570 for (name, expected) in trailers {
571 self.assert_trailer(name, expected);
572 }
573
574 self
575 }
576
577 #[track_caller]
579 pub fn assert_no_trailer(&self, name: &str) -> &Self {
580 let actual = self.trailer(name);
581 assert!(
582 actual.is_none(),
583 "expected no trailer {name:?}, but found {actual:?}"
584 );
585 self
586 }
587
588 #[track_caller]
591 pub fn assert_trailer_with<'a, F>(&self, name: impl Into<HeaderName<'a>>, f: F) -> &Self
592 where
593 F: FnOnce(&HeaderValues),
594 {
595 let name = name.into();
596 match self
597 .response_trailers()
598 .and_then(|trailers| trailers.get_values(name.clone()))
599 {
600 Some(values) => f(values),
601 None => panic!("expected trailer {name:?}, but it was not found"),
602 }
603
604 self
605 }
606}
607
608impl IntoFuture for ConnTest {
609 type IntoFuture = Pin<Box<dyn Future<Output = ConnTest> + Send + 'static>>;
610 type Output = ConnTest;
611
612 fn into_future(mut self) -> Self::IntoFuture {
613 Box::pin(async move {
614 if let Some(peer_ip) = self.peer_ip.take() {
615 let _ = self.peer_ip_sender.send(peer_ip).await;
616 }
617
618 if let Some(host) = self
619 .inner
620 .request_headers()
621 .get_str(KnownHeaderName::Host)
622 .map(ToString::to_string)
623 {
624 let _ = self.inner.url_mut().set_host(Some(&host));
625 }
626
627 let inner = &mut self.inner;
628
629 inner.await.expect("request to virtual server failed");
630
631 let inner = &mut self.inner;
632
633 if let Some(transport) = inner.transport_mut() {
634 let state = std::mem::take(
635 &mut *((transport as &dyn Any)
636 .downcast_ref::<TestTransport>()
637 .unwrap()
638 .state()
639 .write()
640 .unwrap()),
641 );
642
643 *inner.as_mut() = state;
644 }
645
646 self.body = Some(
647 self.inner
648 .response_body()
649 .read_bytes()
650 .await
651 .expect("failed to read response body"),
652 );
653
654 self
655 })
656 }
657}