1#[cfg(test)]
20#[doc = include_str!("../README.md")]
21mod readme {}
22
23mod session_router;
24mod stream;
25
26#[doc(hidden)]
30pub use crate::session_router::{Router, SessionRouter};
31pub use crate::stream::{
32 Datagram, InboundBidiStream, InboundStream, InboundUniStream, OutboundBidiStream,
33 OutboundUniStream,
34};
35use async_channel::Receiver;
36use futures_lite::AsyncWriteExt;
37use std::{
38 borrow::Cow,
39 io,
40 net::SocketAddr,
41 sync::{Arc, OnceLock},
42};
43use swansong::Swansong;
44use trillium::{Conn, Handler, Info, Method, Status, Transport, Upgrade};
45use trillium_http::{
46 Headers, TypeSet,
47 h3::{H3Connection, quic_varint},
48};
49use trillium_server_common::{
50 QuicConnection, Runtime,
51 h3::{StreamId, web_transport::WebTransportDispatcher},
52};
53
54pub struct WebTransportConnection {
60 session_id: u64,
61 bidi_rx: Receiver<InboundBidiStream>,
62 uni_rx: Receiver<InboundUniStream>,
63 datagram_rx: Receiver<Datagram>,
64 swansong: Swansong,
65 request_headers: Headers,
66 response_headers: Headers,
67 state: TypeSet,
68 path: Option<Cow<'static, str>>,
69 authority: Option<Cow<'static, str>>,
70 h3_connection: Arc<H3Connection>,
71 quic_connection: QuicConnection,
72 runtime: Runtime,
73}
74
75impl WebTransportConnection {
76 #[doc(hidden)]
80 #[allow(clippy::too_many_arguments)]
81 pub fn new(
82 session_id: u64,
83 bidi_rx: Receiver<InboundBidiStream>,
84 uni_rx: Receiver<InboundUniStream>,
85 datagram_rx: Receiver<Datagram>,
86 swansong: Swansong,
87 request_headers: Headers,
88 response_headers: Headers,
89 state: TypeSet,
90 path: Option<Cow<'static, str>>,
91 authority: Option<Cow<'static, str>>,
92 h3_connection: Arc<H3Connection>,
93 quic_connection: QuicConnection,
94 runtime: Runtime,
95 ) -> Self {
96 Self {
97 session_id,
98 bidi_rx,
99 uni_rx,
100 datagram_rx,
101 swansong,
102 request_headers,
103 response_headers,
104 state,
105 path,
106 authority,
107 h3_connection,
108 quic_connection,
109 runtime,
110 }
111 }
112
113 pub async fn accept_bidi(&self) -> Option<InboundBidiStream> {
117 self.swansong.interrupt(self.bidi_rx.recv()).await?.ok()
118 }
119
120 pub fn runtime(&self) -> &Runtime {
122 &self.runtime
123 }
124
125 pub fn h3_connection(&self) -> &H3Connection {
127 &self.h3_connection
128 }
129
130 pub fn request_headers(&self) -> &Headers {
132 &self.request_headers
133 }
134
135 pub fn request_headers_mut(&mut self) -> &mut Headers {
137 &mut self.request_headers
138 }
139
140 pub fn response_headers(&self) -> &Headers {
147 &self.response_headers
148 }
149
150 pub fn response_headers_mut(&mut self) -> &mut Headers {
152 &mut self.response_headers
153 }
154
155 pub fn state(&self) -> &TypeSet {
157 &self.state
158 }
159
160 pub fn state_mut(&mut self) -> &mut TypeSet {
162 &mut self.state
163 }
164
165 pub fn path(&self) -> Option<&str> {
168 self.path.as_deref()
169 }
170
171 pub fn authority(&self) -> Option<&str> {
173 self.authority.as_deref()
174 }
175
176 pub fn peer_addr(&self) -> SocketAddr {
178 self.quic_connection.remote_address()
179 }
180
181 pub async fn accept_uni(&self) -> Option<InboundUniStream> {
185 self.swansong.interrupt(self.uni_rx.recv()).await?.ok()
186 }
187
188 pub async fn recv_datagram(&self) -> Option<Datagram> {
192 self.swansong.interrupt(self.datagram_rx.recv()).await?.ok()
193 }
194
195 pub async fn accept_next_stream(&self) -> Option<InboundStream> {
204 futures_lite::future::race(
205 async { self.accept_bidi().await.map(InboundStream::Bidi) },
206 async { self.accept_uni().await.map(InboundStream::Uni) },
207 )
208 .await
209 }
210
211 pub fn send_datagram(&self, payload: &[u8]) -> io::Result<()> {
216 let quarter_id = self.session_id / 4;
217 let header_len = quic_varint::encoded_len(quarter_id);
218 let mut buf = vec![0u8; header_len + payload.len()];
219 quic_varint::encode(quarter_id, &mut buf).unwrap();
220 buf[header_len..].copy_from_slice(payload);
221 self.quic_connection.send_datagram(&buf)
222 }
223
224 pub async fn open_bidi(&self) -> io::Result<OutboundBidiStream> {
226 let (_stream_id, mut transport) = self.quic_connection.open_bidi().await?;
227 transport
228 .write_all(&wt_bidi_header(self.session_id))
229 .await?;
230 Ok(OutboundBidiStream::new(transport))
231 }
232
233 pub async fn open_uni(&self) -> io::Result<OutboundUniStream> {
235 let (_stream_id, mut stream) = self.quic_connection.open_uni().await?;
236 stream.write_all(&wt_uni_header(self.session_id)).await?;
237 Ok(OutboundUniStream::new(stream))
238 }
239}
240
241fn wt_bidi_header(session_id: u64) -> Vec<u8> {
243 let mut buf =
244 vec![0u8; quic_varint::encoded_len(0x41u64) + quic_varint::encoded_len(session_id)];
245 let mut offset = quic_varint::encode(0x41u64, &mut buf).unwrap();
246 offset += quic_varint::encode(session_id, &mut buf[offset..]).unwrap();
247 buf.truncate(offset);
248 buf
249}
250
251fn wt_uni_header(session_id: u64) -> Vec<u8> {
253 let mut buf =
254 vec![0u8; quic_varint::encoded_len(0x54u64) + quic_varint::encoded_len(session_id)];
255 let mut offset = quic_varint::encode(0x54u64, &mut buf).unwrap();
256 offset += quic_varint::encode(session_id, &mut buf[offset..]).unwrap();
257 buf.truncate(offset);
258 buf
259}
260
261pub const DEFAULT_MAX_DATAGRAM_BUFFER: usize = 16;
264
265pub struct WebTransport<H> {
283 runtime: OnceLock<Runtime>,
284 max_datagram_buffer: usize,
285 handler: H,
286}
287
288pub trait WebTransportHandler: Send + Sync + 'static {
293 fn run(
295 &self,
296 web_transport_connection: WebTransportConnection,
297 ) -> impl Future<Output = ()> + Send;
298}
299
300impl<Fun, Fut> WebTransportHandler for Fun
301where
302 Fun: Fn(WebTransportConnection) -> Fut + Send + Sync + 'static,
303 Fut: Future<Output = ()> + Send,
304{
305 async fn run(&self, web_transport_connection: WebTransportConnection) {
306 self(web_transport_connection).await
307 }
308}
309
310impl<H> WebTransport<H>
311where
312 H: WebTransportHandler,
313{
314 pub fn new(handler: H) -> Self {
316 Self {
317 handler,
318 runtime: Default::default(),
319 max_datagram_buffer: DEFAULT_MAX_DATAGRAM_BUFFER,
320 }
321 }
322
323 pub fn with_max_datagram_buffer(mut self, max: usize) -> Self {
336 self.max_datagram_buffer = max;
337 self
338 }
339
340 fn runtime(&self) -> &Runtime {
341 self.runtime.get().unwrap()
342 }
343}
344
345struct WTUpgrade;
346
347impl<H> Handler for WebTransport<H>
348where
349 H: WebTransportHandler,
350{
351 async fn run(&self, conn: Conn) -> Conn {
352 let inner: &trillium_http::Conn<Box<dyn Transport>> = conn.as_ref();
353 if inner.state().contains::<QuicConnection>()
354 && conn.method() == Method::Connect
355 && inner.protocol() == Some("webtransport")
356 {
357 conn.with_state(WTUpgrade).with_status(Status::Ok).halt()
358 } else {
359 conn
360 }
361 }
362
363 async fn init(&mut self, info: &mut Info) {
364 self.runtime.get_or_init(|| {
365 info.shared_state::<Runtime>()
366 .cloned()
367 .expect("webtransport requires a Runtime")
368 });
369
370 info.config_mut()
371 .set_h3_datagrams_enabled(true)
372 .set_webtransport_enabled(true)
373 .set_extended_connect_enabled(true);
374 }
375
376 fn has_upgrade(&self, upgrade: &Upgrade) -> bool {
377 upgrade.state().get::<WTUpgrade>().is_some()
378 }
379
380 async fn upgrade(&self, mut upgrade: Upgrade) {
381 let Some(h3_connection) = upgrade.h3_connection() else {
382 log::error!("missing H3Connection in upgrade state");
383 return;
384 };
385 let Some(quic_connection) = upgrade.state_mut().take::<QuicConnection>() else {
386 log::error!("missing QuicConnection in upgrade state");
387 return;
388 };
389 let Some(stream_id) = upgrade.state_mut().take::<StreamId>() else {
390 log::error!("missing StreamId in upgrade state");
391 return;
392 };
393 let Some(dispatcher) = upgrade.state().get::<WebTransportDispatcher>().cloned() else {
394 log::error!("missing WebTransportDispatcher in upgrade state");
395 return;
396 };
397
398 let max_datagram_buffer = self.max_datagram_buffer;
399 let Some(router) = dispatcher.get_or_init_with(|| Router::new(max_datagram_buffer)) else {
400 log::error!("WebTransportDispatcher has a handler of an unexpected type");
401 return;
402 };
403
404 router
406 .clone()
407 .spawn_routing_task(quic_connection.clone(), self.runtime().clone());
408
409 let session_id = stream_id.into();
410 log::trace!("starting webtransport session {session_id}");
411 let session_swansong = h3_connection.swansong().child();
412 let (bidi_rx, uni_rx, datagram_rx) = router.sessions().lock().await.register(session_id);
413
414 let runtime = self.runtime().clone();
415
416 let inner = upgrade.as_mut();
417 let request_headers = std::mem::take(inner.request_headers_mut());
418 let response_headers = std::mem::take(inner.response_headers_mut());
419 let state = std::mem::take(inner.state_mut());
420 let authority = inner.take_authority();
421 let path = Some(std::mem::take(inner.path_mut()));
422
423 self.handler
424 .run(WebTransportConnection {
425 session_id,
426 bidi_rx,
427 uni_rx,
428 datagram_rx,
429 swansong: session_swansong.clone(),
430 request_headers,
431 response_headers,
432 state,
433 path,
434 authority,
435 h3_connection,
436 quic_connection,
437 runtime,
438 })
439 .await;
440
441 log::trace!("finished handler, cleaning up");
442
443 session_swansong.shut_down().await;
444 router.sessions().lock().await.unregister(session_id);
445 }
446}