trillium_api/api_conn_ext.rs
1use crate::{Error, Result};
2use mime::Mime;
3use serde::{Serialize, de::DeserializeOwned};
4use std::future::Future;
5#[cfg(any(feature = "serde_json", feature = "sonic-rs"))]
6use trillium::Status;
7use trillium::{
8 Conn,
9 KnownHeaderName::{Accept, ContentType},
10};
11
12/// Extension trait that adds api methods to [`trillium::Conn`]
13pub trait ApiConnExt {
14 /// Sends a json response body. This sets a status code of 200,
15 /// serializes the body with serde_json, sets the content-type to
16 /// application/json, and [halts](trillium::Conn::halt) the
17 /// conn. If serialization fails, a 500 status code is sent as per
18 /// [`trillium::conn_try`]
19 ///
20 ///
21 /// ## Examples
22 ///
23 /// ```
24 /// # if !cfg!(any(feature = "sonic-rs", feature = "serde_json")) { return }
25 /// use trillium_api::{json, ApiConnExt};
26 /// use trillium_testing::TestServer;
27 ///
28 /// async fn handler(conn: trillium::Conn) -> trillium::Conn {
29 /// conn.with_json(&json!({ "json macro": "is reexported" }))
30 /// }
31 ///
32 /// # trillium_testing::block_on(async {
33 /// let app = TestServer::new(handler).await;
34 /// app.get("/")
35 /// .await
36 /// .assert_ok()
37 /// .assert_body(r#"{"json macro":"is reexported"}"#)
38 /// .assert_header("content-type", "application/json");
39 /// # });
40 /// ```
41 ///
42 /// ### overriding status code
43 /// ```
44 /// use serde::Serialize;
45 /// use trillium_api::ApiConnExt;
46 /// use trillium_testing::TestServer;
47 ///
48 /// #[derive(Serialize)]
49 /// struct ApiResponse {
50 /// string: &'static str,
51 /// number: usize,
52 /// }
53 ///
54 /// async fn handler(conn: trillium::Conn) -> trillium::Conn {
55 /// conn.with_json(&ApiResponse {
56 /// string: "not the most creative example",
57 /// number: 100,
58 /// })
59 /// .with_status(201)
60 /// }
61 ///
62 /// # trillium_testing::block_on(async {
63 /// let app = TestServer::new(handler).await;
64 /// app.get("/")
65 /// .await
66 /// .assert_status(201)
67 /// .assert_body(r#"{"string":"not the most creative example","number":100}"#)
68 /// .assert_header("content-type", "application/json");
69 /// # });
70 /// ```
71 #[cfg(any(feature = "sonic-rs", feature = "serde_json"))]
72 #[cfg_attr(docsrs, doc(cfg(any(feature = "sonic-rs", feature = "serde_json"))))]
73 fn with_json(self, response: &impl Serialize) -> Self;
74
75 /// Attempts to deserialize a type from the request body, based on the
76 /// request content type.
77 ///
78 /// By default, both application/json and
79 /// application/x-www-form-urlencoded are supported, and future
80 /// versions may add accepted request content types. Please open an
81 /// issue if you need to accept another content type.
82 ///
83 ///
84 /// To exclusively accept application/json, disable default features
85 /// on this crate.
86 ///
87 /// This sets a status code of Status::Ok if and only if no status
88 /// code has been explicitly set.
89 ///
90 /// ## Examples
91 ///
92 /// ### Deserializing to `Value`
93 ///
94 /// ```no_run
95 /// # if !cfg!(any(feature = "sonic-rs", feature = "serde_json")) { return }
96 /// use trillium_api::{ApiConnExt, Value};
97 ///
98 /// async fn handler(mut conn: trillium::Conn) -> trillium::Conn {
99 /// let value: Value = match conn.deserialize().await {
100 /// Ok(v) => v,
101 /// Err(_) => return conn.with_status(400),
102 /// };
103 /// conn.with_json(&value)
104 /// }
105 ///
106 /// # use trillium_testing::TestServer;
107 /// # trillium_testing::block_on(async {
108 /// let app = TestServer::new(handler).await;
109 /// app.post("/")
110 /// .with_body(r#"key=value"#)
111 /// .with_request_header("content-type", "application/x-www-form-urlencoded")
112 /// .await
113 /// .assert_ok()
114 /// .assert_body(r#"{"key":"value"}"#)
115 /// .assert_header("content-type", "application/json");
116 /// # });
117 /// ```
118 ///
119 /// ### Deserializing a concrete type
120 ///
121 /// ```
122 /// use trillium_api::ApiConnExt;
123 /// use trillium_testing::TestServer;
124 ///
125 /// #[derive(serde::Deserialize)]
126 /// struct KvPair {
127 /// key: String,
128 /// value: String,
129 /// }
130 ///
131 /// async fn handler(mut conn: trillium::Conn) -> trillium::Conn {
132 /// match conn.deserialize().await {
133 /// Ok(KvPair { key, value }) => conn
134 /// .with_status(201)
135 /// .with_body(format!("{} is {}", key, value))
136 /// .halt(),
137 ///
138 /// Err(_) => conn.with_status(422).with_body("nope").halt(),
139 /// }
140 /// }
141 ///
142 /// # trillium_testing::block_on(async {
143 /// let app = TestServer::new(handler).await;
144 ///
145 /// app.post("/")
146 /// .with_body(r#"key=name&value=trillium"#)
147 /// .with_request_header("content-type", "application/x-www-form-urlencoded")
148 /// .await
149 /// .assert_status(201)
150 /// .assert_body(r#"name is trillium"#);
151 ///
152 /// app.post("/")
153 /// .with_body(r#"name=trillium"#)
154 /// .with_request_header("content-type", "application/x-www-form-urlencoded")
155 /// .await
156 /// .assert_status(422)
157 /// .assert_body(r#"nope"#);
158 /// # });
159 /// ```
160 fn deserialize<T>(&mut self) -> impl Future<Output = Result<T>> + Send
161 where
162 T: DeserializeOwned;
163
164 /// Deserializes json without any Accepts header content negotiation
165 #[cfg(any(feature = "sonic-rs", feature = "serde_json"))]
166 #[cfg_attr(docsrs, doc(cfg(any(feature = "sonic-rs", feature = "serde_json"))))]
167 fn deserialize_json<T>(&mut self) -> impl Future<Output = Result<T>> + Send
168 where
169 T: DeserializeOwned;
170
171 /// Serializes the provided body using Accepts header content negotiation
172 fn serialize<T>(&mut self, body: &T) -> impl Future<Output = Result<()>> + Send
173 where
174 T: Serialize + Sync;
175
176 /// Returns a parsed content type for this conn.
177 ///
178 /// Note that this function considers a missing content type an error of variant
179 /// [`Error::MissingContentType`].
180 fn content_type(&self) -> Result<Mime>;
181}
182
183impl ApiConnExt for Conn {
184 #[cfg(any(feature = "sonic-rs", feature = "serde_json"))]
185 #[cfg_attr(docsrs, doc(cfg(any(feature = "sonic-rs", feature = "serde_json"))))]
186 fn with_json(mut self, response: &impl Serialize) -> Self {
187 #[cfg(feature = "serde_json")]
188 let as_string = serde_json::to_string(&response);
189
190 #[cfg(feature = "sonic-rs")]
191 let as_string = sonic_rs::to_string(&response);
192
193 match as_string {
194 Ok(body) => {
195 if self.status().is_none() {
196 self.set_status(Status::Ok);
197 }
198
199 self.response_headers_mut()
200 .try_insert(ContentType, "application/json");
201
202 self.with_body(body)
203 }
204
205 Err(error) => self.with_state(Error::from(error)),
206 }
207 }
208
209 async fn deserialize<T>(&mut self) -> Result<T>
210 where
211 T: DeserializeOwned,
212 {
213 let body = self.request_body_string().await?;
214 let content_type = self.content_type()?;
215 let suffix_or_subtype = content_type
216 .suffix()
217 .unwrap_or_else(|| content_type.subtype())
218 .as_str();
219 match suffix_or_subtype {
220 #[cfg(feature = "serde_json")]
221 "json" => {
222 let json_deserializer = &mut serde_json::Deserializer::from_str(&body);
223 Ok(serde_path_to_error::deserialize::<_, T>(json_deserializer)?)
224 }
225
226 #[cfg(feature = "sonic-rs")]
227 "json" => {
228 let json_deserializer = &mut sonic_rs::serde::Deserializer::from_str(&body);
229 Ok(serde_path_to_error::deserialize::<_, T>(json_deserializer)?)
230 }
231
232 #[cfg(feature = "forms")]
233 "x-www-form-urlencoded" => {
234 let body = form_urlencoded::parse(body.as_bytes());
235 let deserializer = serde_urlencoded::Deserializer::new(body);
236 Ok(serde_path_to_error::deserialize::<_, T>(deserializer)?)
237 }
238
239 _ => {
240 drop(body);
241 Err(Error::UnsupportedMimeType {
242 mime_type: content_type.to_string(),
243 })
244 }
245 }
246 }
247
248 fn content_type(&self) -> Result<Mime> {
249 let header_str = self
250 .request_headers()
251 .get_str(ContentType)
252 .ok_or(Error::MissingContentType)?;
253
254 header_str.parse().map_err(|_| Error::UnsupportedMimeType {
255 mime_type: header_str.into(),
256 })
257 }
258
259 #[cfg(any(feature = "sonic-rs", feature = "serde_json"))]
260 #[cfg_attr(docsrs, doc(cfg(any(feature = "sonic-rs", feature = "serde_json"))))]
261 async fn deserialize_json<T>(&mut self) -> Result<T>
262 where
263 T: DeserializeOwned,
264 {
265 let content_type = self.content_type()?;
266 let suffix_or_subtype = content_type
267 .suffix()
268 .unwrap_or_else(|| content_type.subtype())
269 .as_str();
270 if suffix_or_subtype != "json" {
271 return Err(Error::UnsupportedMimeType {
272 mime_type: content_type.to_string(),
273 });
274 }
275
276 log::debug!("extracting json");
277 let body = self.request_body_string().await?;
278
279 #[cfg(feature = "serde_json")]
280 let json_deserializer = &mut serde_json::Deserializer::from_str(&body);
281
282 #[cfg(feature = "sonic-rs")]
283 let json_deserializer = &mut sonic_rs::serde::Deserializer::from_str(&body);
284
285 Ok(serde_path_to_error::deserialize::<_, T>(json_deserializer)?)
286 }
287
288 async fn serialize<T>(&mut self, body: &T) -> Result<()>
289 where
290 T: Serialize + Sync,
291 {
292 let accept = self
293 .request_headers()
294 .get_str(Accept)
295 .unwrap_or("*/*")
296 .split(',')
297 .map(|s| s.trim())
298 .find_map(acceptable_mime_type);
299
300 match accept {
301 #[cfg(feature = "serde_json")]
302 Some(AcceptableMime::Json) => {
303 self.set_body(serde_json::to_string(body)?);
304 self.insert_response_header(ContentType, "application/json");
305 Ok(())
306 }
307
308 #[cfg(feature = "sonic-rs")]
309 Some(AcceptableMime::Json) => {
310 self.set_body(sonic_rs::to_string(body)?);
311 self.insert_response_header(ContentType, "application/json");
312 Ok(())
313 }
314
315 #[cfg(feature = "forms")]
316 Some(AcceptableMime::Form) => {
317 self.set_body(serde_urlencoded::to_string(body)?);
318 self.insert_response_header(ContentType, "application/x-www-form-urlencoded");
319 Ok(())
320 }
321
322 None => {
323 let _ = body;
324 Err(Error::FailureToNegotiateContent)
325 }
326 }
327 }
328}
329
330enum AcceptableMime {
331 #[cfg_attr(docsrs, doc(cfg(any(feature = "sonic-rs", feature = "serde_json"))))]
332 #[cfg(any(feature = "sonic-rs", feature = "serde_json"))]
333 Json,
334
335 #[cfg(feature = "forms")]
336 #[cfg_attr(docsrs, doc(cfg(feature = "forms")))]
337 Form,
338}
339
340fn acceptable_mime_type(mime: &str) -> Option<AcceptableMime> {
341 let mime: Mime = mime.parse().ok()?;
342 let suffix_or_subtype = mime.suffix().unwrap_or_else(|| mime.subtype()).as_str();
343 match suffix_or_subtype {
344 #[cfg(any(feature = "serde_json", feature = "sonic-rs"))]
345 "*" | "json" => Some(AcceptableMime::Json),
346
347 #[cfg(feature = "forms")]
348 "x-www-form-urlencoded" => Some(AcceptableMime::Form),
349
350 _ => None,
351 }
352}