Skip to main content

trillium_http/
conn.rs

1use crate::{
2    Body, Buffer, Headers, HttpContext, KnownHeaderName,
3    KnownHeaderName::Host,
4    Method, ProtocolSession, ReceivedBody, Status, Swansong, TypeSet, Version,
5    after_send::{AfterSend, SendStatus},
6    h2::H2Connection,
7    h3::H3Connection,
8    headers::hpack::FieldSection,
9    liveness::{CancelOnDisconnect, LivenessFut},
10    received_body::ReceivedBodyState,
11    util::encoding,
12};
13
14/// Header names whose semantics only apply at the HTTP/1 layer.
15///
16/// HTTP/2 (RFC 9113) and HTTP/3 (RFC 9114) call these "connection-specific"
17/// and forbid them in requests and responses.
18pub(super) const H1_ONLY_HEADERS: [KnownHeaderName; 5] = [
19    KnownHeaderName::Connection,
20    KnownHeaderName::KeepAlive,
21    KnownHeaderName::ProxyConnection,
22    KnownHeaderName::TransferEncoding,
23    KnownHeaderName::Upgrade,
24];
25
26/// Validated request pseudo-headers + headers, the common output of
27/// [`validate_h2h3_request`].
28pub(super) struct ValidatedRequest {
29    pub method: Method,
30    pub path: Cow<'static, str>,
31    pub authority: Option<Cow<'static, str>>,
32    pub scheme: Option<Cow<'static, str>>,
33    pub protocol: Option<Cow<'static, str>>,
34    pub request_headers: Headers,
35}
36
37/// Shared HTTP/2 + HTTP/3 request-validation per RFC 9113 and RFC 9114.
38///
39/// Both protocols apply the same malformed-message rules to incoming requests:
40/// no `:status` pseudo, required `:method`, non-empty `:path` (or CONNECT default),
41/// `:scheme` required for non-CONNECT, `:authority` required for CONNECT, `:authority`
42/// or `Host` required when `:scheme` is `http`/`https`, no `Host`/`:authority`
43/// mismatch, no [`H1_ONLY_HEADERS`], and `TE` restricted to `trailers`. Returns `None`
44/// on any violation; the caller maps to its protocol-specific error code.
45pub(super) fn validate_h2h3_request(
46    mut field_section: FieldSection<'static>,
47) -> Option<ValidatedRequest> {
48    let pseudo_headers = field_section.pseudo_headers_mut();
49
50    // `:status` is response-only; reject it on requests.
51    if pseudo_headers.status().is_some() {
52        return None;
53    }
54
55    let method = pseudo_headers.take_method();
56    let path = pseudo_headers.take_path();
57    let authority = pseudo_headers.take_authority();
58    let scheme = pseudo_headers.take_scheme();
59    let protocol = pseudo_headers.take_protocol();
60    let request_headers = field_section.into_headers().into_owned();
61
62    if let Some(host) = request_headers.get_str(Host)
63        && let Some(authority) = &authority
64        && host != authority.as_ref()
65    {
66        return None;
67    }
68
69    if H1_ONLY_HEADERS
70        .into_iter()
71        .any(|name| request_headers.has_header(name))
72    {
73        return None;
74    }
75
76    let method = method?;
77
78    if method != Method::Connect && scheme.is_none() {
79        return None;
80    }
81
82    let path = match (method, path) {
83        (_, Some(path)) if !path.is_empty() => path,
84        (Method::Connect, _) => Cow::Borrowed("/"),
85        _ => return None,
86    };
87
88    if method == Method::Connect && authority.is_none() {
89        return None;
90    }
91
92    // When :scheme names a scheme with a mandatory authority component, the request
93    // MUST carry either :authority or a Host header. The spec gives "http" and "https"
94    // as the canonical examples; we also include "ws" and "wss" (same
95    // hierarchical-with-mandatory-authority shape) so the rule applies consistently if
96    // a non-standard sender uses those. Exotic schemes without mandatory authority
97    // (file, data, mailto, urn) are exempt; CONNECT is handled above.
98    if method != Method::Connect
99        && matches!(scheme.as_deref(), Some("http" | "https" | "ws" | "wss"))
100        && authority.is_none()
101        && request_headers.get_str(Host).is_none()
102    {
103        return None;
104    }
105
106    match request_headers.get_str(KnownHeaderName::Te) {
107        None | Some("trailers") => {}
108        _ => return None,
109    }
110
111    Some(ValidatedRequest {
112        method,
113        path,
114        authority,
115        scheme,
116        protocol,
117        request_headers,
118    })
119}
120use encoding_rs::Encoding;
121use futures_lite::{
122    future,
123    io::{AsyncRead, AsyncWrite},
124};
125use std::{
126    borrow::Cow,
127    fmt::{self, Debug, Formatter},
128    future::Future,
129    net::IpAddr,
130    pin::pin,
131    str,
132    sync::Arc,
133    time::Instant,
134};
135mod h1;
136mod h2;
137mod h3;
138pub(crate) use h3::H3FirstFrame;
139
140/// An HTTP connection.
141///
142/// This struct represents both the request and the response, and holds the
143/// transport over which the response will be sent.
144#[derive(fieldwork::Fieldwork)]
145pub struct Conn<Transport> {
146    #[field(get)]
147    /// the shared [`HttpContext`]
148    pub(crate) context: Arc<HttpContext>,
149
150    /// request [headers](Headers)
151    #[field(get, get_mut)]
152    pub(crate) request_headers: Headers,
153
154    /// response [headers](Headers)
155    #[field(get, get_mut)]
156    pub(crate) response_headers: Headers,
157
158    pub(crate) path: Cow<'static, str>,
159
160    /// the http method for this conn's request
161    ///
162    /// ```
163    /// # use trillium_http::{Conn, Method};
164    /// let mut conn = Conn::new_synthetic(Method::Get, "/some/path?and&a=query", ());
165    /// assert_eq!(conn.method(), Method::Get);
166    /// ```
167    #[field(get, set, copy)]
168    pub(crate) method: Method,
169
170    /// the http status for this conn, if set
171    #[field(get, copy)]
172    pub(crate) status: Option<Status>,
173
174    /// The HTTP protocol version in use on this connection.
175    ///
176    /// ```
177    /// # use trillium_http::{Conn, Method, Version};
178    /// let conn = Conn::new_synthetic(Method::Get, "/", ());
179    /// assert_eq!(conn.http_version(), Version::Http1_1);
180    /// ```
181    #[field(get = http_version, copy)]
182    pub(crate) version: Version,
183
184    /// the [state typemap](TypeSet) for this conn
185    #[field(get, get_mut)]
186    pub(crate) state: TypeSet,
187
188    /// the response [body](Body)
189    ///
190    /// ```
191    /// # use trillium_testing::HttpTest;
192    /// HttpTest::new(|conn| async move { conn.with_response_body("hello") })
193    ///     .get("/")
194    ///     .block()
195    ///     .assert_body("hello");
196    ///
197    /// HttpTest::new(|conn| async move { conn.with_response_body(String::from("world")) })
198    ///     .get("/")
199    ///     .block()
200    ///     .assert_body("world");
201    ///
202    /// HttpTest::new(|conn| async move { conn.with_response_body(vec![99, 97, 116]) })
203    ///     .get("/")
204    ///     .block()
205    ///     .assert_body("cat");
206    /// ```
207    #[field(get, set, into, option_set_some, take, with)]
208    pub(crate) response_body: Option<Body>,
209
210    /// the transport
211    ///
212    /// This should only be used to call your own custom methods on the transport that do not read
213    /// or write any data. Calling any method that reads from or writes to the transport will
214    /// disrupt the HTTP protocol. If you're looking to transition from HTTP to another protocol,
215    /// use an HTTP upgrade.
216    #[field(get, get_mut)]
217    pub(crate) transport: Transport,
218
219    pub(crate) buffer: Buffer,
220
221    pub(crate) request_body_state: ReceivedBodyState,
222
223    pub(crate) after_send: AfterSend,
224
225    /// whether the connection is secure
226    ///
227    /// note that this does not necessarily indicate that the transport itself is secure, as it may
228    /// indicate that `trillium_http` is behind a trusted reverse proxy that has terminated tls and
229    /// provided appropriate headers to indicate this.
230    #[field(get, set, rename_predicates)]
231    pub(crate) secure: bool,
232
233    /// The [`Instant`] that the first header bytes for this conn were
234    /// received, before any processing or parsing has been performed.
235    #[field(get, copy)]
236    pub(crate) start_time: Instant,
237
238    /// The IP Address for the connection, if available
239    #[field(set, get, copy, into)]
240    pub(crate) peer_ip: Option<IpAddr>,
241
242    /// the `:authority` pseudo-header
243    #[field(set, get, into)]
244    pub(crate) authority: Option<Cow<'static, str>>,
245
246    /// the `:scheme` pseudo-header
247    #[field(set, get, into)]
248    pub(crate) scheme: Option<Cow<'static, str>>,
249
250    /// the [`ProtocolSession`] for this conn — the per-protocol session state
251    /// (h2/h3 connection driver and stream id) bundled into a single enum so the
252    /// "set together" invariant is enforced at the type level. `Http1` for
253    /// h1 / synthetic conns.
254    pub(crate) protocol_session: ProtocolSession,
255
256    /// the `:protocol` pseudo-header (extended CONNECT)
257    #[field(set, get, into)]
258    pub(crate) protocol: Option<Cow<'static, str>>,
259
260    /// request trailers, populated after the request body has been fully read
261    #[field(get, get_mut)]
262    pub(crate) request_trailers: Option<Headers>,
263}
264
265impl<Transport> Debug for Conn<Transport> {
266    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
267        f.debug_struct("Conn")
268            .field("context", &self.context)
269            .field("request_headers", &self.request_headers)
270            .field("response_headers", &self.response_headers)
271            .field("path", &self.path)
272            .field("method", &self.method)
273            .field("status", &self.status)
274            .field("version", &self.version)
275            .field("state", &self.state)
276            .field("response_body", &self.response_body)
277            .field("transport", &format_args!(".."))
278            .field("buffer", &format_args!(".."))
279            .field("request_body_state", &self.request_body_state)
280            .field("secure", &self.secure)
281            .field("after_send", &format_args!(".."))
282            .field("start_time", &self.start_time)
283            .field("peer_ip", &self.peer_ip)
284            .field("authority", &self.authority)
285            .field("scheme", &self.scheme)
286            .field("protocol", &self.protocol)
287            .field("protocol_session", &self.protocol_session)
288            .field("request_trailers", &self.request_trailers)
289            .finish()
290    }
291}
292
293impl<Transport> Conn<Transport>
294where
295    Transport: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static,
296{
297    /// Returns the shared state typemap for this conn.
298    pub fn shared_state(&self) -> &TypeSet {
299        &self.context.shared_state
300    }
301
302    /// sets the http status code from any `TryInto<Status>`.
303    ///
304    /// ```
305    /// # use trillium_http::Status;
306    /// # trillium_testing::HttpTest::new(|mut conn| async move {
307    /// assert!(conn.status().is_none());
308    ///
309    /// conn.set_status(200); // a status can be set as a u16
310    /// assert_eq!(conn.status().unwrap(), Status::Ok);
311    ///
312    /// conn.set_status(Status::ImATeapot); // or as a Status
313    /// assert_eq!(conn.status().unwrap(), Status::ImATeapot);
314    /// conn
315    /// # }).get("/").block().assert_status(Status::ImATeapot);
316    /// ```
317    pub fn set_status(&mut self, status: impl TryInto<Status>) -> &mut Self {
318        self.status = Some(status.try_into().unwrap_or_else(|_| {
319            log::error!("attempted to set an invalid status code");
320            Status::InternalServerError
321        }));
322        self
323    }
324
325    /// sets the http status code from any `TryInto<Status>`, returning Conn
326    #[must_use]
327    pub fn with_status(mut self, status: impl TryInto<Status>) -> Self {
328        self.set_status(status);
329        self
330    }
331
332    /// retrieves the path part of the request url, up to and excluding any query component
333    /// ```
334    /// # use trillium_testing::HttpTest;
335    /// HttpTest::new(|mut conn| async move {
336    ///     assert_eq!(conn.path(), "/some/path");
337    ///     conn.with_status(200)
338    /// })
339    /// .get("/some/path?and&a=query")
340    /// .block()
341    /// .assert_ok();
342    /// ```
343    pub fn path(&self) -> &str {
344        match self.path.split_once('?') {
345            Some((path, _)) => path,
346            None => &self.path,
347        }
348    }
349
350    /// retrieves the combined path and any query
351    pub fn path_and_query(&self) -> &str {
352        &self.path
353    }
354
355    /// retrieves the query component of the path, or an empty &str
356    ///
357    /// ```
358    /// # use trillium_testing::HttpTest;
359    /// let server = HttpTest::new(|conn| async move {
360    ///     let querystring = conn.querystring().to_string();
361    ///     conn.with_response_body(querystring).with_status(200)
362    /// });
363    ///
364    /// server
365    ///     .get("/some/path?and&a=query")
366    ///     .block()
367    ///     .assert_body("and&a=query");
368    ///
369    /// server.get("/some/path").block().assert_body("");
370    /// ```
371    pub fn querystring(&self) -> &str {
372        self.path
373            .split_once('?')
374            .map(|(_, query)| query)
375            .unwrap_or_default()
376    }
377
378    /// get the host for this conn, if it exists
379    pub fn host(&self) -> Option<&str> {
380        self.request_headers.get_str(Host)
381    }
382
383    /// set the host for this conn
384    pub fn set_host(&mut self, host: String) -> &mut Self {
385        self.request_headers.insert(Host, host);
386        self
387    }
388
389    /// Cancels and drops the future if reading from the transport results in an error or empty read
390    ///
391    /// The use of this method is not advised if your connected http client employs pipelining
392    /// (rarely seen in the wild), as it will buffer an unbounded number of requests one byte at a
393    /// time
394    ///
395    /// If the client disconnects from the conn's transport, this function will return None. If the
396    /// future completes without disconnection, this future will return Some containing the output
397    /// of the future.
398    ///
399    /// Note that the inner future cannot borrow conn, so you will need to clone or take any
400    /// information needed to execute the future prior to executing this method.
401    ///
402    /// # Example
403    ///
404    /// ```rust
405    /// # use futures_lite::{AsyncRead, AsyncWrite};
406    /// # use trillium_http::{Conn, Method};
407    /// async fn something_slow_and_cancel_safe() -> String {
408    ///     String::from("this was not actually slow")
409    /// }
410    /// async fn handler<T>(mut conn: Conn<T>) -> Conn<T>
411    /// where
412    ///     T: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static,
413    /// {
414    ///     let Some(returned_body) = conn
415    ///         .cancel_on_disconnect(async { something_slow_and_cancel_safe().await })
416    ///         .await
417    ///     else {
418    ///         return conn;
419    ///     };
420    ///     conn.with_response_body(returned_body).with_status(200)
421    /// }
422    /// ```
423    pub async fn cancel_on_disconnect<'a, Fut>(&'a mut self, fut: Fut) -> Option<Fut::Output>
424    where
425        Fut: Future + Send + 'a,
426    {
427        CancelOnDisconnect(self, pin!(fut)).await
428    }
429
430    /// Check if the transport is connected by attempting to read from the transport
431    ///
432    /// # Example
433    ///
434    /// This is best to use at appropriate points in a long-running handler, like:
435    ///
436    /// ```rust
437    /// # use futures_lite::{AsyncRead, AsyncWrite};
438    /// # use trillium_http::{Conn, Method};
439    /// # async fn something_slow_but_not_cancel_safe() {}
440    /// async fn handler<T>(mut conn: Conn<T>) -> Conn<T>
441    /// where
442    ///     T: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static,
443    /// {
444    ///     for _ in 0..100 {
445    ///         if conn.is_disconnected().await {
446    ///             return conn;
447    ///         }
448    ///         something_slow_but_not_cancel_safe().await;
449    ///     }
450    ///     conn.with_status(200)
451    /// }
452    /// ```
453    pub async fn is_disconnected(&mut self) -> bool {
454        future::poll_once(LivenessFut::new(self)).await.is_some()
455    }
456
457    /// returns the [`encoding_rs::Encoding`] for this request, as determined from the mime-type
458    /// charset, if available
459    ///
460    /// ```
461    /// # use trillium_testing::HttpTest;
462    /// HttpTest::new(|mut conn| async move {
463    ///     assert_eq!(conn.request_encoding(), encoding_rs::WINDOWS_1252); // the default
464    ///
465    ///     conn.request_headers_mut()
466    ///         .insert("content-type", "text/plain;charset=utf-16");
467    ///     assert_eq!(conn.request_encoding(), encoding_rs::UTF_16LE);
468    ///
469    ///     conn.with_status(200)
470    /// })
471    /// .get("/")
472    /// .block()
473    /// .assert_ok();
474    /// ```
475    pub fn request_encoding(&self) -> &'static Encoding {
476        encoding(&self.request_headers)
477    }
478
479    /// returns the [`encoding_rs::Encoding`] for this response, as
480    /// determined from the mime-type charset, if available
481    ///
482    /// ```
483    /// # use trillium_testing::HttpTest;
484    /// HttpTest::new(|mut conn| async move {
485    ///     assert_eq!(conn.response_encoding(), encoding_rs::WINDOWS_1252); // the default
486    ///     conn.response_headers_mut()
487    ///         .insert("content-type", "text/plain;charset=utf-16");
488    ///
489    ///     assert_eq!(conn.response_encoding(), encoding_rs::UTF_16LE);
490    ///
491    ///     conn.with_status(200)
492    /// })
493    /// .get("/")
494    /// .block()
495    /// .assert_ok();
496    /// ```
497    pub fn response_encoding(&self) -> &'static Encoding {
498        encoding(&self.response_headers)
499    }
500
501    /// returns a [`ReceivedBody`] that references this conn. the conn
502    /// retains all data and holds the singular transport, but the
503    /// `ReceivedBody` provides an interface to read body content.
504    ///
505    /// If the request included an `Expect: 100-continue` header, the 100 Continue response is sent
506    /// lazily on the first read from the returned [`ReceivedBody`].
507    /// ```
508    /// # use trillium_testing::HttpTest;
509    /// let server = HttpTest::new(|mut conn| async move {
510    ///     let request_body = conn.request_body();
511    ///     assert_eq!(request_body.content_length(), Some(5));
512    ///     assert_eq!(request_body.read_string().await.unwrap(), "hello");
513    ///     conn.with_status(200)
514    /// });
515    ///
516    /// server.post("/").with_body("hello").block().assert_ok();
517    /// ```
518    pub fn request_body(&mut self) -> ReceivedBody<'_, Transport> {
519        let needs_100_continue = self.needs_100_continue();
520        let body = self.build_request_body();
521        if needs_100_continue {
522            body.with_send_100_continue()
523        } else {
524            body
525        }
526    }
527
528    /// returns a clone of the [`swansong::Swansong`] for this Conn. use
529    /// this to gracefully stop long-running futures and streams
530    /// inside of handler functions
531    pub fn swansong(&self) -> Swansong {
532        self.protocol_session
533            .h3_connection()
534            .map_or_else(|| self.context.swansong.clone(), |h| h.swansong().clone())
535    }
536
537    /// Registers a function to call after the http response has been
538    /// completely transferred.
539    ///
540    /// The callback is guaranteed to fire **exactly once** before the conn is
541    /// dropped. Either the codec's send path invokes it with the real outcome,
542    /// or — if the conn is dropped before send completes (handler panic,
543    /// transport error, mid-write disconnect) — the drop fallback invokes it
544    /// with a `SendStatus` whose `is_success()` returns false. Multiple
545    /// registrations on the same conn chain in registration order.
546    ///
547    /// Because firing is ordered by send-completion rather than handler return,
548    /// this is the right hook for instrumentation that wants to report what the
549    /// peer actually observed.
550    ///
551    /// This is a sync function and should be computationally lightweight. If
552    /// your _application_ needs additional async processing, use your runtime's
553    /// task spawn within this hook. If your _library_ needs additional async
554    /// processing in an `after_send` hook, please open an issue.
555    pub fn after_send<F>(&mut self, after_send: F)
556    where
557        F: FnOnce(SendStatus) + Send + Sync + 'static,
558    {
559        self.after_send.append(after_send);
560    }
561
562    /// applies a mapping function from one transport to another. This
563    /// is particularly useful for boxing the transport. unless you're
564    /// sure this is what you're looking for, you probably don't want
565    /// to be using this
566    pub fn map_transport<NewTransport>(
567        self,
568        f: impl Fn(Transport) -> NewTransport,
569    ) -> Conn<NewTransport>
570    where
571        NewTransport: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static,
572    {
573        // Manual respread: rustc treats `Conn<Transport>` and `Conn<NewTransport>` as
574        // disjoint types and rejects `..self` without the unstable
575        // `type_changing_struct_update` feature. If a new field is added to `Conn`,
576        // update this respread, `Upgrade::map_transport`, and `From<Conn> for Upgrade`
577        // (`upgrade.rs`) — they share this drift hazard.
578        Conn {
579            context: self.context,
580            request_headers: self.request_headers,
581            response_headers: self.response_headers,
582            method: self.method,
583            response_body: self.response_body,
584            path: self.path,
585            status: self.status,
586            version: self.version,
587            state: self.state,
588            transport: f(self.transport),
589            buffer: self.buffer,
590            request_body_state: self.request_body_state,
591            secure: self.secure,
592            after_send: self.after_send,
593            start_time: self.start_time,
594            peer_ip: self.peer_ip,
595            authority: self.authority,
596            scheme: self.scheme,
597            protocol: self.protocol,
598            protocol_session: self.protocol_session,
599            request_trailers: self.request_trailers,
600        }
601    }
602
603    /// whether this conn is suitable for an http upgrade to another protocol
604    pub fn should_upgrade(&self) -> bool {
605        (self.method() == Method::Connect && self.status == Some(Status::Ok))
606            || self.status == Some(Status::SwitchingProtocols)
607    }
608
609    #[doc(hidden)]
610    pub fn finalize_headers(&mut self) {
611        if self.version == Version::Http3 {
612            self.finalize_response_headers_h3();
613        } else {
614            self.finalize_response_headers_1x();
615        }
616    }
617
618    /// the [`H2Connection`] driver for this conn, if this is an HTTP/2 request
619    pub fn h2_connection(&self) -> Option<&Arc<H2Connection>> {
620        self.protocol_session.h2_connection()
621    }
622
623    /// the h2 stream id for this conn, if this is an HTTP/2 request
624    pub fn h2_stream_id(&self) -> Option<u32> {
625        self.protocol_session.h2_stream_id()
626    }
627
628    /// the [`H3Connection`] driver for this conn, if this is an HTTP/3 request
629    pub fn h3_connection(&self) -> Option<&Arc<H3Connection>> {
630        self.protocol_session.h3_connection()
631    }
632
633    /// the h3 stream id for this conn, if this is an HTTP/3 request
634    pub fn h3_stream_id(&self) -> Option<u64> {
635        self.protocol_session.h3_stream_id()
636    }
637}