Skip to main content

trillium_server_common/
h3.rs

1//! HTTP/3 specific exports
2
3pub mod web_transport;
4use crate::{
5    ArcHandler, QuicConnection, QuicConnectionTrait, QuicEndpoint, QuicTransportReceive,
6    QuicTransportSend, RuntimeTrait,
7};
8use std::sync::Arc;
9use trillium::{Handler, Upgrade};
10use trillium_http::{
11    HttpContext,
12    h3::{H3Connection, H3Error, H3ErrorCode, H3StreamResult, UniStreamResult},
13};
14use web_transport::{WebTransportDispatcher, WebTransportStream};
15
16/// A QUIC stream identifier
17#[derive(Clone, Copy, Debug)]
18pub struct StreamId(u64);
19impl From<StreamId> for u64 {
20    fn from(val: StreamId) -> Self {
21        val.0
22    }
23}
24
25impl From<u64> for StreamId {
26    fn from(value: u64) -> Self {
27        Self(value)
28    }
29}
30
31pub(crate) async fn run_h3<QE: QuicEndpoint>(
32    quic_binding: QE,
33    context: Arc<HttpContext>,
34    handler: ArcHandler<impl Handler>,
35    runtime: impl RuntimeTrait,
36) {
37    let swansong = context.swansong();
38    while let Some(connection) = swansong.interrupt(quic_binding.accept()).await.flatten() {
39        let h3 = H3Connection::new(context.clone());
40        let handler = handler.clone();
41        let runtime = runtime.clone();
42        runtime
43            .clone()
44            .spawn(run_h3_connection(connection, h3, handler, runtime));
45    }
46}
47
48async fn run_h3_connection<QC: QuicConnectionTrait>(
49    connection: QC,
50    h3: Arc<H3Connection>,
51    handler: ArcHandler<impl Handler>,
52    runtime: impl RuntimeTrait,
53) {
54    let wt_dispatcher = h3
55        .context()
56        .config()
57        .webtransport_enabled()
58        .then(WebTransportDispatcher::new);
59
60    log::trace!("new quic connection from {}", connection.remote_address());
61
62    spawn_outbound_control_stream(&connection, &h3, &runtime);
63    spawn_qpack_encoder_stream(&connection, &h3, &runtime);
64    spawn_qpack_decoder_stream(&connection, &h3, &runtime);
65    spawn_inbound_uni_streams(&connection, &h3, &runtime, &wt_dispatcher);
66    handle_inbound_bidi_streams(connection, h3.clone(), handler, runtime, wt_dispatcher).await;
67}
68
69async fn handle_inbound_bidi_streams<QC: QuicConnectionTrait>(
70    connection: QC,
71    h3: Arc<H3Connection>,
72    handler: ArcHandler<impl Handler>,
73    runtime: impl RuntimeTrait,
74    wt_dispatcher: Option<WebTransportDispatcher>,
75) {
76    loop {
77        match h3.swansong().interrupt(connection.accept_bidi()).await {
78            None => {
79                log::trace!("H3 bidi accept loop: interrupted by swansong shutdown");
80                break;
81            }
82            Some(Err(e)) => {
83                log::debug!("H3 bidi accept loop: accept_bidi error: {e}");
84                break;
85            }
86            Some(Ok((stream_id, transport))) => {
87                handle_bidi_stream(
88                    stream_id,
89                    transport,
90                    &h3,
91                    &handler,
92                    &connection,
93                    &runtime,
94                    &wt_dispatcher,
95                );
96            }
97        }
98    }
99
100    h3.shut_down();
101}
102
103fn handle_bidi_stream<QC: QuicConnectionTrait>(
104    stream_id: u64,
105    transport: QC::BidiStream,
106    h3: &Arc<H3Connection>,
107    handler: &ArcHandler<impl Handler>,
108    connection: &QC,
109    runtime: &impl RuntimeTrait,
110    wt_dispatcher: &Option<WebTransportDispatcher>,
111) {
112    log::trace!("H3 bidi stream {stream_id}: spawning handler task");
113    let (h3, handler, connection, wt_dispatcher) = (
114        h3.clone(),
115        handler.clone(),
116        connection.clone(),
117        wt_dispatcher.clone(),
118    );
119
120    runtime.spawn(async move {
121        let handler = &handler;
122        let peer_ip = connection.remote_address().ip();
123        let quic_connection = connection.clone();
124        let wt_dispatcher = wt_dispatcher.clone();
125
126        let handler_fn = {
127            let wt_dispatcher = wt_dispatcher.clone();
128            |mut conn: trillium_http::Conn<_>| async move {
129                conn.set_peer_ip(Some(peer_ip));
130                conn.set_secure(true);
131
132                let state = conn.state_mut();
133                state.insert(quic_connection.clone());
134                state.insert(QuicConnection::from(quic_connection));
135                state.insert(StreamId(stream_id));
136                if let Some(dispatcher) = wt_dispatcher {
137                    state.insert(dispatcher);
138                }
139
140                let conn = handler.run(conn.into()).await;
141                let conn = handler.before_send(conn).await;
142
143                conn.into_inner()
144            }
145        };
146
147        let result = h3
148            .clone()
149            .process_inbound_bidi_with_reset(transport, handler_fn, stream_id, |t, code| {
150                // RFC 9114 §4.1.2: stream-level protocol errors (notably H3_MESSAGE_ERROR)
151                // MUST RST the stream. We stop the recv side and reset the send side with
152                // the same code so the peer sees the error on whichever direction it's
153                // listening on.
154                let raw = u64::from(code);
155                t.stop(raw);
156                t.reset(raw);
157            })
158            .await;
159
160        match result {
161            Ok(H3StreamResult::Request(conn)) if conn.should_upgrade() => {
162                let upgrade = Upgrade::from(conn);
163                if handler.has_upgrade(&upgrade) {
164                    log::debug!("upgrading h3 stream");
165                    handler.upgrade(upgrade).await;
166                } else {
167                    log::error!("h3 upgrade specified but no upgrade handler provided");
168                }
169            }
170
171            Ok(H3StreamResult::Request(_)) => {}
172
173            Ok(H3StreamResult::WebTransport {
174                session_id,
175                mut transport,
176                buffer,
177            }) => {
178                if let Some(dispatcher) = &wt_dispatcher {
179                    dispatcher.dispatch(WebTransportStream::Bidi {
180                        session_id,
181                        stream: Box::new(transport),
182                        buffer: buffer.into(),
183                    });
184                } else {
185                    transport.stop(H3ErrorCode::StreamCreationError.into());
186                    transport.reset(H3ErrorCode::StreamCreationError.into());
187                }
188            }
189
190            Err(error) => {
191                log::debug!("H3 bidi stream {stream_id}: error: {error}");
192                handle_h3_error(error, &connection, &h3);
193            }
194        }
195    });
196}
197
198fn spawn_inbound_uni_streams<QC: QuicConnectionTrait>(
199    connection: &QC,
200    h3: &Arc<H3Connection>,
201    runtime: &impl RuntimeTrait,
202    wt_dispatcher: &Option<WebTransportDispatcher>,
203) {
204    let (connection, h3, runtime, wt_dispatcher) = (
205        connection.clone(),
206        h3.clone(),
207        runtime.clone(),
208        wt_dispatcher.clone(),
209    );
210    runtime.clone().spawn(async move {
211        while let Some(Ok((_stream_id, recv))) =
212            h3.swansong().interrupt(connection.accept_uni()).await
213        {
214            let (connection, h3, wt_dispatcher) =
215                (connection.clone(), h3.clone(), wt_dispatcher.clone());
216
217            runtime.spawn(async move {
218                // RFC 9114 §8.1 / RFC 9204 §6 connection-level errors must close the
219                // QUIC connection while the recv stream is still alive — otherwise
220                // quinn's RecvStream::drop sends STOP_SENDING, and the peer's malformed
221                // RESET_STREAM response can race ahead and override our app error code
222                // with FINAL_SIZE_ERROR on the wire. The closure fires inside
223                // process_inbound_uni_with_close before stream drops, so the close sets
224                // quinn's conn.error first and the drop becomes a no-op.
225                let close_connection = {
226                    let connection = connection.clone();
227                    let h3 = h3.clone();
228                    move |code: H3ErrorCode| {
229                        connection.close(code.into(), code.reason().as_bytes());
230                        h3.shut_down();
231                    }
232                };
233                let result = h3
234                    .process_inbound_uni_with_close(recv, close_connection)
235                    .await;
236
237                match result {
238                    Ok(UniStreamResult::Handled) => {}
239                    Ok(UniStreamResult::WebTransport {
240                        session_id,
241                        mut stream,
242                        buffer,
243                    }) => {
244                        if let Some(dispatcher) = &wt_dispatcher {
245                            dispatcher.dispatch(WebTransportStream::Uni {
246                                session_id,
247                                stream: Box::new(stream),
248                                buffer: buffer.into(),
249                            });
250                        } else {
251                            stream.stop(H3ErrorCode::StreamCreationError.into());
252                        }
253                    }
254
255                    Ok(UniStreamResult::Unknown { mut stream, .. }) => {
256                        stream.stop(H3ErrorCode::StreamCreationError.into());
257                    }
258
259                    Err(error) => {
260                        // Connection-level protocol errors already fired the close
261                        // callback above; this call is a no-op for the close path
262                        // (idempotent) and still useful for logging plus I/O errors.
263                        handle_h3_error(error, &connection, &h3);
264                    }
265                }
266            });
267        }
268
269        h3.shut_down();
270    });
271}
272
273fn spawn_qpack_decoder_stream<QC: QuicConnectionTrait>(
274    connection: &QC,
275    h3: &Arc<H3Connection>,
276    runtime: &impl RuntimeTrait,
277) {
278    let (connection, h3) = (connection.clone(), h3.clone());
279
280    runtime.spawn(async move {
281        log::trace!("H3: opening outbound QPACK decoder stream");
282        let stream = match connection.open_uni().await {
283            Ok((_stream_id, stream)) => stream,
284            Err(err) => {
285                log::error!("H3: open_uni for QPACK decoder stream failed: {err:?}");
286                h3.shut_down();
287                return;
288            }
289        };
290
291        let result = h3.run_decoder(stream).await;
292
293        if let Err(error) = result {
294            handle_h3_error(error, &connection, &h3);
295        }
296
297        h3.shut_down();
298    });
299}
300
301fn spawn_qpack_encoder_stream<QC: QuicConnectionTrait>(
302    connection: &QC,
303    h3: &Arc<H3Connection>,
304    runtime: &impl RuntimeTrait,
305) {
306    let (connection, h3) = (connection.clone(), h3.clone());
307    runtime.spawn(async move {
308        log::trace!("H3: opening outbound QPACK encoder stream");
309        let stream = match connection.open_uni().await {
310            Ok((_stream_id, stream)) => stream,
311            Err(err) => {
312                log::error!("H3: open_uni for QPACK encoder stream failed: {err:?}");
313                h3.shut_down();
314                return;
315            }
316        };
317
318        let result = h3.run_encoder(stream).await;
319
320        if let Err(error) = result {
321            handle_h3_error(error, &connection, &h3);
322        }
323
324        h3.shut_down();
325    });
326}
327
328fn spawn_outbound_control_stream<QC: QuicConnectionTrait>(
329    connection: &QC,
330    h3: &Arc<H3Connection>,
331    runtime: &impl RuntimeTrait,
332) {
333    let (connection, h3) = (connection.clone(), h3.clone());
334    runtime.spawn(async move {
335        log::trace!("H3: opening outbound control stream");
336        let stream = match connection.open_uni().await {
337            Ok((_stream_id, stream)) => stream,
338            Err(err) => {
339                log::error!("H3: open_uni for outbound control stream failed: {err:?}");
340                h3.shut_down();
341                return;
342            }
343        };
344
345        let result = h3.run_outbound_control(stream).await;
346
347        if let Err(error) = result {
348            handle_h3_error(error, &connection, &h3);
349        }
350
351        h3.shut_down();
352    });
353}
354
355fn handle_h3_error(error: H3Error, connection: &impl QuicConnectionTrait, h3: &H3Connection) {
356    log::debug!("H3 error: {error}");
357    if let H3Error::Protocol(code) = error
358        && code.is_connection_error()
359    {
360        connection.close(code.into(), code.reason().as_bytes());
361        h3.shut_down();
362    }
363}