Skip to main content

trillium_api/
api_conn_ext.rs

1use crate::{Error, Result};
2use mime::Mime;
3use serde::{Serialize, de::DeserializeOwned};
4use std::future::Future;
5#[cfg(any(feature = "serde_json", feature = "sonic-rs"))]
6use trillium::Status;
7use trillium::{
8    Conn,
9    KnownHeaderName::{Accept, ContentType},
10};
11
12/// Extension trait that adds api methods to [`trillium::Conn`]
13pub trait ApiConnExt {
14    /// Sends a json response body. This sets a status code of 200,
15    /// serializes the body with serde_json, sets the content-type to
16    /// application/json, and [halts](trillium::Conn::halt) the
17    /// conn. If serialization fails, a 500 status code is sent as per
18    /// [`trillium::conn_try`]
19    ///
20    ///
21    /// ## Examples
22    ///
23    /// ```
24    /// # if !cfg!(any(feature = "sonic-rs", feature = "serde_json")) { return }
25    /// use trillium_api::{json, ApiConnExt};
26    /// use trillium_testing::TestServer;
27    ///
28    /// async fn handler(conn: trillium::Conn) -> trillium::Conn {
29    /// conn.with_json(&json!({ "json macro": "is reexported" }))
30    /// }
31    ///
32    /// # trillium_testing::block_on(async {
33    /// let app = TestServer::new(handler).await;
34    /// app.get("/")
35    ///     .await
36    ///     .assert_ok()
37    ///     .assert_body(r#"{"json macro":"is reexported"}"#)
38    ///     .assert_header("content-type", "application/json");
39    /// # });
40    /// ```
41    ///
42    /// ### overriding status code
43    /// ```
44    /// use serde::Serialize;
45    /// use trillium_api::ApiConnExt;
46    /// use trillium_testing::TestServer;
47    ///
48    /// #[derive(Serialize)]
49    /// struct ApiResponse {
50    ///     string: &'static str,
51    ///     number: usize,
52    /// }
53    ///
54    /// async fn handler(conn: trillium::Conn) -> trillium::Conn {
55    ///     conn.with_json(&ApiResponse {
56    ///         string: "not the most creative example",
57    ///         number: 100,
58    ///     })
59    ///     .with_status(201)
60    /// }
61    ///
62    /// # trillium_testing::block_on(async {
63    /// let app = TestServer::new(handler).await;
64    /// app.get("/")
65    ///     .await
66    ///     .assert_status(201)
67    ///     .assert_body(r#"{"string":"not the most creative example","number":100}"#)
68    ///     .assert_header("content-type", "application/json");
69    /// # });
70    /// ```
71    #[cfg(any(feature = "sonic-rs", feature = "serde_json"))]
72    #[cfg_attr(docsrs, doc(cfg(any(feature = "sonic-rs", feature = "serde_json"))))]
73    fn with_json(self, response: &impl Serialize) -> Self;
74
75    /// Attempts to deserialize a type from the request body, based on the
76    /// request content type.
77    ///
78    /// By default, both application/json and
79    /// application/x-www-form-urlencoded are supported, and future
80    /// versions may add accepted request content types. Please open an
81    /// issue if you need to accept another content type.
82    ///
83    ///
84    /// To exclusively accept application/json, disable default features
85    /// on this crate.
86    ///
87    /// This sets a status code of Status::Ok if and only if no status
88    /// code has been explicitly set.
89    ///
90    /// ## Examples
91    ///
92    /// ### Deserializing to `Value`
93    ///
94    /// ```no_run
95    /// # if !cfg!(any(feature = "sonic-rs", feature = "serde_json")) { return }
96    /// use trillium_api::{ApiConnExt, Value};
97    ///
98    /// async fn handler(mut conn: trillium::Conn) -> trillium::Conn {
99    ///     let value: Value = match conn.deserialize().await {
100    ///         Ok(v) => v,
101    ///         Err(_) => return conn.with_status(400),
102    ///     };
103    ///     conn.with_json(&value)
104    /// }
105    ///
106    /// # use trillium_testing::TestServer;
107    /// # trillium_testing::block_on(async {
108    /// let app = TestServer::new(handler).await;
109    /// app.post("/")
110    ///     .with_body(r#"key=value"#)
111    ///     .with_request_header("content-type", "application/x-www-form-urlencoded")
112    ///     .await
113    ///     .assert_ok()
114    ///     .assert_body(r#"{"key":"value"}"#)
115    ///     .assert_header("content-type", "application/json");
116    /// # });
117    /// ```
118    ///
119    /// ### Deserializing a concrete type
120    ///
121    /// ```
122    /// use trillium_api::ApiConnExt;
123    /// use trillium_testing::TestServer;
124    ///
125    /// #[derive(serde::Deserialize)]
126    /// struct KvPair {
127    ///     key: String,
128    ///     value: String,
129    /// }
130    ///
131    /// async fn handler(mut conn: trillium::Conn) -> trillium::Conn {
132    ///     match conn.deserialize().await {
133    ///         Ok(KvPair { key, value }) => conn
134    ///             .with_status(201)
135    ///             .with_body(format!("{} is {}", key, value))
136    ///             .halt(),
137    ///
138    ///         Err(_) => conn.with_status(422).with_body("nope").halt(),
139    ///     }
140    /// }
141    ///
142    /// # trillium_testing::block_on(async {
143    /// let app = TestServer::new(handler).await;
144    ///
145    /// app.post("/")
146    ///     .with_body(r#"key=name&value=trillium"#)
147    ///     .with_request_header("content-type", "application/x-www-form-urlencoded")
148    ///     .await
149    ///     .assert_status(201)
150    ///     .assert_body(r#"name is trillium"#);
151    ///
152    /// app.post("/")
153    ///     .with_body(r#"name=trillium"#)
154    ///     .with_request_header("content-type", "application/x-www-form-urlencoded")
155    ///     .await
156    ///     .assert_status(422)
157    ///     .assert_body(r#"nope"#);
158    /// # });
159    /// ```
160    fn deserialize<T>(&mut self) -> impl Future<Output = Result<T>> + Send
161    where
162        T: DeserializeOwned;
163
164    /// Deserializes json without any Accepts header content negotiation
165    #[cfg(any(feature = "sonic-rs", feature = "serde_json"))]
166    #[cfg_attr(docsrs, doc(cfg(any(feature = "sonic-rs", feature = "serde_json"))))]
167    fn deserialize_json<T>(&mut self) -> impl Future<Output = Result<T>> + Send
168    where
169        T: DeserializeOwned;
170
171    /// Serializes the provided body using Accepts header content negotiation
172    fn serialize<T>(&mut self, body: &T) -> impl Future<Output = Result<()>> + Send
173    where
174        T: Serialize + Sync;
175
176    /// Returns a parsed content type for this conn.
177    ///
178    /// Note that this function considers a missing content type an error of variant
179    /// [`Error::MissingContentType`].
180    fn content_type(&self) -> Result<Mime>;
181}
182
183impl ApiConnExt for Conn {
184    #[cfg(any(feature = "sonic-rs", feature = "serde_json"))]
185    #[cfg_attr(docsrs, doc(cfg(any(feature = "sonic-rs", feature = "serde_json"))))]
186    fn with_json(mut self, response: &impl Serialize) -> Self {
187        #[cfg(feature = "serde_json")]
188        let as_string = serde_json::to_string(&response);
189
190        #[cfg(feature = "sonic-rs")]
191        let as_string = sonic_rs::to_string(&response);
192
193        match as_string {
194            Ok(body) => {
195                if self.status().is_none() {
196                    self.set_status(Status::Ok);
197                }
198
199                self.response_headers_mut()
200                    .try_insert(ContentType, "application/json");
201
202                self.with_body(body)
203            }
204
205            Err(error) => self.with_state(Error::from(error)),
206        }
207    }
208
209    async fn deserialize<T>(&mut self) -> Result<T>
210    where
211        T: DeserializeOwned,
212    {
213        let body = self.request_body_string().await?;
214        let content_type = self.content_type()?;
215        let suffix_or_subtype = content_type
216            .suffix()
217            .unwrap_or_else(|| content_type.subtype())
218            .as_str();
219        match suffix_or_subtype {
220            #[cfg(feature = "serde_json")]
221            "json" => {
222                let json_deserializer = &mut serde_json::Deserializer::from_str(&body);
223                Ok(serde_path_to_error::deserialize::<_, T>(json_deserializer)?)
224            }
225
226            #[cfg(feature = "sonic-rs")]
227            "json" => {
228                let json_deserializer = &mut sonic_rs::serde::Deserializer::from_str(&body);
229                Ok(serde_path_to_error::deserialize::<_, T>(json_deserializer)?)
230            }
231
232            #[cfg(feature = "forms")]
233            "x-www-form-urlencoded" => {
234                let body = form_urlencoded::parse(body.as_bytes());
235                let deserializer = serde_urlencoded::Deserializer::new(body);
236                Ok(serde_path_to_error::deserialize::<_, T>(deserializer)?)
237            }
238
239            _ => {
240                drop(body);
241                Err(Error::UnsupportedMimeType {
242                    mime_type: content_type.to_string(),
243                })
244            }
245        }
246    }
247
248    fn content_type(&self) -> Result<Mime> {
249        let header_str = self
250            .request_headers()
251            .get_str(ContentType)
252            .ok_or(Error::MissingContentType)?;
253
254        header_str.parse().map_err(|_| Error::UnsupportedMimeType {
255            mime_type: header_str.into(),
256        })
257    }
258
259    #[cfg(any(feature = "sonic-rs", feature = "serde_json"))]
260    #[cfg_attr(docsrs, doc(cfg(any(feature = "sonic-rs", feature = "serde_json"))))]
261    async fn deserialize_json<T>(&mut self) -> Result<T>
262    where
263        T: DeserializeOwned,
264    {
265        let content_type = self.content_type()?;
266        let suffix_or_subtype = content_type
267            .suffix()
268            .unwrap_or_else(|| content_type.subtype())
269            .as_str();
270        if suffix_or_subtype != "json" {
271            return Err(Error::UnsupportedMimeType {
272                mime_type: content_type.to_string(),
273            });
274        }
275
276        log::debug!("extracting json");
277        let body = self.request_body_string().await?;
278
279        #[cfg(feature = "serde_json")]
280        let json_deserializer = &mut serde_json::Deserializer::from_str(&body);
281
282        #[cfg(feature = "sonic-rs")]
283        let json_deserializer = &mut sonic_rs::serde::Deserializer::from_str(&body);
284
285        Ok(serde_path_to_error::deserialize::<_, T>(json_deserializer)?)
286    }
287
288    async fn serialize<T>(&mut self, body: &T) -> Result<()>
289    where
290        T: Serialize + Sync,
291    {
292        let accept = self
293            .request_headers()
294            .get_str(Accept)
295            .unwrap_or("*/*")
296            .split(',')
297            .map(|s| s.trim())
298            .find_map(acceptable_mime_type);
299
300        match accept {
301            #[cfg(feature = "serde_json")]
302            Some(AcceptableMime::Json) => {
303                self.set_body(serde_json::to_string(body)?);
304                self.insert_response_header(ContentType, "application/json");
305                Ok(())
306            }
307
308            #[cfg(feature = "sonic-rs")]
309            Some(AcceptableMime::Json) => {
310                self.set_body(sonic_rs::to_string(body)?);
311                self.insert_response_header(ContentType, "application/json");
312                Ok(())
313            }
314
315            #[cfg(feature = "forms")]
316            Some(AcceptableMime::Form) => {
317                self.set_body(serde_urlencoded::to_string(body)?);
318                self.insert_response_header(ContentType, "application/x-www-form-urlencoded");
319                Ok(())
320            }
321
322            None => {
323                let _ = body;
324                Err(Error::FailureToNegotiateContent)
325            }
326        }
327    }
328}
329
330enum AcceptableMime {
331    #[cfg_attr(docsrs, doc(cfg(any(feature = "sonic-rs", feature = "serde_json"))))]
332    #[cfg(any(feature = "sonic-rs", feature = "serde_json"))]
333    Json,
334
335    #[cfg(feature = "forms")]
336    #[cfg_attr(docsrs, doc(cfg(feature = "forms")))]
337    Form,
338}
339
340fn acceptable_mime_type(mime: &str) -> Option<AcceptableMime> {
341    let mime: Mime = mime.parse().ok()?;
342    let suffix_or_subtype = mime.suffix().unwrap_or_else(|| mime.subtype()).as_str();
343    match suffix_or_subtype {
344        #[cfg(any(feature = "serde_json", feature = "sonic-rs"))]
345        "*" | "json" => Some(AcceptableMime::Json),
346
347        #[cfg(feature = "forms")]
348        "x-www-form-urlencoded" => Some(AcceptableMime::Form),
349
350        _ => None,
351    }
352}