Skip to main content

trillium_server_common/h3/
web_transport.rs

1//! WebTransport types
2
3use crate::quic::{BoxedBidiStream, BoxedRecvStream};
4use std::{
5    any::Any,
6    fmt::{self, Debug},
7    sync::{Arc, RwLock},
8};
9
10/// An inbound WebTransport stream, dispatched by [`WebTransportDispatcher`] to the registered
11/// handler.
12#[derive(fieldwork::Fieldwork)]
13#[fieldwork(get)]
14pub enum WebTransportStream {
15    /// A bidirectional stream.
16    Bidi {
17        /// The WebTransport session ID (stream ID of the CONNECT request).
18        session_id: u64,
19        /// The stream transport, ready for application data.
20        stream: BoxedBidiStream,
21        /// Any bytes buffered after the session ID during stream negotiation.
22        buffer: Vec<u8>,
23    },
24    /// A unidirectional stream.
25    Uni {
26        /// The WebTransport session ID.
27        session_id: u64,
28        /// The receive stream, ready for application data.
29        stream: BoxedRecvStream,
30        /// Any bytes buffered after the session ID during stream negotiation.
31        buffer: Vec<u8>,
32    },
33}
34
35impl Debug for WebTransportStream {
36    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
37        match self {
38            Self::Bidi { session_id, .. } => f
39                .debug_struct("WebTransportStream::Bidi")
40                .field("session_id", session_id)
41                .finish_non_exhaustive(),
42            Self::Uni { session_id, .. } => f
43                .debug_struct("WebTransportStream::Uni")
44                .field("session_id", session_id)
45                .finish_non_exhaustive(),
46        }
47    }
48}
49
50/// Trait for receiving dispatched WebTransport streams.
51///
52/// Implementors are registered with [`WebTransportDispatcher`] and receive each inbound stream
53/// via [`dispatch`](WebTransportDispatch::dispatch).
54pub trait WebTransportDispatch: Any + Send + Sync {
55    /// Handle an inbound WebTransport stream.
56    fn dispatch(&self, stream: WebTransportStream);
57}
58
59/// Routing state for inbound WebTransport streams on a single QUIC connection.
60enum DispatchState {
61    /// No handler registered yet. Early-arriving streams are buffered.
62    Buffering(Vec<WebTransportStream>),
63
64    /// A handler has been registered. Streams are dispatched directly.
65    Active(Arc<dyn WebTransportDispatch>),
66}
67
68/// Dispatcher for inbound WebTransport streams on a QUIC connection.
69///
70/// Bridges the QUIC connection handler, which delivers streams as they arrive, with WebTransport
71/// session handlers that register later via [`get_or_init_with`](Self::get_or_init_with).
72/// Streams that arrive before a handler registers are buffered and delivered when the handler
73/// registers.
74///
75/// Cheaply cloneable.
76#[derive(Clone)]
77pub struct WebTransportDispatcher(Arc<RwLock<DispatchState>>);
78impl Default for WebTransportDispatcher {
79    fn default() -> Self {
80        Self(Arc::new(RwLock::new(DispatchState::Buffering(Vec::new()))))
81    }
82}
83
84impl Debug for WebTransportDispatcher {
85    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
86        let state = self.0.read().expect("dispatcher lock poisoned");
87        let label = match &*state {
88            DispatchState::Buffering(buf) => {
89                format!("Buffering({} streams)", buf.len())
90            }
91            DispatchState::Active(_) => "Active".to_string(),
92        };
93        f.debug_tuple("WebTransportDispatcher")
94            .field(&label)
95            .finish()
96    }
97}
98
99impl WebTransportDispatcher {
100    /// Create a new dispatcher in the buffering state.
101    pub fn new() -> Self {
102        Self::default()
103    }
104
105    /// Dispatch an inbound WebTransport stream to the registered handler, or buffer it.
106    pub fn dispatch(&self, stream: WebTransportStream) {
107        // Fast path: handler is registered, take a read lock.
108        {
109            let state = self.0.read().expect("dispatcher lock poisoned");
110            if let DispatchState::Active(handler) = &*state {
111                handler.dispatch(stream);
112                return;
113            }
114        }
115
116        // Slow path: still buffering, take a write lock.
117        {
118            let mut state = self.0.write().expect("dispatcher lock poisoned");
119            match &*state {
120                DispatchState::Buffering(_) => {
121                    let DispatchState::Buffering(buf) = &mut *state else {
122                        unreachable!()
123                    };
124                    buf.push(stream);
125                }
126                DispatchState::Active(handler) => handler.dispatch(stream),
127            }
128        }
129    }
130
131    /// Get or initialize the dispatch handler.
132    ///
133    /// If no handler is registered yet, calls `init` to create one, transitions from
134    /// buffering to active, and drains any buffered streams through the new handler.
135    ///
136    /// If a handler is already registered and its concrete type matches `T`, returns
137    /// a clone of the existing `Arc<T>`.
138    ///
139    /// Returns `None` if a handler is already registered but is a different concrete type.
140    pub fn get_or_init_with<T: WebTransportDispatch>(
141        &self,
142        init: impl FnOnce() -> T,
143    ) -> Option<Arc<T>> {
144        // Fast path: already active.
145        {
146            let state = self.0.read().expect("dispatcher lock poisoned");
147            if let DispatchState::Active(handler) = &*state {
148                return downcast_arc(handler.clone());
149            }
150        }
151
152        // Slow path: take write lock, initialize if still buffering.
153        let mut state = self.0.write().expect("dispatcher lock poisoned");
154        match &*state {
155            DispatchState::Active(handler) => downcast_arc(handler.clone()),
156            DispatchState::Buffering(_) => {
157                let handler = Arc::new(init());
158                let buffered = std::mem::replace(
159                    &mut *state,
160                    DispatchState::Active(handler.clone() as Arc<dyn WebTransportDispatch>),
161                );
162                let DispatchState::Buffering(buffered) = buffered else {
163                    unreachable!()
164                };
165                drop(state);
166
167                for stream in buffered {
168                    handler.dispatch(stream);
169                }
170
171                Some(handler)
172            }
173        }
174    }
175}
176
177fn downcast_arc<T: Any + Send + Sync>(arc: Arc<dyn WebTransportDispatch>) -> Option<Arc<T>> {
178    let any: Arc<dyn Any + Send + Sync> = arc;
179    any.downcast::<T>().ok()
180}