Skip to main content

trillium_testing/
test_server.rs

1use crate::{Runtime, RuntimeTrait, ServerConnector, TestTransport, runtime};
2use async_channel::Sender;
3use std::{
4    any::{Any, type_name},
5    fmt::{self, Debug, Formatter},
6    future::{Future, IntoFuture},
7    net::IpAddr,
8    pin::Pin,
9    str,
10    sync::Arc,
11};
12use trillium::{Handler, Info, KnownHeaderName};
13use trillium_client::{Client, IntoUrl};
14use trillium_http::{HeaderName, HeaderValues, Headers, HttpContext, Method, Status};
15#[allow(clippy::test_attr_in_doctest, reason = "demonstrating test usage")]
16/// A testing interface that wraps a trillium handler, providing a high-level API for making
17/// requests and asserting on responses.
18///
19/// This runs a full request-response cycle against an in-memory virtual transport using
20/// [`trillium-client`](https://docs.rs/trillium-client). No ports are bound and the tests are fully
21/// parallelizable.
22///
23/// A fluent set of assertions are provided that chain off of a borrow.
24///
25/// ```
26/// use test_harness::test;
27/// use trillium::{Conn, Status};
28/// use trillium_testing::{TestResult, TestServer, harness};
29///
30/// #[test(harness)]
31/// async fn basic_test() {
32///     let app = TestServer::new(|conn: Conn| async move { conn.ok("hello") }).await;
33///
34///     app.get("/").await.assert_ok().assert_body("hello");
35///
36///     // or if you prefer:
37///
38///     let conn = app.post("/").with_body("body").await;
39///     conn.assert_ok();
40///     conn.assert_body("hello");
41/// }
42///
43/// // also an option, but not preferred:
44///
45/// #[test]
46/// fn sync_test() {
47///     let app = TestServer::new_blocking(|conn: Conn| async move { conn.ok("hello") });
48///
49///     app.get("/").block().assert_ok().assert_body("hello");
50///
51///     let conn = app.post("/").with_body("body").block();
52///     conn.assert_ok();
53///     conn.assert_body("hello");
54/// }
55/// ```
56#[derive(Clone, Debug)]
57pub struct TestServer<H> {
58    client: Client,
59    peer_ip_sender: Sender<IpAddr>,
60    connector: ServerConnector<H>,
61}
62
63impl<H: Handler> TestServer<H> {
64    /// Creates a new [`TestServer`].
65    ///
66    /// Note that this is **async** because it initializes the handler with [`Handler::init`].
67    pub async fn new(handler: H) -> Self {
68        Self::new_with_runtime(handler, runtime()).await
69    }
70
71    async fn new_with_runtime(mut handler: H, rt: impl RuntimeTrait) -> Self {
72        let url = "http://trillium.test".into_url(None).unwrap();
73        let mut info = Info::from(HttpContext::default());
74        info.insert_shared_state(rt.clone());
75        info.insert_shared_state(Runtime::new(rt.clone()));
76        info.insert_shared_state(url.clone());
77        handler.init(&mut info).await;
78        let context: Arc<HttpContext> = Arc::new(info.into());
79        let mut connector = ServerConnector::new(handler)
80            .with_context(context.clone())
81            .with_runtime(rt);
82        let (peer_ip_sender, receive) = async_channel::unbounded();
83        connector.server_peer_ips_receiver = Some(receive);
84        let client = Client::new(connector.clone()).with_base(url);
85
86        Self {
87            client,
88            peer_ip_sender,
89            connector,
90        }
91    }
92
93    /// construct a new TestServer and block on initialization
94    pub fn new_blocking(handler: H) -> Self {
95        // Create the runtime before block_on so it is stored as an owned (not borrowed) runtime
96        // in the connector. If we used crate::block_on here, runtime() inside new_with_runtime
97        // would detect the current tokio handle and store AlreadyRunning — pointing at the
98        // temporary block_on runtime — which is shut down before block() is ever called.
99        let rt = runtime();
100        rt.clone().block_on(Self::new_with_runtime(handler, rt))
101    }
102
103    /// Build a new [`ConnTest`]
104    pub fn build<M: TryInto<Method>>(&self, method: M, path: &str) -> ConnTest
105    where
106        M::Error: Debug,
107    {
108        ConnTest {
109            inner: self.client.build_conn(method, path),
110            body: None,
111            peer_ip_sender: self.peer_ip_sender.clone(),
112            peer_ip: None,
113            runtime: self.connector.runtime().clone(),
114        }
115    }
116
117    /// borrow from shared state configured by the handler
118    pub fn shared_state<T: Send + Sync + 'static>(&self) -> Option<&T> {
119        self.connector.context().shared_state().get()
120    }
121
122    /// assert that a given type is in shared state
123    #[track_caller]
124    pub fn assert_shared_state<T>(&self, expected: T) -> &Self
125    where
126        T: Send + Sync + Debug + PartialEq + 'static,
127    {
128        match self.shared_state::<T>() {
129            Some(actual) => assert_eq!(*actual, expected),
130            None => panic!(
131                "expected handler state of type {}, but none was found",
132                type_name::<T>()
133            ),
134        };
135        self
136    }
137
138    /// assert that a given type is in shared and make further assertions on it
139    pub fn assert_shared_state_with<T, F>(&self, f: F) -> &Self
140    where
141        T: Send + Sync + 'static,
142        F: FnOnce(&T),
143    {
144        match self.shared_state::<T>() {
145            Some(state) => f(state),
146            None => panic!(
147                "expected handler state of type {}, but none was found",
148                type_name::<T>()
149            ),
150        };
151        self
152    }
153
154    /// Borrow the handler
155    pub fn handler(&self) -> &H {
156        self.connector.handler()
157    }
158
159    /// Add a default host/authority for this virtual server (eg pretend this server is running at
160    /// `example.com` with `.with_host("example.com")`
161    pub fn with_host(mut self, host: &str) -> Self {
162        self.set_host(host);
163        self
164    }
165
166    /// Set the default host/authority for this virtual server (eg pretend this server is running at
167    /// `example.com` with `.set_host("example.com")`
168    pub fn set_host(&mut self, host: &str) -> &mut Self {
169        let _ = self.client.base_mut().unwrap().set_host(Some(host));
170        self
171    }
172
173    /// Set the url for this virtual server (eg pretend this server is running at
174    /// `https://example.com` with `.with_base("https://example.com")`
175    pub fn with_base(mut self, base: impl IntoUrl) -> Self {
176        self.set_base(base);
177        self
178    }
179
180    /// Set the url for this virtual server (eg pretend this server is running at
181    /// `https://example.com` with `.set_base("https://example.com")`
182    pub fn set_base(&mut self, base: impl IntoUrl) -> &mut Self {
183        self.client
184            .set_base(base)
185            .expect("unable to build a base url");
186        self
187    }
188
189    /// Builds a GET [`ConnTest`] for the given path.
190    pub fn get(&self, path: &str) -> ConnTest {
191        self.build(Method::Get, path)
192    }
193
194    /// Builds a POST [`ConnTest`] for the given path.
195    pub fn post(&self, path: &str) -> ConnTest {
196        self.build(Method::Post, path)
197    }
198
199    /// Builds a PUT [`ConnTest`] for the given path.
200    pub fn put(&self, path: &str) -> ConnTest {
201        self.build(Method::Put, path)
202    }
203
204    /// Builds a DELETE [`ConnTest`] for the given path.
205    pub fn delete(&self, path: &str) -> ConnTest {
206        self.build(Method::Delete, path)
207    }
208
209    /// Builds a PATCH [`ConnTest`] for the given path.
210    pub fn patch(&self, path: &str) -> ConnTest {
211        self.build(Method::Patch, path)
212    }
213}
214
215/// Represents both an outbound HTTP request being built and the received HTTP response.
216///
217/// Before `.await`, use the request-building methods to configure the request.
218/// After `.await`, use the accessor and assertion methods to inspect the response.
219///
220/// The response body is read eagerly on `.await`, so all accessors are synchronous.
221pub struct ConnTest {
222    inner: trillium_client::Conn,
223    body: Option<Vec<u8>>,
224    peer_ip_sender: Sender<IpAddr>,
225    peer_ip: Option<IpAddr>,
226    runtime: Runtime,
227}
228
229impl Debug for ConnTest {
230    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
231        f.debug_struct("ConnTest")
232            .field("inner", &self.inner)
233            .field("body", &self.body.as_deref().map(String::from_utf8_lossy))
234            .finish()
235    }
236}
237
238/// Request-building methods (use before `.await`)
239impl ConnTest {
240    /// Inserts a request header, replacing any existing value for that header name.
241    pub fn with_request_header(
242        mut self,
243        name: impl Into<HeaderName<'static>>,
244        value: impl Into<HeaderValues>,
245    ) -> Self {
246        self.inner.request_headers_mut().insert(name, value);
247        self
248    }
249
250    /// Extends the request headers from an iterable of `(name, value)` pairs.
251    pub fn with_request_headers<HN, HV, I>(mut self, headers: I) -> Self
252    where
253        I: IntoIterator<Item = (HN, HV)> + Send,
254        HN: Into<HeaderName<'static>>,
255        HV: Into<HeaderValues>,
256    {
257        self.inner.request_headers_mut().extend(headers);
258        self
259    }
260
261    /// Removes a request header if present.
262    pub fn without_request_header(mut self, name: impl Into<HeaderName<'static>>) -> Self {
263        self.inner.request_headers_mut().remove(name);
264        self
265    }
266
267    /// Sets the request body.
268    pub fn with_body(mut self, body: impl Into<trillium_http::Body>) -> Self {
269        self.inner.set_request_body(body);
270        self
271    }
272
273    /// Sets the request body to the given serializable, as well as setting content-type:
274    /// application/json if not already set
275    #[cfg(any(feature = "sonic-rs", feature = "serde_json"))]
276    pub fn with_json_body(mut self, body: &impl serde::Serialize) -> Self {
277        self.inner
278            .request_headers_mut()
279            .try_insert(KnownHeaderName::ContentType, "application/json");
280
281        self.with_body(crate::to_json_string(body).unwrap())
282    }
283
284    /// Sets a test-double ip that represents the *client's* ip, available to the server as peer ip.
285    pub fn with_peer_ip(mut self, peer_ip: impl Into<IpAddr>) -> Self {
286        self.peer_ip = Some(peer_ip.into());
287        self
288    }
289
290    /// Perform a blocking request
291    pub fn block(self) -> Self {
292        self.runtime.clone().block_on(self.into_future())
293    }
294}
295
296/// Response accessors and assertions (use after `.await`)
297impl ConnTest {
298    /// Returns handler state of type `T` set on the conn during the request, if any.
299    pub fn state<T: Send + Sync + 'static>(&self) -> Option<&T> {
300        self.inner.state::<T>()
301    }
302
303    /// Asserts that handler state of type `T` was set and equals `expected`.
304    #[track_caller]
305    pub fn assert_state<T>(&self, expected: T) -> &Self
306    where
307        T: PartialEq + Debug + Send + Sync + 'static,
308    {
309        match self.state::<T>() {
310            Some(actual) => assert_eq!(*actual, expected),
311            None => panic!(
312                "expected handler state of type {}, but none was found",
313                type_name::<T>()
314            ),
315        }
316        self
317    }
318
319    /// Asserts that no handler state of type `T` was set on the conn during the request.
320    #[track_caller]
321    pub fn assert_no_state<T>(&self) -> &Self
322    where
323        T: Debug + Send + Sync + 'static,
324    {
325        if let Some(value) = self.state::<T>() {
326            panic!(
327                "expected no handler state of type {}, but found {:?}",
328                type_name::<T>(),
329                value
330            );
331        }
332        self
333    }
334
335    /// Returns the response status code.
336    ///
337    /// Panics if called before the request has been sent (i.e., before `.await`).
338    pub fn status(&self) -> Status {
339        self.inner
340            .status()
341            .expect("response not yet received — did you .await this ConnTest?")
342    }
343
344    /// Returns the response body as a `&str`.
345    ///
346    /// Panics if no body was received from the server, or if the body is not a valid utf-8 string.
347    pub fn body(&self) -> &str {
348        str::from_utf8(self.body_bytes()).expect("body was not utf-8")
349    }
350
351    /// Returns the response body as a `&str`.
352    ///
353    /// Panics if no body was received from the server
354    pub fn body_bytes(&self) -> &[u8] {
355        self.body.as_deref().expect("body was not set")
356    }
357
358    /// Returns the response headers.
359    pub fn response_headers(&self) -> &Headers {
360        self.inner.response_headers()
361    }
362
363    /// Returns the response headers.
364    pub fn response_trailers(&self) -> Option<&Headers> {
365        self.inner.response_trailers()
366    }
367
368    /// Returns the response headers.
369    pub fn request_trailers(&self) -> Option<&Headers> {
370        self.inner.request_trailers()
371    }
372
373    /// Returns the value of a response header by name, if present.
374    pub fn header<'a>(&self, name: impl Into<HeaderName<'a>>) -> Option<&str> {
375        self.inner.response_headers().get_str(name)
376    }
377
378    /// Returns the value of a response trailer by name, if present.
379    pub fn trailer<'a>(&self, name: impl Into<HeaderName<'a>>) -> Option<&str> {
380        self.inner
381            .response_trailers()
382            .and_then(|trailers| trailers.get_str(name))
383    }
384
385    /// Asserts that the response status equals `expected`.
386    #[track_caller]
387    pub fn assert_status(&self, status: impl TryInto<Status>) -> &Self {
388        let expected: Status = status
389            .try_into()
390            .ok()
391            .expect("expected a valid status code");
392        let actual = self.status();
393        assert_eq!(actual, expected, "expected status {expected}, got {actual}");
394        self
395    }
396
397    /// Asserts that the response status is 200 OK.
398    #[track_caller]
399    pub fn assert_ok(&self) -> &Self {
400        self.assert_status(200)
401    }
402
403    /// Asserts that the response body is a string that equals `expected`, ignoring trailing
404    /// whitespace
405    #[track_caller]
406    pub fn assert_body(&self, expected: &str) -> &Self {
407        assert_eq!(self.body().trim_end(), expected.trim_end());
408        self
409    }
410
411    /// Asserts that the response body contains `pattern`.
412    #[track_caller]
413    pub fn assert_body_contains(&self, pattern: &str) -> &Self {
414        let body = self.body();
415        assert!(
416            body.contains(pattern),
417            "expected body to contain {pattern:?}, but body was:\n{body}"
418        );
419        self
420    }
421
422    /// Asserts that the response has a header `name` with value `value`.
423    #[track_caller]
424    pub fn assert_header<'a, HV, HN>(&self, name: HN, expected: HV) -> &Self
425    where
426        HeaderValues: PartialEq<HV>,
427        HV: Debug,
428        HN: Into<HeaderName<'a>>,
429    {
430        let name = name.into();
431
432        match self.inner.response_headers().get_values(name.clone()) {
433            Some(actual) => assert_eq!(*actual, expected, "for header {name:?}"),
434            None => panic!("header {name} not set"),
435        };
436
437        self
438    }
439
440    /// Asserts that the response has a header `name` with value `value`.
441    #[track_caller]
442    pub fn assert_headers<'a, I, HN, HV>(&self, headers: I) -> &Self
443    where
444        I: IntoIterator<Item = (HN, HV)> + Send,
445        HN: Into<HeaderName<'a>>,
446        HV: Debug,
447        HeaderValues: PartialEq<HV>,
448    {
449        for (name, expected) in headers {
450            self.assert_header(name, expected);
451        }
452
453        self
454    }
455
456    /// Asserts that the response has no header named `name`.
457    #[track_caller]
458    pub fn assert_no_header(&self, name: &str) -> &Self {
459        let actual = self.header(name);
460        assert!(
461            actual.is_none(),
462            "expected no header {name:?}, but found {actual:?}"
463        );
464        self
465    }
466
467    /// Asserts that a header with the given name exists and runs the provided closure with its
468    /// value.
469    #[track_caller]
470    pub fn assert_header_with<'a, F>(&self, name: impl Into<HeaderName<'a>>, f: F) -> &Self
471    where
472        F: FnOnce(&HeaderValues),
473    {
474        let name = name.into();
475        match self.response_headers().get_values(name.clone()) {
476            Some(values) => f(values),
477            None => panic!("expected header {name:?}, but it was not found"),
478        }
479
480        self
481    }
482
483    /// Asserts that handler state of type `T` was set and runs the provided closure with it.
484    #[track_caller]
485    pub fn assert_state_with<T, F>(&self, f: F) -> &Self
486    where
487        T: Send + Sync + 'static,
488        F: FnOnce(&T),
489    {
490        match self.state::<T>() {
491            Some(state) => f(state),
492            None => panic!(
493                "expected handler state of type {}, but none was found",
494                type_name::<T>()
495            ),
496        };
497        self
498    }
499
500    /// Runs the provided closure with the response body.
501    #[track_caller]
502    pub fn assert_body_with<F>(&self, f: F) -> &Self
503    where
504        F: FnOnce(&str),
505    {
506        f(self.body());
507        self
508    }
509
510    /// Parses the response body as JSON and runs the provided closure with the parsed value.
511    #[cfg(any(feature = "sonic-rs", feature = "serde_json"))]
512    #[track_caller]
513    pub fn assert_json_body_with<T, F>(&self, f: F) -> &Self
514    where
515        T: serde::de::DeserializeOwned,
516        F: FnOnce(&T),
517    {
518        let parsed: T =
519            crate::from_json_str(self.body()).expect("failed to parse response body as JSON");
520        f(&parsed);
521        self
522    }
523
524    /// Parses the response body as JSON and runs the provided closure with the parsed value.
525    #[cfg(any(feature = "sonic-rs", feature = "serde_json"))]
526    #[track_caller]
527    pub fn assert_json_body<T>(&self, body: &T) -> &Self
528    where
529        T: serde::de::DeserializeOwned + PartialEq + Debug,
530    {
531        let parsed: T =
532            crate::from_json_str(self.body()).expect("failed to parse response body as JSON");
533        assert_eq!(&parsed, body);
534        self
535    }
536
537    /// Asserts that the response has a trailer `name` with value `value`.
538    #[track_caller]
539    pub fn assert_trailer<'a, HV, HN>(&self, name: HN, expected: HV) -> &Self
540    where
541        HeaderValues: PartialEq<HV>,
542        HV: Debug,
543        HN: Into<HeaderName<'a>>,
544    {
545        let name = name.into();
546
547        match self
548            .inner
549            .response_trailers()
550            .and_then(|trailers| trailers.get_values(name.clone()))
551        {
552            Some(actual) => assert_eq!(*actual, expected, "for trailer {name:?}"),
553            None => panic!("trailer {name} not set"),
554        };
555
556        self
557    }
558
559    /// Asserts that the response has a trailer `name` with value `value`.
560    #[track_caller]
561    pub fn assert_trailers<'a, I, HN, HV>(&self, trailers: I) -> &Self
562    where
563        I: IntoIterator<Item = (HN, HV)> + Send,
564        HN: Into<HeaderName<'a>>,
565        HV: Debug,
566        HeaderValues: PartialEq<HV>,
567    {
568        for (name, expected) in trailers {
569            self.assert_trailer(name, expected);
570        }
571
572        self
573    }
574
575    /// Asserts that the response has no trailer named `name`.
576    #[track_caller]
577    pub fn assert_no_trailer(&self, name: &str) -> &Self {
578        let actual = self.trailer(name);
579        assert!(
580            actual.is_none(),
581            "expected no trailer {name:?}, but found {actual:?}"
582        );
583        self
584    }
585
586    /// Asserts that a trailer with the given name exists and runs the provided closure with its
587    /// value.
588    #[track_caller]
589    pub fn assert_trailer_with<'a, F>(&self, name: impl Into<HeaderName<'a>>, f: F) -> &Self
590    where
591        F: FnOnce(&HeaderValues),
592    {
593        let name = name.into();
594        match self
595            .response_trailers()
596            .and_then(|trailers| trailers.get_values(name.clone()))
597        {
598            Some(values) => f(values),
599            None => panic!("expected trailer {name:?}, but it was not found"),
600        }
601
602        self
603    }
604}
605
606impl IntoFuture for ConnTest {
607    type IntoFuture = Pin<Box<dyn Future<Output = ConnTest> + Send + 'static>>;
608    type Output = ConnTest;
609
610    fn into_future(mut self) -> Self::IntoFuture {
611        Box::pin(async move {
612            if let Some(peer_ip) = self.peer_ip.take() {
613                let _ = self.peer_ip_sender.send(peer_ip).await;
614            }
615
616            if let Some(host) = self
617                .inner
618                .request_headers()
619                .get_str(KnownHeaderName::Host)
620                .map(ToString::to_string)
621            {
622                let _ = self.inner.url_mut().set_host(Some(&host));
623            }
624
625            let inner = &mut self.inner;
626
627            inner.await.expect("request to virtual server failed");
628
629            let inner = &mut self.inner;
630
631            if let Some(transport) = inner.transport_mut() {
632                let state = std::mem::take(
633                    &mut *((transport as &dyn Any)
634                        .downcast_ref::<TestTransport>()
635                        .unwrap()
636                        .state()
637                        .write()
638                        .unwrap()),
639                );
640
641                *inner.as_mut() = state;
642            }
643
644            self.body = Some(
645                self.inner
646                    .response_body()
647                    .read_bytes()
648                    .await
649                    .expect("failed to read response body"),
650            );
651
652            self
653        })
654    }
655}