Skip to main content

trillium_webtransport/
lib.rs

1//! WebTransport support for Trillium.
2//!
3//! This crate provides a [`WebTransport`] handler that accepts WebTransport sessions over
4//! HTTP/3, and a [`WebTransportConnection`] handle for sending and receiving streams and
5//! datagrams within each session.
6//!
7//! WebTransport requires an HTTP/3-capable server adapter configured with a QUIC endpoint
8//! and TLS.
9//!
10//! # Client
11//!
12//! [`WebTransportConnection`] is also the type returned by `trillium_client`'s
13//! `Client::webtransport(url).into_webtransport().await` (cargo feature `webtransport` on
14//! [trillium-client](https://docs.rs/trillium-client)). The client API is symmetric with the
15//! server: the same `accept_bidi` / `accept_uni` / `recv_datagram` / `open_bidi` /
16//! `open_uni` / `send_datagram` methods work on either side. Multiple client sessions to the
17//! same origin coalesce onto a single underlying QUIC connection.
18
19#[cfg(test)]
20#[doc = include_str!("../README.md")]
21mod readme {}
22
23mod session_router;
24mod stream;
25
26/// Internal multiplexing primitives shared with `trillium-client` to coalesce multiple
27/// WebTransport sessions over a single QUIC connection. Not part of the stable API; users
28/// should not depend on these.
29#[doc(hidden)]
30pub use crate::session_router::{Router, SessionRouter};
31pub use crate::stream::{
32    Datagram, InboundBidiStream, InboundStream, InboundUniStream, OutboundBidiStream,
33    OutboundUniStream,
34};
35use async_channel::Receiver;
36use futures_lite::AsyncWriteExt;
37use std::{
38    borrow::Cow,
39    io,
40    net::SocketAddr,
41    sync::{Arc, OnceLock},
42};
43use swansong::Swansong;
44use trillium::{Conn, Handler, Info, Method, Status, Transport, Upgrade};
45use trillium_http::{
46    Headers, TypeSet,
47    h3::{H3Connection, quic_varint},
48};
49use trillium_server_common::{
50    QuicConnection, Runtime,
51    h3::{StreamId, web_transport::WebTransportDispatcher},
52};
53
54/// A handle to an active WebTransport session.
55///
56/// Passed to your [`WebTransportHandler`] when a client opens a WebTransport session.
57/// Use it to accept streams from the client, open server-initiated streams, and exchange
58/// datagrams.
59pub struct WebTransportConnection {
60    session_id: u64,
61    bidi_rx: Receiver<InboundBidiStream>,
62    uni_rx: Receiver<InboundUniStream>,
63    datagram_rx: Receiver<Datagram>,
64    swansong: Swansong,
65    request_headers: Headers,
66    response_headers: Headers,
67    state: TypeSet,
68    path: Option<Cow<'static, str>>,
69    authority: Option<Cow<'static, str>>,
70    h3_connection: Arc<H3Connection>,
71    quic_connection: QuicConnection,
72    runtime: Runtime,
73}
74
75impl WebTransportConnection {
76    /// Construct a new `WebTransportConnection`. Internal API used by trillium-client and the
77    /// trillium-webtransport server handler to assemble a session from its parts. Not for
78    /// downstream use.
79    #[doc(hidden)]
80    #[allow(clippy::too_many_arguments)]
81    pub fn new(
82        session_id: u64,
83        bidi_rx: Receiver<InboundBidiStream>,
84        uni_rx: Receiver<InboundUniStream>,
85        datagram_rx: Receiver<Datagram>,
86        swansong: Swansong,
87        request_headers: Headers,
88        response_headers: Headers,
89        state: TypeSet,
90        path: Option<Cow<'static, str>>,
91        authority: Option<Cow<'static, str>>,
92        h3_connection: Arc<H3Connection>,
93        quic_connection: QuicConnection,
94        runtime: Runtime,
95    ) -> Self {
96        Self {
97            session_id,
98            bidi_rx,
99            uni_rx,
100            datagram_rx,
101            swansong,
102            request_headers,
103            response_headers,
104            state,
105            path,
106            authority,
107            h3_connection,
108            quic_connection,
109            runtime,
110        }
111    }
112
113    /// Accept the next inbound bidirectional stream for this session.
114    ///
115    /// Returns `None` when the session is shutting down or the QUIC connection has closed.
116    pub async fn accept_bidi(&self) -> Option<InboundBidiStream> {
117        self.swansong.interrupt(self.bidi_rx.recv()).await?.ok()
118    }
119
120    /// Returns the async runtime for this server.
121    pub fn runtime(&self) -> &Runtime {
122        &self.runtime
123    }
124
125    /// Returns the underlying HTTP/3 connection.
126    pub fn h3_connection(&self) -> &H3Connection {
127        &self.h3_connection
128    }
129
130    /// The headers from the CONNECT request that established this WebTransport session.
131    pub fn request_headers(&self) -> &Headers {
132        &self.request_headers
133    }
134
135    /// Mutably borrow the CONNECT request headers.
136    pub fn request_headers_mut(&mut self) -> &mut Headers {
137        &mut self.request_headers
138    }
139
140    /// The headers from the CONNECT response that established this WebTransport session.
141    ///
142    /// On the client side, these are the headers the server sent alongside its `200` response
143    /// to the extended CONNECT (e.g. server identification, custom extension hints,
144    /// CDN-injected headers, `Set-Cookie`). On the server side this is empty — by the time the
145    /// handler runs, the response has already been sent.
146    pub fn response_headers(&self) -> &Headers {
147        &self.response_headers
148    }
149
150    /// Mutably borrow the CONNECT response headers.
151    pub fn response_headers_mut(&mut self) -> &mut Headers {
152        &mut self.response_headers
153    }
154
155    /// Borrow the [`TypeSet`] of state accumulated by the handler chain before the upgrade.
156    pub fn state(&self) -> &TypeSet {
157        &self.state
158    }
159
160    /// Mutably borrow the [`TypeSet`] of state.
161    pub fn state_mut(&mut self) -> &mut TypeSet {
162        &mut self.state
163    }
164
165    /// The `:path` of the CONNECT request that established this session, identifying which
166    /// WebTransport endpoint the peer asked for.
167    pub fn path(&self) -> Option<&str> {
168        self.path.as_deref()
169    }
170
171    /// The `:authority` of the CONNECT request that established this session.
172    pub fn authority(&self) -> Option<&str> {
173        self.authority.as_deref()
174    }
175
176    /// The peer's socket address.
177    pub fn peer_addr(&self) -> SocketAddr {
178        self.quic_connection.remote_address()
179    }
180
181    /// Accept the next inbound unidirectional stream for this session.
182    ///
183    /// Returns `None` when the session is shutting down or the QUIC connection has closed.
184    pub async fn accept_uni(&self) -> Option<InboundUniStream> {
185        self.swansong.interrupt(self.uni_rx.recv()).await?.ok()
186    }
187
188    /// Receive the next datagram for this session.
189    ///
190    /// Returns `None` when the session is shutting down or the QUIC connection has closed.
191    pub async fn recv_datagram(&self) -> Option<Datagram> {
192        self.swansong.interrupt(self.datagram_rx.recv()).await?.ok()
193    }
194
195    /// Accept the next inbound stream for this session.
196    ///
197    /// Races the bidi and uni stream channels and returns whichever arrives first.
198    /// Returns `None` when the session ends.
199    ///
200    /// Datagrams are intentionally excluded — use [`recv_datagram`](Self::recv_datagram)
201    /// in a separate concurrent loop, as datagrams typically require lower latency
202    /// than stream acceptance.
203    pub async fn accept_next_stream(&self) -> Option<InboundStream> {
204        futures_lite::future::race(
205            async { self.accept_bidi().await.map(InboundStream::Bidi) },
206            async { self.accept_uni().await.map(InboundStream::Uni) },
207        )
208        .await
209    }
210
211    /// Send an unreliable datagram to the client.
212    ///
213    /// Returns an error if the QUIC connection does not support datagrams or the payload is
214    /// too large.
215    pub fn send_datagram(&self, payload: &[u8]) -> io::Result<()> {
216        let quarter_id = self.session_id / 4;
217        let header_len = quic_varint::encoded_len(quarter_id);
218        let mut buf = vec![0u8; header_len + payload.len()];
219        quic_varint::encode(quarter_id, &mut buf).unwrap();
220        buf[header_len..].copy_from_slice(payload);
221        self.quic_connection.send_datagram(&buf)
222    }
223
224    /// Open a new server-initiated bidirectional stream for this session.
225    pub async fn open_bidi(&self) -> io::Result<OutboundBidiStream> {
226        let (_stream_id, mut transport) = self.quic_connection.open_bidi().await?;
227        transport
228            .write_all(&wt_bidi_header(self.session_id))
229            .await?;
230        Ok(OutboundBidiStream::new(transport))
231    }
232
233    /// Open a new server-initiated unidirectional stream for this session.
234    pub async fn open_uni(&self) -> io::Result<OutboundUniStream> {
235        let (_stream_id, mut stream) = self.quic_connection.open_uni().await?;
236        stream.write_all(&wt_uni_header(self.session_id)).await?;
237        Ok(OutboundUniStream::new(stream))
238    }
239}
240
241/// Encode the bidi stream header: signal value 0x41 + session_id.
242fn wt_bidi_header(session_id: u64) -> Vec<u8> {
243    let mut buf =
244        vec![0u8; quic_varint::encoded_len(0x41u64) + quic_varint::encoded_len(session_id)];
245    let mut offset = quic_varint::encode(0x41u64, &mut buf).unwrap();
246    offset += quic_varint::encode(session_id, &mut buf[offset..]).unwrap();
247    buf.truncate(offset);
248    buf
249}
250
251/// Encode the uni stream header: stream type 0x54 + session_id.
252fn wt_uni_header(session_id: u64) -> Vec<u8> {
253    let mut buf =
254        vec![0u8; quic_varint::encoded_len(0x54u64) + quic_varint::encoded_len(session_id)];
255    let mut offset = quic_varint::encode(0x54u64, &mut buf).unwrap();
256    offset += quic_varint::encode(session_id, &mut buf[offset..]).unwrap();
257    buf.truncate(offset);
258    buf
259}
260
261/// Default datagram buffer cap per session: 16 datagrams. See
262/// [`WebTransport::with_max_datagram_buffer`] for tuning semantics.
263pub const DEFAULT_MAX_DATAGRAM_BUFFER: usize = 16;
264
265/// A Trillium [`Handler`] that accepts WebTransport sessions.
266///
267/// Add this to your handler chain and provide a [`WebTransportHandler`] (or a closure) to
268/// process each session.
269///
270/// # Example
271///
272/// ```no_run
273/// use trillium_webtransport::{WebTransport, WebTransportConnection};
274///
275/// let handler = WebTransport::new(|conn: WebTransportConnection| async move {
276///     while let Some(stream) = conn.accept_next_stream().await {
277///         // handle stream...
278/// # drop(stream);
279///     }
280/// });
281/// ```
282pub struct WebTransport<H> {
283    runtime: OnceLock<Runtime>,
284    max_datagram_buffer: usize,
285    handler: H,
286}
287
288/// A handler for WebTransport sessions.
289///
290/// Any `Fn(WebTransportConnection) -> impl Future<Output = ()>` automatically implements this
291/// trait, so you can pass a closure or async function directly to [`WebTransport::new`].
292pub trait WebTransportHandler: Send + Sync + 'static {
293    /// Handle a WebTransport session. Called once per client-initiated session.
294    fn run(
295        &self,
296        web_transport_connection: WebTransportConnection,
297    ) -> impl Future<Output = ()> + Send;
298}
299
300impl<Fun, Fut> WebTransportHandler for Fun
301where
302    Fun: Fn(WebTransportConnection) -> Fut + Send + Sync + 'static,
303    Fut: Future<Output = ()> + Send,
304{
305    async fn run(&self, web_transport_connection: WebTransportConnection) {
306        self(web_transport_connection).await
307    }
308}
309
310impl<H> WebTransport<H>
311where
312    H: WebTransportHandler,
313{
314    /// Create a new `WebTransport` handler that passes each session to `handler`.
315    pub fn new(handler: H) -> Self {
316        Self {
317            handler,
318            runtime: Default::default(),
319            max_datagram_buffer: DEFAULT_MAX_DATAGRAM_BUFFER,
320        }
321    }
322
323    /// Set the maximum number of datagrams to buffer per session.
324    ///
325    /// When the buffer is full, the oldest datagram is dropped to make room for the newest.
326    ///
327    /// - **`max > 1`** — FIFO ring-buffer that tolerates bursts up to `max` datagrams before
328    ///   dropping. Good for ordered event streams where some loss is acceptable.
329    /// - **`max = 1`** — "latest-only" semantics: if multiple datagrams arrive while your
330    ///   [`recv_datagram`](WebTransportConnection::recv_datagram) loop is busy, only the most
331    ///   recent is retained. Good for streaming state (positions, sensor readings) where older
332    ///   values are invalidated by newer ones.
333    ///
334    /// Default: 16.
335    pub fn with_max_datagram_buffer(mut self, max: usize) -> Self {
336        self.max_datagram_buffer = max;
337        self
338    }
339
340    fn runtime(&self) -> &Runtime {
341        self.runtime.get().unwrap()
342    }
343}
344
345struct WTUpgrade;
346
347impl<H> Handler for WebTransport<H>
348where
349    H: WebTransportHandler,
350{
351    async fn run(&self, conn: Conn) -> Conn {
352        let inner: &trillium_http::Conn<Box<dyn Transport>> = conn.as_ref();
353        if inner.state().contains::<QuicConnection>()
354            && conn.method() == Method::Connect
355            && inner.protocol() == Some("webtransport")
356        {
357            conn.with_state(WTUpgrade).with_status(Status::Ok).halt()
358        } else {
359            conn
360        }
361    }
362
363    async fn init(&mut self, info: &mut Info) {
364        self.runtime.get_or_init(|| {
365            info.shared_state::<Runtime>()
366                .cloned()
367                .expect("webtransport requires a Runtime")
368        });
369
370        info.config_mut()
371            .set_h3_datagrams_enabled(true)
372            .set_webtransport_enabled(true)
373            .set_extended_connect_enabled(true);
374    }
375
376    fn has_upgrade(&self, upgrade: &Upgrade) -> bool {
377        upgrade.state().get::<WTUpgrade>().is_some()
378    }
379
380    async fn upgrade(&self, mut upgrade: Upgrade) {
381        let Some(h3_connection) = upgrade.h3_connection() else {
382            log::error!("missing H3Connection in upgrade state");
383            return;
384        };
385        let Some(quic_connection) = upgrade.state_mut().take::<QuicConnection>() else {
386            log::error!("missing QuicConnection in upgrade state");
387            return;
388        };
389        let Some(stream_id) = upgrade.state_mut().take::<StreamId>() else {
390            log::error!("missing StreamId in upgrade state");
391            return;
392        };
393        let Some(dispatcher) = upgrade.state().get::<WebTransportDispatcher>().cloned() else {
394            log::error!("missing WebTransportDispatcher in upgrade state");
395            return;
396        };
397
398        let max_datagram_buffer = self.max_datagram_buffer;
399        let Some(router) = dispatcher.get_or_init_with(|| Router::new(max_datagram_buffer)) else {
400            log::error!("WebTransportDispatcher has a handler of an unexpected type");
401            return;
402        };
403
404        // No-op if a previous session on this connection already started the task.
405        router
406            .clone()
407            .spawn_routing_task(quic_connection.clone(), self.runtime().clone());
408
409        let session_id = stream_id.into();
410        log::trace!("starting webtransport session {session_id}");
411        let session_swansong = h3_connection.swansong().child();
412        let (bidi_rx, uni_rx, datagram_rx) = router.sessions().lock().await.register(session_id);
413
414        let runtime = self.runtime().clone();
415
416        let inner = upgrade.as_mut();
417        let request_headers = std::mem::take(inner.request_headers_mut());
418        let response_headers = std::mem::take(inner.response_headers_mut());
419        let state = std::mem::take(inner.state_mut());
420        let authority = inner.take_authority();
421        let path = Some(std::mem::take(inner.path_mut()));
422
423        self.handler
424            .run(WebTransportConnection {
425                session_id,
426                bidi_rx,
427                uni_rx,
428                datagram_rx,
429                swansong: session_swansong.clone(),
430                request_headers,
431                response_headers,
432                state,
433                path,
434                authority,
435                h3_connection,
436                quic_connection,
437                runtime,
438            })
439            .await;
440
441        log::trace!("finished handler, cleaning up");
442
443        session_swansong.shut_down().await;
444        router.sessions().lock().await.unregister(session_id);
445    }
446}