1pub mod web_transport;
4use crate::{
5 ArcHandler, ArcedQuicEndpoint, BoxedBidiStream, QuicConnection, QuicTransportReceive,
6 QuicTransportSend, RuntimeTrait,
7};
8use std::sync::Arc;
9use trillium::{Handler, KnownHeaderName, Listener, Upgrade};
10use trillium_http::{
11 HttpContext,
12 h3::{H3Connection, H3Error, H3ErrorCode, H3StreamResult, UniStreamResult},
13};
14use web_transport::{WebTransportDispatcher, WebTransportStream};
15
16#[derive(Clone, Copy, Debug)]
18pub struct StreamId(u64);
19impl From<StreamId> for u64 {
20 fn from(val: StreamId) -> Self {
21 val.0
22 }
23}
24
25impl From<u64> for StreamId {
26 fn from(value: u64) -> Self {
27 Self(value)
28 }
29}
30
31pub(crate) async fn run_h3(
32 quic_binding: ArcedQuicEndpoint,
33 context: Arc<HttpContext>,
34 handler: ArcHandler<impl Handler>,
35 runtime: impl RuntimeTrait,
36 listener: Option<Listener>,
37 local_alt_svc: Option<&'static str>,
38) {
39 let swansong = context.swansong();
40 while let Some(connection) = swansong.interrupt(quic_binding.accept()).await.flatten() {
41 let h3 = H3Connection::new(context.clone());
42 let handler = handler.clone();
43 let runtime = runtime.clone();
44 runtime.clone().spawn(run_h3_connection(
45 connection,
46 h3,
47 handler,
48 runtime,
49 listener.clone(),
50 local_alt_svc,
51 ));
52 }
53}
54
55async fn run_h3_connection(
56 connection: QuicConnection,
57 h3: Arc<H3Connection>,
58 handler: ArcHandler<impl Handler>,
59 runtime: impl RuntimeTrait,
60 listener: Option<Listener>,
61 local_alt_svc: Option<&'static str>,
62) {
63 let wt_dispatcher = h3
64 .context()
65 .config()
66 .webtransport_enabled()
67 .then(WebTransportDispatcher::new);
68
69 log::trace!("new quic connection from {}", connection.remote_address());
70
71 spawn_outbound_control_stream(&connection, &h3, &runtime);
72 spawn_qpack_encoder_stream(&connection, &h3, &runtime);
73 spawn_qpack_decoder_stream(&connection, &h3, &runtime);
74 spawn_inbound_uni_streams(&connection, &h3, &runtime, &wt_dispatcher);
75 handle_inbound_bidi_streams(
76 connection,
77 h3.clone(),
78 handler,
79 runtime,
80 wt_dispatcher,
81 listener,
82 local_alt_svc,
83 )
84 .await;
85}
86
87async fn handle_inbound_bidi_streams(
88 connection: QuicConnection,
89 h3: Arc<H3Connection>,
90 handler: ArcHandler<impl Handler>,
91 runtime: impl RuntimeTrait,
92 wt_dispatcher: Option<WebTransportDispatcher>,
93 listener: Option<Listener>,
94 local_alt_svc: Option<&'static str>,
95) {
96 loop {
97 match h3.swansong().interrupt(connection.accept_bidi()).await {
98 None => {
99 log::trace!("H3 bidi accept loop: interrupted by swansong shutdown");
100 break;
101 }
102 Some(Err(e)) => {
103 log::debug!("H3 bidi accept loop: accept_bidi error: {e}");
104 break;
105 }
106 Some(Ok((stream_id, transport))) => {
107 handle_bidi_stream(
108 stream_id,
109 transport,
110 &h3,
111 &handler,
112 &connection,
113 &runtime,
114 &wt_dispatcher,
115 listener.clone(),
116 local_alt_svc,
117 );
118 }
119 }
120 }
121
122 h3.shut_down();
123}
124
125#[allow(clippy::too_many_arguments)]
126fn handle_bidi_stream(
127 stream_id: u64,
128 transport: BoxedBidiStream,
129 h3: &Arc<H3Connection>,
130 handler: &ArcHandler<impl Handler>,
131 connection: &QuicConnection,
132 runtime: &impl RuntimeTrait,
133 wt_dispatcher: &Option<WebTransportDispatcher>,
134 listener: Option<Listener>,
135 local_alt_svc: Option<&'static str>,
136) {
137 log::trace!("H3 bidi stream {stream_id}: spawning handler task");
138 let (h3, handler, connection, wt_dispatcher) = (
139 h3.clone(),
140 handler.clone(),
141 connection.clone(),
142 wt_dispatcher.clone(),
143 );
144
145 runtime.spawn(async move {
146 let handler = &handler;
147 let peer_ip = connection.remote_address().ip();
148 let quic_connection = connection.clone();
149 let wt_dispatcher = wt_dispatcher.clone();
150
151 let handler_fn = {
152 let wt_dispatcher = wt_dispatcher.clone();
153 |mut conn: trillium_http::Conn<_>| async move {
154 conn.set_peer_ip(Some(peer_ip));
155 conn.set_secure(true);
156
157 let state = conn.state_mut();
158 state.insert(quic_connection);
159 state.insert(StreamId(stream_id));
160 if let Some(listener) = listener {
161 if let Some(addr) = listener.socket_addr() {
162 state.insert(addr);
163 }
164 state.insert(listener);
165 }
166 if let Some(dispatcher) = wt_dispatcher {
167 state.insert(dispatcher);
168 }
169 if let Some(alt_svc) = local_alt_svc {
170 conn.response_headers_mut()
171 .try_insert(KnownHeaderName::AltSvc, alt_svc);
172 }
173
174 let conn = handler.run(conn.into()).await;
175 let conn = handler.before_send(conn).await;
176
177 conn.into_inner()
178 }
179 };
180
181 let result = h3
182 .clone()
183 .process_inbound_bidi_with_reset(transport, handler_fn, stream_id, |t, code| {
184 let raw = u64::from(code);
189 t.stop(raw);
190 t.reset(raw);
191 })
192 .await;
193
194 match result {
195 Ok(H3StreamResult::Request(conn)) if conn.should_upgrade() => {
196 let upgrade = Upgrade::from(conn);
197 if handler.has_upgrade(&upgrade) {
198 log::debug!("upgrading h3 stream");
199 handler.upgrade(upgrade).await;
200 } else {
201 log::error!("h3 upgrade specified but no upgrade handler provided");
202 }
203 }
204
205 Ok(H3StreamResult::Request(_)) => {}
206
207 Ok(H3StreamResult::WebTransport {
208 session_id,
209 mut transport,
210 buffer,
211 }) => {
212 if let Some(dispatcher) = &wt_dispatcher {
213 dispatcher.dispatch(WebTransportStream::Bidi {
214 session_id,
215 stream: Box::new(transport),
216 buffer: buffer.into(),
217 });
218 } else {
219 transport.stop(H3ErrorCode::StreamCreationError.into());
220 transport.reset(H3ErrorCode::StreamCreationError.into());
221 }
222 }
223
224 Err(error) => {
225 log::debug!("H3 bidi stream {stream_id}: error: {error}");
226 handle_h3_error(error, &connection, &h3);
227 }
228 }
229 });
230}
231
232fn spawn_inbound_uni_streams(
233 connection: &QuicConnection,
234 h3: &Arc<H3Connection>,
235 runtime: &impl RuntimeTrait,
236 wt_dispatcher: &Option<WebTransportDispatcher>,
237) {
238 let (connection, h3, runtime, wt_dispatcher) = (
239 connection.clone(),
240 h3.clone(),
241 runtime.clone(),
242 wt_dispatcher.clone(),
243 );
244 runtime.clone().spawn(async move {
245 while let Some(Ok((_stream_id, recv))) =
246 h3.swansong().interrupt(connection.accept_uni()).await
247 {
248 let (connection, h3, wt_dispatcher) =
249 (connection.clone(), h3.clone(), wt_dispatcher.clone());
250
251 runtime.spawn(async move {
252 let close_connection = {
260 let connection = connection.clone();
261 let h3 = h3.clone();
262 move |code: H3ErrorCode| {
263 connection.close(code.into(), code.reason().as_bytes());
264 h3.shut_down();
265 }
266 };
267 let result = h3
268 .process_inbound_uni_with_close(recv, close_connection)
269 .await;
270
271 match result {
272 Ok(UniStreamResult::Handled) => {}
273 Ok(UniStreamResult::WebTransport {
274 session_id,
275 mut stream,
276 buffer,
277 }) => {
278 if let Some(dispatcher) = &wt_dispatcher {
279 dispatcher.dispatch(WebTransportStream::Uni {
280 session_id,
281 stream: Box::new(stream),
282 buffer: buffer.into(),
283 });
284 } else {
285 stream.stop(H3ErrorCode::StreamCreationError.into());
286 }
287 }
288
289 Ok(UniStreamResult::Unknown { mut stream, .. }) => {
290 stream.stop(H3ErrorCode::StreamCreationError.into());
291 }
292
293 Err(error) => {
294 handle_h3_error(error, &connection, &h3);
298 }
299 }
300 });
301 }
302
303 h3.shut_down();
304 });
305}
306
307fn spawn_qpack_decoder_stream(
308 connection: &QuicConnection,
309 h3: &Arc<H3Connection>,
310 runtime: &impl RuntimeTrait,
311) {
312 let (connection, h3) = (connection.clone(), h3.clone());
313
314 runtime.spawn(async move {
315 log::trace!("H3: opening outbound QPACK decoder stream");
316 let stream = match connection.open_uni().await {
317 Ok((_stream_id, stream)) => stream,
318 Err(err) => {
319 log::error!("H3: open_uni for QPACK decoder stream failed: {err:?}");
320 h3.shut_down();
321 return;
322 }
323 };
324
325 let result = h3.run_decoder(stream).await;
326
327 if let Err(error) = result {
328 handle_h3_error(error, &connection, &h3);
329 }
330
331 h3.shut_down();
332 });
333}
334
335fn spawn_qpack_encoder_stream(
336 connection: &QuicConnection,
337 h3: &Arc<H3Connection>,
338 runtime: &impl RuntimeTrait,
339) {
340 let (connection, h3) = (connection.clone(), h3.clone());
341 runtime.spawn(async move {
342 log::trace!("H3: opening outbound QPACK encoder stream");
343 let stream = match connection.open_uni().await {
344 Ok((_stream_id, stream)) => stream,
345 Err(err) => {
346 log::error!("H3: open_uni for QPACK encoder stream failed: {err:?}");
347 h3.shut_down();
348 return;
349 }
350 };
351
352 let result = h3.run_encoder(stream).await;
353
354 if let Err(error) = result {
355 handle_h3_error(error, &connection, &h3);
356 }
357
358 h3.shut_down();
359 });
360}
361
362fn spawn_outbound_control_stream(
363 connection: &QuicConnection,
364 h3: &Arc<H3Connection>,
365 runtime: &impl RuntimeTrait,
366) {
367 let (connection, h3) = (connection.clone(), h3.clone());
368 runtime.spawn(async move {
369 log::trace!("H3: opening outbound control stream");
370 let stream = match connection.open_uni().await {
371 Ok((_stream_id, stream)) => stream,
372 Err(err) => {
373 log::error!("H3: open_uni for outbound control stream failed: {err:?}");
374 h3.shut_down();
375 return;
376 }
377 };
378
379 let result = h3.run_outbound_control(stream).await;
380
381 if let Err(error) = result {
382 handle_h3_error(error, &connection, &h3);
383 }
384
385 h3.shut_down();
386 });
387}
388
389fn handle_h3_error(error: H3Error, connection: &QuicConnection, h3: &H3Connection) {
390 log::debug!("H3 error: {error}");
391 if let H3Error::Protocol(code) = error
392 && code.is_connection_error()
393 {
394 connection.close(code.into(), code.reason().as_bytes());
395 h3.shut_down();
396 }
397}