1use crate::{Runtime, RuntimeTrait, ServerConnector, TestTransport, runtime};
2use async_channel::Sender;
3use std::{
4 any::{Any, type_name},
5 fmt::{self, Debug, Formatter},
6 future::{Future, IntoFuture},
7 net::IpAddr,
8 pin::Pin,
9 str,
10 sync::Arc,
11};
12use trillium::{Handler, Info, KnownHeaderName};
13use trillium_client::{Client, IntoUrl};
14use trillium_http::{HeaderName, HeaderValues, Headers, HttpContext, Method, Status};
15#[allow(clippy::test_attr_in_doctest, reason = "demonstrating test usage")]
16#[derive(Clone, Debug)]
57pub struct TestServer<H> {
58 client: Client,
59 peer_ip_sender: Sender<IpAddr>,
60 connector: ServerConnector<H>,
61}
62
63impl<H: Handler> TestServer<H> {
64 pub async fn new(handler: H) -> Self {
68 Self::new_with_runtime(handler, runtime()).await
69 }
70
71 async fn new_with_runtime(mut handler: H, rt: impl RuntimeTrait) -> Self {
72 let url = "http://trillium.test".into_url(None).unwrap();
73 let mut info = Info::from(HttpContext::default());
74 info.insert_shared_state(rt.clone());
75 info.insert_shared_state(Runtime::new(rt.clone()));
76 info.insert_shared_state(url.clone());
77 handler.init(&mut info).await;
78 let context: Arc<HttpContext> = Arc::new(info.into());
79 let mut connector = ServerConnector::new(handler)
80 .with_context(context.clone())
81 .with_runtime(rt);
82 let (peer_ip_sender, receive) = async_channel::unbounded();
83 connector.server_peer_ips_receiver = Some(receive);
84 let client = Client::new(connector.clone()).with_base(url);
85
86 Self {
87 client,
88 peer_ip_sender,
89 connector,
90 }
91 }
92
93 pub fn new_blocking(handler: H) -> Self {
95 let rt = runtime();
100 rt.clone().block_on(Self::new_with_runtime(handler, rt))
101 }
102
103 pub fn build<M: TryInto<Method>>(&self, method: M, path: &str) -> ConnTest
105 where
106 M::Error: Debug,
107 {
108 ConnTest {
109 inner: self.client.build_conn(method, path),
110 body: None,
111 peer_ip_sender: self.peer_ip_sender.clone(),
112 peer_ip: None,
113 runtime: self.connector.runtime().clone(),
114 }
115 }
116
117 pub fn shared_state<T: Send + Sync + 'static>(&self) -> Option<&T> {
119 self.connector.context().shared_state().get()
120 }
121
122 #[track_caller]
124 pub fn assert_shared_state<T>(&self, expected: T) -> &Self
125 where
126 T: Send + Sync + Debug + PartialEq + 'static,
127 {
128 match self.shared_state::<T>() {
129 Some(actual) => assert_eq!(*actual, expected),
130 None => panic!(
131 "expected handler state of type {}, but none was found",
132 type_name::<T>()
133 ),
134 };
135 self
136 }
137
138 pub fn assert_shared_state_with<T, F>(&self, f: F) -> &Self
140 where
141 T: Send + Sync + 'static,
142 F: FnOnce(&T),
143 {
144 match self.shared_state::<T>() {
145 Some(state) => f(state),
146 None => panic!(
147 "expected handler state of type {}, but none was found",
148 type_name::<T>()
149 ),
150 };
151 self
152 }
153
154 pub fn handler(&self) -> &H {
156 self.connector.handler()
157 }
158
159 pub fn with_host(mut self, host: &str) -> Self {
162 self.set_host(host);
163 self
164 }
165
166 pub fn set_host(&mut self, host: &str) -> &mut Self {
169 let _ = self.client.base_mut().unwrap().set_host(Some(host));
170 self
171 }
172
173 pub fn with_base(mut self, base: impl IntoUrl) -> Self {
176 self.set_base(base);
177 self
178 }
179
180 pub fn set_base(&mut self, base: impl IntoUrl) -> &mut Self {
183 self.client
184 .set_base(base)
185 .expect("unable to build a base url");
186 self
187 }
188
189 pub fn get(&self, path: &str) -> ConnTest {
191 self.build(Method::Get, path)
192 }
193
194 pub fn post(&self, path: &str) -> ConnTest {
196 self.build(Method::Post, path)
197 }
198
199 pub fn put(&self, path: &str) -> ConnTest {
201 self.build(Method::Put, path)
202 }
203
204 pub fn delete(&self, path: &str) -> ConnTest {
206 self.build(Method::Delete, path)
207 }
208
209 pub fn patch(&self, path: &str) -> ConnTest {
211 self.build(Method::Patch, path)
212 }
213}
214
215pub struct ConnTest {
222 inner: trillium_client::Conn,
223 body: Option<Vec<u8>>,
224 peer_ip_sender: Sender<IpAddr>,
225 peer_ip: Option<IpAddr>,
226 runtime: Runtime,
227}
228
229impl Debug for ConnTest {
230 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
231 f.debug_struct("ConnTest")
232 .field("inner", &self.inner)
233 .field("body", &self.body.as_deref().map(String::from_utf8_lossy))
234 .finish()
235 }
236}
237
238impl ConnTest {
240 pub fn with_request_header(
242 mut self,
243 name: impl Into<HeaderName<'static>>,
244 value: impl Into<HeaderValues>,
245 ) -> Self {
246 self.inner.request_headers_mut().insert(name, value);
247 self
248 }
249
250 pub fn with_request_headers<HN, HV, I>(mut self, headers: I) -> Self
252 where
253 I: IntoIterator<Item = (HN, HV)> + Send,
254 HN: Into<HeaderName<'static>>,
255 HV: Into<HeaderValues>,
256 {
257 self.inner.request_headers_mut().extend(headers);
258 self
259 }
260
261 pub fn without_request_header(mut self, name: impl Into<HeaderName<'static>>) -> Self {
263 self.inner.request_headers_mut().remove(name);
264 self
265 }
266
267 pub fn with_body(mut self, body: impl Into<trillium_http::Body>) -> Self {
269 self.inner.set_request_body(body);
270 self
271 }
272
273 #[cfg(any(feature = "sonic-rs", feature = "serde_json"))]
276 pub fn with_json_body(mut self, body: &impl serde::Serialize) -> Self {
277 self.inner
278 .request_headers_mut()
279 .try_insert(KnownHeaderName::ContentType, "application/json");
280
281 self.with_body(crate::to_json_string(body).unwrap())
282 }
283
284 pub fn with_peer_ip(mut self, peer_ip: impl Into<IpAddr>) -> Self {
286 self.peer_ip = Some(peer_ip.into());
287 self
288 }
289
290 pub fn block(self) -> Self {
292 self.runtime.clone().block_on(self.into_future())
293 }
294}
295
296impl ConnTest {
298 pub fn state<T: Send + Sync + 'static>(&self) -> Option<&T> {
300 self.inner.state::<T>()
301 }
302
303 #[track_caller]
305 pub fn assert_state<T>(&self, expected: T) -> &Self
306 where
307 T: PartialEq + Debug + Send + Sync + 'static,
308 {
309 match self.state::<T>() {
310 Some(actual) => assert_eq!(*actual, expected),
311 None => panic!(
312 "expected handler state of type {}, but none was found",
313 type_name::<T>()
314 ),
315 }
316 self
317 }
318
319 #[track_caller]
321 pub fn assert_no_state<T>(&self) -> &Self
322 where
323 T: Debug + Send + Sync + 'static,
324 {
325 if let Some(value) = self.state::<T>() {
326 panic!(
327 "expected no handler state of type {}, but found {:?}",
328 type_name::<T>(),
329 value
330 );
331 }
332 self
333 }
334
335 pub fn status(&self) -> Status {
339 self.inner
340 .status()
341 .expect("response not yet received — did you .await this ConnTest?")
342 }
343
344 pub fn body(&self) -> &str {
348 str::from_utf8(self.body_bytes()).expect("body was not utf-8")
349 }
350
351 pub fn body_bytes(&self) -> &[u8] {
355 self.body.as_deref().expect("body was not set")
356 }
357
358 pub fn response_headers(&self) -> &Headers {
360 self.inner.response_headers()
361 }
362
363 pub fn response_trailers(&self) -> Option<&Headers> {
365 self.inner.response_trailers()
366 }
367
368 pub fn request_trailers(&self) -> Option<&Headers> {
370 self.inner.request_trailers()
371 }
372
373 pub fn header<'a>(&self, name: impl Into<HeaderName<'a>>) -> Option<&str> {
375 self.inner.response_headers().get_str(name)
376 }
377
378 pub fn trailer<'a>(&self, name: impl Into<HeaderName<'a>>) -> Option<&str> {
380 self.inner
381 .response_trailers()
382 .and_then(|trailers| trailers.get_str(name))
383 }
384
385 #[track_caller]
387 pub fn assert_status(&self, status: impl TryInto<Status>) -> &Self {
388 let expected: Status = status
389 .try_into()
390 .ok()
391 .expect("expected a valid status code");
392 let actual = self.status();
393 assert_eq!(actual, expected, "expected status {expected}, got {actual}");
394 self
395 }
396
397 #[track_caller]
399 pub fn assert_ok(&self) -> &Self {
400 self.assert_status(200)
401 }
402
403 #[track_caller]
406 pub fn assert_body(&self, expected: &str) -> &Self {
407 assert_eq!(self.body().trim_end(), expected.trim_end());
408 self
409 }
410
411 #[track_caller]
413 pub fn assert_body_contains(&self, pattern: &str) -> &Self {
414 let body = self.body();
415 assert!(
416 body.contains(pattern),
417 "expected body to contain {pattern:?}, but body was:\n{body}"
418 );
419 self
420 }
421
422 #[track_caller]
424 pub fn assert_header<'a, HV, HN>(&self, name: HN, expected: HV) -> &Self
425 where
426 HeaderValues: PartialEq<HV>,
427 HV: Debug,
428 HN: Into<HeaderName<'a>>,
429 {
430 let name = name.into();
431
432 match self.inner.response_headers().get_values(name.clone()) {
433 Some(actual) => assert_eq!(*actual, expected, "for header {name:?}"),
434 None => panic!("header {name} not set"),
435 };
436
437 self
438 }
439
440 #[track_caller]
442 pub fn assert_headers<'a, I, HN, HV>(&self, headers: I) -> &Self
443 where
444 I: IntoIterator<Item = (HN, HV)> + Send,
445 HN: Into<HeaderName<'a>>,
446 HV: Debug,
447 HeaderValues: PartialEq<HV>,
448 {
449 for (name, expected) in headers {
450 self.assert_header(name, expected);
451 }
452
453 self
454 }
455
456 #[track_caller]
458 pub fn assert_no_header(&self, name: &str) -> &Self {
459 let actual = self.header(name);
460 assert!(
461 actual.is_none(),
462 "expected no header {name:?}, but found {actual:?}"
463 );
464 self
465 }
466
467 #[track_caller]
470 pub fn assert_header_with<'a, F>(&self, name: impl Into<HeaderName<'a>>, f: F) -> &Self
471 where
472 F: FnOnce(&HeaderValues),
473 {
474 let name = name.into();
475 match self.response_headers().get_values(name.clone()) {
476 Some(values) => f(values),
477 None => panic!("expected header {name:?}, but it was not found"),
478 }
479
480 self
481 }
482
483 #[track_caller]
485 pub fn assert_state_with<T, F>(&self, f: F) -> &Self
486 where
487 T: Send + Sync + 'static,
488 F: FnOnce(&T),
489 {
490 match self.state::<T>() {
491 Some(state) => f(state),
492 None => panic!(
493 "expected handler state of type {}, but none was found",
494 type_name::<T>()
495 ),
496 };
497 self
498 }
499
500 #[track_caller]
502 pub fn assert_body_with<F>(&self, f: F) -> &Self
503 where
504 F: FnOnce(&str),
505 {
506 f(self.body());
507 self
508 }
509
510 #[cfg(any(feature = "sonic-rs", feature = "serde_json"))]
512 #[track_caller]
513 pub fn assert_json_body_with<T, F>(&self, f: F) -> &Self
514 where
515 T: serde::de::DeserializeOwned,
516 F: FnOnce(&T),
517 {
518 let parsed: T =
519 crate::from_json_str(self.body()).expect("failed to parse response body as JSON");
520 f(&parsed);
521 self
522 }
523
524 #[cfg(any(feature = "sonic-rs", feature = "serde_json"))]
526 #[track_caller]
527 pub fn assert_json_body<T>(&self, body: &T) -> &Self
528 where
529 T: serde::de::DeserializeOwned + PartialEq + Debug,
530 {
531 let parsed: T =
532 crate::from_json_str(self.body()).expect("failed to parse response body as JSON");
533 assert_eq!(&parsed, body);
534 self
535 }
536
537 #[track_caller]
539 pub fn assert_trailer<'a, HV, HN>(&self, name: HN, expected: HV) -> &Self
540 where
541 HeaderValues: PartialEq<HV>,
542 HV: Debug,
543 HN: Into<HeaderName<'a>>,
544 {
545 let name = name.into();
546
547 match self
548 .inner
549 .response_trailers()
550 .and_then(|trailers| trailers.get_values(name.clone()))
551 {
552 Some(actual) => assert_eq!(*actual, expected, "for trailer {name:?}"),
553 None => panic!("trailer {name} not set"),
554 };
555
556 self
557 }
558
559 #[track_caller]
561 pub fn assert_trailers<'a, I, HN, HV>(&self, trailers: I) -> &Self
562 where
563 I: IntoIterator<Item = (HN, HV)> + Send,
564 HN: Into<HeaderName<'a>>,
565 HV: Debug,
566 HeaderValues: PartialEq<HV>,
567 {
568 for (name, expected) in trailers {
569 self.assert_trailer(name, expected);
570 }
571
572 self
573 }
574
575 #[track_caller]
577 pub fn assert_no_trailer(&self, name: &str) -> &Self {
578 let actual = self.trailer(name);
579 assert!(
580 actual.is_none(),
581 "expected no trailer {name:?}, but found {actual:?}"
582 );
583 self
584 }
585
586 #[track_caller]
589 pub fn assert_trailer_with<'a, F>(&self, name: impl Into<HeaderName<'a>>, f: F) -> &Self
590 where
591 F: FnOnce(&HeaderValues),
592 {
593 let name = name.into();
594 match self
595 .response_trailers()
596 .and_then(|trailers| trailers.get_values(name.clone()))
597 {
598 Some(values) => f(values),
599 None => panic!("expected trailer {name:?}, but it was not found"),
600 }
601
602 self
603 }
604}
605
606impl IntoFuture for ConnTest {
607 type IntoFuture = Pin<Box<dyn Future<Output = ConnTest> + Send + 'static>>;
608 type Output = ConnTest;
609
610 fn into_future(mut self) -> Self::IntoFuture {
611 Box::pin(async move {
612 if let Some(peer_ip) = self.peer_ip.take() {
613 let _ = self.peer_ip_sender.send(peer_ip).await;
614 }
615
616 if let Some(host) = self
617 .inner
618 .request_headers()
619 .get_str(KnownHeaderName::Host)
620 .map(ToString::to_string)
621 {
622 let _ = self.inner.url_mut().set_host(Some(&host));
623 }
624
625 let inner = &mut self.inner;
626
627 inner.await.expect("request to virtual server failed");
628
629 let inner = &mut self.inner;
630
631 if let Some(transport) = inner.transport_mut() {
632 let state = std::mem::take(
633 &mut *((transport as &dyn Any)
634 .downcast_ref::<TestTransport>()
635 .unwrap()
636 .state()
637 .write()
638 .unwrap()),
639 );
640
641 *inner.as_mut() = state;
642 }
643
644 self.body = Some(
645 self.inner
646 .response_body()
647 .read_bytes()
648 .await
649 .expect("failed to read response body"),
650 );
651
652 self
653 })
654 }
655}