trillium_server_common/h3/
web_transport.rs1use crate::quic::{BoxedBidiStream, BoxedRecvStream};
4use std::{
5 any::Any,
6 fmt::{self, Debug},
7 sync::{Arc, RwLock},
8};
9
10#[derive(fieldwork::Fieldwork)]
13#[fieldwork(get)]
14pub enum WebTransportStream {
15 Bidi {
17 session_id: u64,
19 stream: BoxedBidiStream,
21 buffer: Vec<u8>,
23 },
24 Uni {
26 session_id: u64,
28 stream: BoxedRecvStream,
30 buffer: Vec<u8>,
32 },
33}
34
35impl Debug for WebTransportStream {
36 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
37 match self {
38 Self::Bidi { session_id, .. } => f
39 .debug_struct("WebTransportStream::Bidi")
40 .field("session_id", session_id)
41 .finish_non_exhaustive(),
42 Self::Uni { session_id, .. } => f
43 .debug_struct("WebTransportStream::Uni")
44 .field("session_id", session_id)
45 .finish_non_exhaustive(),
46 }
47 }
48}
49
50pub trait WebTransportDispatch: Any + Send + Sync {
55 fn dispatch(&self, stream: WebTransportStream);
57}
58
59enum DispatchState {
61 Buffering(Vec<WebTransportStream>),
63
64 Active(Arc<dyn WebTransportDispatch>),
66}
67
68#[derive(Clone)]
77pub struct WebTransportDispatcher(Arc<RwLock<DispatchState>>);
78impl Default for WebTransportDispatcher {
79 fn default() -> Self {
80 Self(Arc::new(RwLock::new(DispatchState::Buffering(Vec::new()))))
81 }
82}
83
84impl Debug for WebTransportDispatcher {
85 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
86 let state = self.0.read().expect("dispatcher lock poisoned");
87 let label = match &*state {
88 DispatchState::Buffering(buf) => {
89 format!("Buffering({} streams)", buf.len())
90 }
91 DispatchState::Active(_) => "Active".to_string(),
92 };
93 f.debug_tuple("WebTransportDispatcher")
94 .field(&label)
95 .finish()
96 }
97}
98
99impl WebTransportDispatcher {
100 pub fn new() -> Self {
102 Self::default()
103 }
104
105 pub fn dispatch(&self, stream: WebTransportStream) {
107 {
109 let state = self.0.read().expect("dispatcher lock poisoned");
110 if let DispatchState::Active(handler) = &*state {
111 handler.dispatch(stream);
112 return;
113 }
114 }
115
116 {
118 let mut state = self.0.write().expect("dispatcher lock poisoned");
119 match &*state {
120 DispatchState::Buffering(_) => {
121 let DispatchState::Buffering(buf) = &mut *state else {
122 unreachable!()
123 };
124 buf.push(stream);
125 }
126 DispatchState::Active(handler) => handler.dispatch(stream),
127 }
128 }
129 }
130
131 pub fn get_or_init_with<T: WebTransportDispatch>(
141 &self,
142 init: impl FnOnce() -> T,
143 ) -> Option<Arc<T>> {
144 {
146 let state = self.0.read().expect("dispatcher lock poisoned");
147 if let DispatchState::Active(handler) = &*state {
148 return downcast_arc(handler.clone());
149 }
150 }
151
152 let mut state = self.0.write().expect("dispatcher lock poisoned");
154 match &*state {
155 DispatchState::Active(handler) => downcast_arc(handler.clone()),
156 DispatchState::Buffering(_) => {
157 let handler = Arc::new(init());
158 let buffered = std::mem::replace(
159 &mut *state,
160 DispatchState::Active(handler.clone() as Arc<dyn WebTransportDispatch>),
161 );
162 let DispatchState::Buffering(buffered) = buffered else {
163 unreachable!()
164 };
165 drop(state);
166
167 for stream in buffered {
168 handler.dispatch(stream);
169 }
170
171 Some(handler)
172 }
173 }
174 }
175}
176
177fn downcast_arc<T: Any + Send + Sync>(arc: Arc<dyn WebTransportDispatch>) -> Option<Arc<T>> {
178 let any: Arc<dyn Any + Send + Sync> = arc;
179 any.downcast::<T>().ok()
180}