Skip to main content

trillium_testing/
test_server.rs

1use crate::{Runtime, RuntimeTrait, ServerConnector, TestTransport, runtime};
2use async_channel::Sender;
3use std::{
4    any::{Any, type_name},
5    fmt::{self, Debug, Formatter},
6    future::{Future, IntoFuture},
7    net::IpAddr,
8    pin::Pin,
9    str,
10    sync::Arc,
11};
12use trillium::{Handler, Info, KnownHeaderName};
13use trillium_client::{Client, IntoUrl};
14use trillium_http::{HeaderName, HeaderValues, Headers, HttpContext, Method, Status};
15#[allow(clippy::test_attr_in_doctest, reason = "demonstrating test usage")]
16/// A testing interface that wraps a trillium handler, providing a high-level API for making
17/// requests and asserting on responses.
18///
19/// This runs a full request-response cycle against an in-memory virtual transport using
20/// [`trillium-client`](https://docs.rs/trillium-client). No ports are bound and the tests are fully
21/// parallelizable.
22///
23/// A fluent set of assertions are provided that chain off of a borrow.
24///
25/// ```
26/// use test_harness::test;
27/// use trillium::{Conn, Status};
28/// use trillium_testing::{TestResult, TestServer, harness};
29///
30/// #[test(harness)]
31/// async fn basic_test() {
32///     let app = TestServer::new(|conn: Conn| async move { conn.ok("hello") }).await;
33///
34///     app.get("/").await.assert_ok().assert_body("hello");
35///
36///     // or if you prefer:
37///
38///     let conn = app.post("/").with_body("body").await;
39///     conn.assert_ok();
40///     conn.assert_body("hello");
41/// }
42///
43/// // also an option, but not preferred:
44///
45/// #[test]
46/// fn sync_test() {
47///     let app = TestServer::new_blocking(|conn: Conn| async move { conn.ok("hello") });
48///
49///     app.get("/").block().assert_ok().assert_body("hello");
50///
51///     let conn = app.post("/").with_body("body").block();
52///     conn.assert_ok();
53///     conn.assert_body("hello");
54/// }
55/// ```
56#[derive(Clone, Debug)]
57pub struct TestServer<H> {
58    client: Client,
59    peer_ip_sender: Sender<IpAddr>,
60    connector: ServerConnector<H>,
61}
62
63impl<H: Handler> TestServer<H> {
64    /// Creates a new [`TestServer`].
65    ///
66    /// Note that this is **async** because it initializes the handler with [`Handler::init`].
67    pub async fn new(handler: H) -> Self {
68        Self::new_with_runtime(handler, runtime()).await
69    }
70
71    async fn new_with_runtime(mut handler: H, rt: impl RuntimeTrait) -> Self {
72        let url = "http://trillium.test".into_url(None).unwrap();
73        let mut info = Info::from(HttpContext::default());
74        info.insert_shared_state(rt.clone());
75        info.insert_shared_state(Runtime::new(rt.clone()));
76        info.insert_shared_state(url.clone());
77        handler.init(&mut info).await;
78        let context: Arc<HttpContext> = Arc::new(info.into());
79        let mut connector = ServerConnector::new(handler)
80            .with_context(context.clone())
81            .with_runtime(rt);
82        let (peer_ip_sender, receive) = async_channel::unbounded();
83        connector.server_peer_ips_receiver = Some(receive);
84        let client = Client::new(connector.clone())
85            .without_keepalive()
86            .with_base(url);
87
88        Self {
89            client,
90            peer_ip_sender,
91            connector,
92        }
93    }
94
95    /// construct a new TestServer and block on initialization
96    pub fn new_blocking(handler: H) -> Self {
97        // Create the runtime before block_on so it is stored as an owned (not borrowed) runtime
98        // in the connector. If we used crate::block_on here, runtime() inside new_with_runtime
99        // would detect the current tokio handle and store AlreadyRunning — pointing at the
100        // temporary block_on runtime — which is shut down before block() is ever called.
101        let rt = runtime();
102        rt.clone().block_on(Self::new_with_runtime(handler, rt))
103    }
104
105    /// Build a new [`ConnTest`]
106    pub fn build<M: TryInto<Method>>(&self, method: M, path: &str) -> ConnTest
107    where
108        M::Error: Debug,
109    {
110        ConnTest {
111            inner: self.client.build_conn(method, path),
112            body: None,
113            peer_ip_sender: self.peer_ip_sender.clone(),
114            peer_ip: None,
115            runtime: self.connector.runtime().clone(),
116        }
117    }
118
119    /// borrow from shared state configured by the handler
120    pub fn shared_state<T: Send + Sync + 'static>(&self) -> Option<&T> {
121        self.connector.context().shared_state().get()
122    }
123
124    /// assert that a given type is in shared state
125    #[track_caller]
126    pub fn assert_shared_state<T>(&self, expected: T) -> &Self
127    where
128        T: Send + Sync + Debug + PartialEq + 'static,
129    {
130        match self.shared_state::<T>() {
131            Some(actual) => assert_eq!(*actual, expected),
132            None => panic!(
133                "expected handler state of type {}, but none was found",
134                type_name::<T>()
135            ),
136        };
137        self
138    }
139
140    /// assert that a given type is in shared and make further assertions on it
141    pub fn assert_shared_state_with<T, F>(&self, f: F) -> &Self
142    where
143        T: Send + Sync + 'static,
144        F: FnOnce(&T),
145    {
146        match self.shared_state::<T>() {
147            Some(state) => f(state),
148            None => panic!(
149                "expected handler state of type {}, but none was found",
150                type_name::<T>()
151            ),
152        };
153        self
154    }
155
156    /// Borrow the handler
157    pub fn handler(&self) -> &H {
158        self.connector.handler()
159    }
160
161    /// Add a default host/authority for this virtual server (eg pretend this server is running at
162    /// `example.com` with `.with_host("example.com")`
163    pub fn with_host(mut self, host: &str) -> Self {
164        self.set_host(host);
165        self
166    }
167
168    /// Set the default host/authority for this virtual server (eg pretend this server is running at
169    /// `example.com` with `.set_host("example.com")`
170    pub fn set_host(&mut self, host: &str) -> &mut Self {
171        let _ = self.client.base_mut().unwrap().set_host(Some(host));
172        self
173    }
174
175    /// Set the url for this virtual server (eg pretend this server is running at
176    /// `https://example.com` with `.with_base("https://example.com")`
177    pub fn with_base(mut self, base: impl IntoUrl) -> Self {
178        self.set_base(base);
179        self
180    }
181
182    /// Set the url for this virtual server (eg pretend this server is running at
183    /// `https://example.com` with `.set_base("https://example.com")`
184    pub fn set_base(&mut self, base: impl IntoUrl) -> &mut Self {
185        self.client
186            .set_base(base)
187            .expect("unable to build a base url");
188        self
189    }
190
191    /// Builds a GET [`ConnTest`] for the given path.
192    pub fn get(&self, path: &str) -> ConnTest {
193        self.build(Method::Get, path)
194    }
195
196    /// Builds a POST [`ConnTest`] for the given path.
197    pub fn post(&self, path: &str) -> ConnTest {
198        self.build(Method::Post, path)
199    }
200
201    /// Builds a PUT [`ConnTest`] for the given path.
202    pub fn put(&self, path: &str) -> ConnTest {
203        self.build(Method::Put, path)
204    }
205
206    /// Builds a DELETE [`ConnTest`] for the given path.
207    pub fn delete(&self, path: &str) -> ConnTest {
208        self.build(Method::Delete, path)
209    }
210
211    /// Builds a PATCH [`ConnTest`] for the given path.
212    pub fn patch(&self, path: &str) -> ConnTest {
213        self.build(Method::Patch, path)
214    }
215}
216
217/// Represents both an outbound HTTP request being built and the received HTTP response.
218///
219/// Before `.await`, use the request-building methods to configure the request.
220/// After `.await`, use the accessor and assertion methods to inspect the response.
221///
222/// The response body is read eagerly on `.await`, so all accessors are synchronous.
223pub struct ConnTest {
224    inner: trillium_client::Conn,
225    body: Option<Vec<u8>>,
226    peer_ip_sender: Sender<IpAddr>,
227    peer_ip: Option<IpAddr>,
228    runtime: Runtime,
229}
230
231impl Debug for ConnTest {
232    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
233        f.debug_struct("ConnTest")
234            .field("inner", &self.inner)
235            .field("body", &self.body.as_deref().map(String::from_utf8_lossy))
236            .finish()
237    }
238}
239
240/// Request-building methods (use before `.await`)
241impl ConnTest {
242    /// Inserts a request header, replacing any existing value for that header name.
243    pub fn with_request_header(
244        mut self,
245        name: impl Into<HeaderName<'static>>,
246        value: impl Into<HeaderValues>,
247    ) -> Self {
248        self.inner.request_headers_mut().insert(name, value);
249        self
250    }
251
252    /// Extends the request headers from an iterable of `(name, value)` pairs.
253    pub fn with_request_headers<HN, HV, I>(mut self, headers: I) -> Self
254    where
255        I: IntoIterator<Item = (HN, HV)> + Send,
256        HN: Into<HeaderName<'static>>,
257        HV: Into<HeaderValues>,
258    {
259        self.inner.request_headers_mut().extend(headers);
260        self
261    }
262
263    /// Removes a request header if present.
264    pub fn without_request_header(mut self, name: impl Into<HeaderName<'static>>) -> Self {
265        self.inner.request_headers_mut().remove(name);
266        self
267    }
268
269    /// Sets the request body.
270    pub fn with_body(mut self, body: impl Into<trillium_http::Body>) -> Self {
271        self.inner.set_request_body(body);
272        self
273    }
274
275    /// Sets the request body to the given serializable, as well as setting content-type:
276    /// application/json if not already set
277    #[cfg(any(feature = "sonic-rs", feature = "serde_json"))]
278    pub fn with_json_body(mut self, body: &impl serde::Serialize) -> Self {
279        self.inner
280            .request_headers_mut()
281            .try_insert(KnownHeaderName::ContentType, "application/json");
282
283        self.with_body(crate::to_json_string(body).unwrap())
284    }
285
286    /// Sets a test-double ip that represents the *client's* ip, available to the server as peer ip.
287    pub fn with_peer_ip(mut self, peer_ip: impl Into<IpAddr>) -> Self {
288        self.peer_ip = Some(peer_ip.into());
289        self
290    }
291
292    /// Perform a blocking request
293    pub fn block(self) -> Self {
294        self.runtime.clone().block_on(self.into_future())
295    }
296}
297
298/// Response accessors and assertions (use after `.await`)
299impl ConnTest {
300    /// Returns handler state of type `T` set on the conn during the request, if any.
301    pub fn state<T: Send + Sync + 'static>(&self) -> Option<&T> {
302        self.inner.state::<T>()
303    }
304
305    /// Asserts that handler state of type `T` was set and equals `expected`.
306    #[track_caller]
307    pub fn assert_state<T>(&self, expected: T) -> &Self
308    where
309        T: PartialEq + Debug + Send + Sync + 'static,
310    {
311        match self.state::<T>() {
312            Some(actual) => assert_eq!(*actual, expected),
313            None => panic!(
314                "expected handler state of type {}, but none was found",
315                type_name::<T>()
316            ),
317        }
318        self
319    }
320
321    /// Asserts that no handler state of type `T` was set on the conn during the request.
322    #[track_caller]
323    pub fn assert_no_state<T>(&self) -> &Self
324    where
325        T: Debug + Send + Sync + 'static,
326    {
327        if let Some(value) = self.state::<T>() {
328            panic!(
329                "expected no handler state of type {}, but found {:?}",
330                type_name::<T>(),
331                value
332            );
333        }
334        self
335    }
336
337    /// Returns the response status code.
338    ///
339    /// Panics if called before the request has been sent (i.e., before `.await`).
340    pub fn status(&self) -> Status {
341        self.inner
342            .status()
343            .expect("response not yet received — did you .await this ConnTest?")
344    }
345
346    /// Returns the response body as a `&str`.
347    ///
348    /// Panics if no body was received from the server, or if the body is not a valid utf-8 string.
349    pub fn body(&self) -> &str {
350        str::from_utf8(self.body_bytes()).expect("body was not utf-8")
351    }
352
353    /// Returns the response body as a `&str`.
354    ///
355    /// Panics if no body was received from the server
356    pub fn body_bytes(&self) -> &[u8] {
357        self.body.as_deref().expect("body was not set")
358    }
359
360    /// Returns the response headers.
361    pub fn response_headers(&self) -> &Headers {
362        self.inner.response_headers()
363    }
364
365    /// Returns the response headers.
366    pub fn response_trailers(&self) -> Option<&Headers> {
367        self.inner.response_trailers()
368    }
369
370    /// Returns the response headers.
371    pub fn request_trailers(&self) -> Option<&Headers> {
372        self.inner.request_trailers()
373    }
374
375    /// Returns the value of a response header by name, if present.
376    pub fn header<'a>(&self, name: impl Into<HeaderName<'a>>) -> Option<&str> {
377        self.inner.response_headers().get_str(name)
378    }
379
380    /// Returns the value of a response trailer by name, if present.
381    pub fn trailer<'a>(&self, name: impl Into<HeaderName<'a>>) -> Option<&str> {
382        self.inner
383            .response_trailers()
384            .and_then(|trailers| trailers.get_str(name))
385    }
386
387    /// Asserts that the response status equals `expected`.
388    #[track_caller]
389    pub fn assert_status(&self, status: impl TryInto<Status>) -> &Self {
390        let expected: Status = status
391            .try_into()
392            .ok()
393            .expect("expected a valid status code");
394        let actual = self.status();
395        assert_eq!(actual, expected, "expected status {expected}, got {actual}");
396        self
397    }
398
399    /// Asserts that the response status is 200 OK.
400    #[track_caller]
401    pub fn assert_ok(&self) -> &Self {
402        self.assert_status(200)
403    }
404
405    /// Asserts that the response body is a string that equals `expected`, ignoring trailing
406    /// whitespace
407    #[track_caller]
408    pub fn assert_body(&self, expected: &str) -> &Self {
409        assert_eq!(self.body().trim_end(), expected.trim_end());
410        self
411    }
412
413    /// Asserts that the response body contains `pattern`.
414    #[track_caller]
415    pub fn assert_body_contains(&self, pattern: &str) -> &Self {
416        let body = self.body();
417        assert!(
418            body.contains(pattern),
419            "expected body to contain {pattern:?}, but body was:\n{body}"
420        );
421        self
422    }
423
424    /// Asserts that the response has a header `name` with value `value`.
425    #[track_caller]
426    pub fn assert_header<'a, HV, HN>(&self, name: HN, expected: HV) -> &Self
427    where
428        HeaderValues: PartialEq<HV>,
429        HV: Debug,
430        HN: Into<HeaderName<'a>>,
431    {
432        let name = name.into();
433
434        match self.inner.response_headers().get_values(name.clone()) {
435            Some(actual) => assert_eq!(*actual, expected, "for header {name:?}"),
436            None => panic!("header {name} not set"),
437        };
438
439        self
440    }
441
442    /// Asserts that the response has a header `name` with value `value`.
443    #[track_caller]
444    pub fn assert_headers<'a, I, HN, HV>(&self, headers: I) -> &Self
445    where
446        I: IntoIterator<Item = (HN, HV)> + Send,
447        HN: Into<HeaderName<'a>>,
448        HV: Debug,
449        HeaderValues: PartialEq<HV>,
450    {
451        for (name, expected) in headers {
452            self.assert_header(name, expected);
453        }
454
455        self
456    }
457
458    /// Asserts that the response has no header named `name`.
459    #[track_caller]
460    pub fn assert_no_header(&self, name: &str) -> &Self {
461        let actual = self.header(name);
462        assert!(
463            actual.is_none(),
464            "expected no header {name:?}, but found {actual:?}"
465        );
466        self
467    }
468
469    /// Asserts that a header with the given name exists and runs the provided closure with its
470    /// value.
471    #[track_caller]
472    pub fn assert_header_with<'a, F>(&self, name: impl Into<HeaderName<'a>>, f: F) -> &Self
473    where
474        F: FnOnce(&HeaderValues),
475    {
476        let name = name.into();
477        match self.response_headers().get_values(name.clone()) {
478            Some(values) => f(values),
479            None => panic!("expected header {name:?}, but it was not found"),
480        }
481
482        self
483    }
484
485    /// Asserts that handler state of type `T` was set and runs the provided closure with it.
486    #[track_caller]
487    pub fn assert_state_with<T, F>(&self, f: F) -> &Self
488    where
489        T: Send + Sync + 'static,
490        F: FnOnce(&T),
491    {
492        match self.state::<T>() {
493            Some(state) => f(state),
494            None => panic!(
495                "expected handler state of type {}, but none was found",
496                type_name::<T>()
497            ),
498        };
499        self
500    }
501
502    /// Runs the provided closure with the response body.
503    #[track_caller]
504    pub fn assert_body_with<F>(&self, f: F) -> &Self
505    where
506        F: FnOnce(&str),
507    {
508        f(self.body());
509        self
510    }
511
512    /// Parses the response body as JSON and runs the provided closure with the parsed value.
513    #[cfg(any(feature = "sonic-rs", feature = "serde_json"))]
514    #[track_caller]
515    pub fn assert_json_body_with<T, F>(&self, f: F) -> &Self
516    where
517        T: serde::de::DeserializeOwned,
518        F: FnOnce(&T),
519    {
520        let parsed: T =
521            crate::from_json_str(self.body()).expect("failed to parse response body as JSON");
522        f(&parsed);
523        self
524    }
525
526    /// Parses the response body as JSON and runs the provided closure with the parsed value.
527    #[cfg(any(feature = "sonic-rs", feature = "serde_json"))]
528    #[track_caller]
529    pub fn assert_json_body<T>(&self, body: &T) -> &Self
530    where
531        T: serde::de::DeserializeOwned + PartialEq + Debug,
532    {
533        let parsed: T =
534            crate::from_json_str(self.body()).expect("failed to parse response body as JSON");
535        assert_eq!(&parsed, body);
536        self
537    }
538
539    /// Asserts that the response has a trailer `name` with value `value`.
540    #[track_caller]
541    pub fn assert_trailer<'a, HV, HN>(&self, name: HN, expected: HV) -> &Self
542    where
543        HeaderValues: PartialEq<HV>,
544        HV: Debug,
545        HN: Into<HeaderName<'a>>,
546    {
547        let name = name.into();
548
549        match self
550            .inner
551            .response_trailers()
552            .and_then(|trailers| trailers.get_values(name.clone()))
553        {
554            Some(actual) => assert_eq!(*actual, expected, "for trailer {name:?}"),
555            None => panic!("trailer {name} not set"),
556        };
557
558        self
559    }
560
561    /// Asserts that the response has a trailer `name` with value `value`.
562    #[track_caller]
563    pub fn assert_trailers<'a, I, HN, HV>(&self, trailers: I) -> &Self
564    where
565        I: IntoIterator<Item = (HN, HV)> + Send,
566        HN: Into<HeaderName<'a>>,
567        HV: Debug,
568        HeaderValues: PartialEq<HV>,
569    {
570        for (name, expected) in trailers {
571            self.assert_trailer(name, expected);
572        }
573
574        self
575    }
576
577    /// Asserts that the response has no trailer named `name`.
578    #[track_caller]
579    pub fn assert_no_trailer(&self, name: &str) -> &Self {
580        let actual = self.trailer(name);
581        assert!(
582            actual.is_none(),
583            "expected no trailer {name:?}, but found {actual:?}"
584        );
585        self
586    }
587
588    /// Asserts that a trailer with the given name exists and runs the provided closure with its
589    /// value.
590    #[track_caller]
591    pub fn assert_trailer_with<'a, F>(&self, name: impl Into<HeaderName<'a>>, f: F) -> &Self
592    where
593        F: FnOnce(&HeaderValues),
594    {
595        let name = name.into();
596        match self
597            .response_trailers()
598            .and_then(|trailers| trailers.get_values(name.clone()))
599        {
600            Some(values) => f(values),
601            None => panic!("expected trailer {name:?}, but it was not found"),
602        }
603
604        self
605    }
606}
607
608impl IntoFuture for ConnTest {
609    type IntoFuture = Pin<Box<dyn Future<Output = ConnTest> + Send + 'static>>;
610    type Output = ConnTest;
611
612    fn into_future(mut self) -> Self::IntoFuture {
613        Box::pin(async move {
614            if let Some(peer_ip) = self.peer_ip.take() {
615                let _ = self.peer_ip_sender.send(peer_ip).await;
616            }
617
618            if let Some(host) = self
619                .inner
620                .request_headers()
621                .get_str(KnownHeaderName::Host)
622                .map(ToString::to_string)
623            {
624                let _ = self.inner.url_mut().set_host(Some(&host));
625            }
626
627            let inner = &mut self.inner;
628
629            inner.await.expect("request to virtual server failed");
630
631            let inner = &mut self.inner;
632
633            if let Some(transport) = inner.transport_mut() {
634                let state = std::mem::take(
635                    &mut *((transport as &dyn Any)
636                        .downcast_ref::<TestTransport>()
637                        .unwrap()
638                        .state()
639                        .write()
640                        .unwrap()),
641                );
642
643                *inner.as_mut() = state;
644            }
645
646            self.body = Some(
647                self.inner
648                    .response_body()
649                    .read_bytes()
650                    .await
651                    .expect("failed to read response body"),
652            );
653
654            self
655        })
656    }
657}