Skip to main content

trillium_server_common/
h3.rs

1//! HTTP/3 specific exports
2
3pub mod web_transport;
4use crate::{
5    ArcHandler, ArcedQuicEndpoint, BoxedBidiStream, QuicConnection, QuicTransportReceive,
6    QuicTransportSend, RuntimeTrait,
7};
8use std::sync::Arc;
9use trillium::{Handler, KnownHeaderName, Listener, Upgrade};
10use trillium_http::{
11    HttpContext,
12    h3::{H3Connection, H3Error, H3ErrorCode, H3StreamResult, UniStreamResult},
13};
14use web_transport::{WebTransportDispatcher, WebTransportStream};
15
16/// A QUIC stream identifier
17#[derive(Clone, Copy, Debug)]
18pub struct StreamId(u64);
19impl From<StreamId> for u64 {
20    fn from(val: StreamId) -> Self {
21        val.0
22    }
23}
24
25impl From<u64> for StreamId {
26    fn from(value: u64) -> Self {
27        Self(value)
28    }
29}
30
31pub(crate) async fn run_h3(
32    quic_binding: ArcedQuicEndpoint,
33    context: Arc<HttpContext>,
34    handler: ArcHandler<impl Handler>,
35    runtime: impl RuntimeTrait,
36    listener: Option<Listener>,
37    local_alt_svc: Option<&'static str>,
38) {
39    let swansong = context.swansong();
40    while let Some(connection) = swansong.interrupt(quic_binding.accept()).await.flatten() {
41        let h3 = H3Connection::new(context.clone());
42        let handler = handler.clone();
43        let runtime = runtime.clone();
44        runtime.clone().spawn(run_h3_connection(
45            connection,
46            h3,
47            handler,
48            runtime,
49            listener.clone(),
50            local_alt_svc,
51        ));
52    }
53}
54
55async fn run_h3_connection(
56    connection: QuicConnection,
57    h3: Arc<H3Connection>,
58    handler: ArcHandler<impl Handler>,
59    runtime: impl RuntimeTrait,
60    listener: Option<Listener>,
61    local_alt_svc: Option<&'static str>,
62) {
63    let wt_dispatcher = h3
64        .context()
65        .config()
66        .webtransport_enabled()
67        .then(WebTransportDispatcher::new);
68
69    log::trace!("new quic connection from {}", connection.remote_address());
70
71    spawn_outbound_control_stream(&connection, &h3, &runtime);
72    spawn_qpack_encoder_stream(&connection, &h3, &runtime);
73    spawn_qpack_decoder_stream(&connection, &h3, &runtime);
74    spawn_inbound_uni_streams(&connection, &h3, &runtime, &wt_dispatcher);
75    handle_inbound_bidi_streams(
76        connection,
77        h3.clone(),
78        handler,
79        runtime,
80        wt_dispatcher,
81        listener,
82        local_alt_svc,
83    )
84    .await;
85}
86
87async fn handle_inbound_bidi_streams(
88    connection: QuicConnection,
89    h3: Arc<H3Connection>,
90    handler: ArcHandler<impl Handler>,
91    runtime: impl RuntimeTrait,
92    wt_dispatcher: Option<WebTransportDispatcher>,
93    listener: Option<Listener>,
94    local_alt_svc: Option<&'static str>,
95) {
96    loop {
97        match h3.swansong().interrupt(connection.accept_bidi()).await {
98            None => {
99                log::trace!("H3 bidi accept loop: interrupted by swansong shutdown");
100                break;
101            }
102            Some(Err(e)) => {
103                log::debug!("H3 bidi accept loop: accept_bidi error: {e}");
104                break;
105            }
106            Some(Ok((stream_id, transport))) => {
107                handle_bidi_stream(
108                    stream_id,
109                    transport,
110                    &h3,
111                    &handler,
112                    &connection,
113                    &runtime,
114                    &wt_dispatcher,
115                    listener.clone(),
116                    local_alt_svc,
117                );
118            }
119        }
120    }
121
122    h3.shut_down();
123}
124
125#[allow(clippy::too_many_arguments)]
126fn handle_bidi_stream(
127    stream_id: u64,
128    transport: BoxedBidiStream,
129    h3: &Arc<H3Connection>,
130    handler: &ArcHandler<impl Handler>,
131    connection: &QuicConnection,
132    runtime: &impl RuntimeTrait,
133    wt_dispatcher: &Option<WebTransportDispatcher>,
134    listener: Option<Listener>,
135    local_alt_svc: Option<&'static str>,
136) {
137    log::trace!("H3 bidi stream {stream_id}: spawning handler task");
138    let (h3, handler, connection, wt_dispatcher) = (
139        h3.clone(),
140        handler.clone(),
141        connection.clone(),
142        wt_dispatcher.clone(),
143    );
144
145    runtime.spawn(async move {
146        let handler = &handler;
147        let peer_ip = connection.remote_address().ip();
148        let quic_connection = connection.clone();
149        let wt_dispatcher = wt_dispatcher.clone();
150
151        let handler_fn = {
152            let wt_dispatcher = wt_dispatcher.clone();
153            |mut conn: trillium_http::Conn<_>| async move {
154                conn.set_peer_ip(Some(peer_ip));
155                conn.set_secure(true);
156
157                let state = conn.state_mut();
158                state.insert(quic_connection);
159                state.insert(StreamId(stream_id));
160                if let Some(listener) = listener {
161                    if let Some(addr) = listener.socket_addr() {
162                        state.insert(addr);
163                    }
164                    state.insert(listener);
165                }
166                if let Some(dispatcher) = wt_dispatcher {
167                    state.insert(dispatcher);
168                }
169                if let Some(alt_svc) = local_alt_svc {
170                    conn.response_headers_mut()
171                        .try_insert(KnownHeaderName::AltSvc, alt_svc);
172                }
173
174                let conn = handler.run(conn.into()).await;
175                let conn = handler.before_send(conn).await;
176
177                conn.into_inner()
178            }
179        };
180
181        let result = h3
182            .clone()
183            .process_inbound_bidi_with_reset(transport, handler_fn, stream_id, |t, code| {
184                // RFC 9114 §4.1.2: stream-level protocol errors (notably H3_MESSAGE_ERROR)
185                // MUST RST the stream. We stop the recv side and reset the send side with
186                // the same code so the peer sees the error on whichever direction it's
187                // listening on.
188                let raw = u64::from(code);
189                t.stop(raw);
190                t.reset(raw);
191            })
192            .await;
193
194        match result {
195            Ok(H3StreamResult::Request(conn)) if conn.should_upgrade() => {
196                let upgrade = Upgrade::from(conn);
197                if handler.has_upgrade(&upgrade) {
198                    log::debug!("upgrading h3 stream");
199                    handler.upgrade(upgrade).await;
200                } else {
201                    log::error!("h3 upgrade specified but no upgrade handler provided");
202                }
203            }
204
205            Ok(H3StreamResult::Request(_)) => {}
206
207            Ok(H3StreamResult::WebTransport {
208                session_id,
209                mut transport,
210                buffer,
211            }) => {
212                if let Some(dispatcher) = &wt_dispatcher {
213                    dispatcher.dispatch(WebTransportStream::Bidi {
214                        session_id,
215                        stream: Box::new(transport),
216                        buffer: buffer.into(),
217                    });
218                } else {
219                    transport.stop(H3ErrorCode::StreamCreationError.into());
220                    transport.reset(H3ErrorCode::StreamCreationError.into());
221                }
222            }
223
224            Err(error) => {
225                log::debug!("H3 bidi stream {stream_id}: error: {error}");
226                handle_h3_error(error, &connection, &h3);
227            }
228        }
229    });
230}
231
232fn spawn_inbound_uni_streams(
233    connection: &QuicConnection,
234    h3: &Arc<H3Connection>,
235    runtime: &impl RuntimeTrait,
236    wt_dispatcher: &Option<WebTransportDispatcher>,
237) {
238    let (connection, h3, runtime, wt_dispatcher) = (
239        connection.clone(),
240        h3.clone(),
241        runtime.clone(),
242        wt_dispatcher.clone(),
243    );
244    runtime.clone().spawn(async move {
245        while let Some(Ok((_stream_id, recv))) =
246            h3.swansong().interrupt(connection.accept_uni()).await
247        {
248            let (connection, h3, wt_dispatcher) =
249                (connection.clone(), h3.clone(), wt_dispatcher.clone());
250
251            runtime.spawn(async move {
252                // RFC 9114 §8.1 / RFC 9204 §6 connection-level errors must close the
253                // QUIC connection while the recv stream is still alive — otherwise
254                // quinn's RecvStream::drop sends STOP_SENDING, and the peer's malformed
255                // RESET_STREAM response can race ahead and override our app error code
256                // with FINAL_SIZE_ERROR on the wire. The closure fires inside
257                // process_inbound_uni_with_close before stream drops, so the close sets
258                // quinn's conn.error first and the drop becomes a no-op.
259                let close_connection = {
260                    let connection = connection.clone();
261                    let h3 = h3.clone();
262                    move |code: H3ErrorCode| {
263                        connection.close(code.into(), code.reason().as_bytes());
264                        h3.shut_down();
265                    }
266                };
267                let result = h3
268                    .process_inbound_uni_with_close(recv, close_connection)
269                    .await;
270
271                match result {
272                    Ok(UniStreamResult::Handled) => {}
273                    Ok(UniStreamResult::WebTransport {
274                        session_id,
275                        mut stream,
276                        buffer,
277                    }) => {
278                        if let Some(dispatcher) = &wt_dispatcher {
279                            dispatcher.dispatch(WebTransportStream::Uni {
280                                session_id,
281                                stream: Box::new(stream),
282                                buffer: buffer.into(),
283                            });
284                        } else {
285                            stream.stop(H3ErrorCode::StreamCreationError.into());
286                        }
287                    }
288
289                    Ok(UniStreamResult::Unknown { mut stream, .. }) => {
290                        stream.stop(H3ErrorCode::StreamCreationError.into());
291                    }
292
293                    Err(error) => {
294                        // Connection-level protocol errors already fired the close
295                        // callback above; this call is a no-op for the close path
296                        // (idempotent) and still useful for logging plus I/O errors.
297                        handle_h3_error(error, &connection, &h3);
298                    }
299                }
300            });
301        }
302
303        h3.shut_down();
304    });
305}
306
307fn spawn_qpack_decoder_stream(
308    connection: &QuicConnection,
309    h3: &Arc<H3Connection>,
310    runtime: &impl RuntimeTrait,
311) {
312    let (connection, h3) = (connection.clone(), h3.clone());
313
314    runtime.spawn(async move {
315        log::trace!("H3: opening outbound QPACK decoder stream");
316        let stream = match connection.open_uni().await {
317            Ok((_stream_id, stream)) => stream,
318            Err(err) => {
319                log::error!("H3: open_uni for QPACK decoder stream failed: {err:?}");
320                h3.shut_down();
321                return;
322            }
323        };
324
325        let result = h3.run_decoder(stream).await;
326
327        if let Err(error) = result {
328            handle_h3_error(error, &connection, &h3);
329        }
330
331        h3.shut_down();
332    });
333}
334
335fn spawn_qpack_encoder_stream(
336    connection: &QuicConnection,
337    h3: &Arc<H3Connection>,
338    runtime: &impl RuntimeTrait,
339) {
340    let (connection, h3) = (connection.clone(), h3.clone());
341    runtime.spawn(async move {
342        log::trace!("H3: opening outbound QPACK encoder stream");
343        let stream = match connection.open_uni().await {
344            Ok((_stream_id, stream)) => stream,
345            Err(err) => {
346                log::error!("H3: open_uni for QPACK encoder stream failed: {err:?}");
347                h3.shut_down();
348                return;
349            }
350        };
351
352        let result = h3.run_encoder(stream).await;
353
354        if let Err(error) = result {
355            handle_h3_error(error, &connection, &h3);
356        }
357
358        h3.shut_down();
359    });
360}
361
362fn spawn_outbound_control_stream(
363    connection: &QuicConnection,
364    h3: &Arc<H3Connection>,
365    runtime: &impl RuntimeTrait,
366) {
367    let (connection, h3) = (connection.clone(), h3.clone());
368    runtime.spawn(async move {
369        log::trace!("H3: opening outbound control stream");
370        let stream = match connection.open_uni().await {
371            Ok((_stream_id, stream)) => stream,
372            Err(err) => {
373                log::error!("H3: open_uni for outbound control stream failed: {err:?}");
374                h3.shut_down();
375                return;
376            }
377        };
378
379        let result = h3.run_outbound_control(stream).await;
380
381        if let Err(error) = result {
382            handle_h3_error(error, &connection, &h3);
383        }
384
385        h3.shut_down();
386    });
387}
388
389fn handle_h3_error(error: H3Error, connection: &QuicConnection, h3: &H3Connection) {
390    log::debug!("H3 error: {error}");
391    if let H3Error::Protocol(code) = error
392        && code.is_connection_error()
393    {
394        connection.close(code.into(), code.reason().as_bytes());
395        h3.shut_down();
396    }
397}