Skip to main content

trillium_server_common/
h3.rs

1//! HTTP/3 specific exports
2
3pub mod web_transport;
4use crate::{
5    ArcHandler, QuicConnection, QuicConnectionTrait, QuicEndpoint, QuicTransportReceive,
6    QuicTransportSend, Runtime,
7};
8use std::sync::Arc;
9use trillium::{Handler, Upgrade};
10use trillium_http::{
11    HttpContext,
12    h3::{H3Connection, H3Error, H3ErrorCode, H3StreamResult, UniStreamResult},
13};
14use web_transport::{WebTransportDispatcher, WebTransportStream};
15
16/// A QUIC stream identifier
17#[derive(Clone, Copy, Debug)]
18pub struct StreamId(u64);
19impl From<StreamId> for u64 {
20    fn from(val: StreamId) -> Self {
21        val.0
22    }
23}
24
25impl From<u64> for StreamId {
26    fn from(value: u64) -> Self {
27        Self(value)
28    }
29}
30
31pub(crate) async fn run_h3<QE: QuicEndpoint>(
32    quic_binding: QE,
33    context: Arc<HttpContext>,
34    handler: ArcHandler<impl Handler>,
35    runtime: Runtime,
36) {
37    let swansong = context.swansong();
38    while let Some(connection) = swansong.interrupt(quic_binding.accept()).await.flatten() {
39        let h3 = H3Connection::new(context.clone());
40        let handler = handler.clone();
41        let runtime = runtime.clone();
42        runtime
43            .clone()
44            .spawn(run_h3_connection(connection, h3, handler, runtime));
45    }
46}
47
48async fn run_h3_connection<QC: QuicConnectionTrait>(
49    connection: QC,
50    h3: Arc<H3Connection>,
51    handler: ArcHandler<impl Handler>,
52    runtime: Runtime,
53) {
54    let wt_dispatcher = h3
55        .context()
56        .config()
57        .webtransport_enabled()
58        .then(WebTransportDispatcher::new);
59
60    log::trace!("new quic connection from {}", connection.remote_address());
61
62    spawn_outbound_control_stream(&connection, &h3, &runtime);
63    spawn_qpack_encoder_stream(&connection, &h3, &runtime);
64    spawn_qpack_decoder_stream(&connection, &h3, &runtime);
65    spawn_inbound_uni_streams(&connection, &h3, &runtime, &wt_dispatcher);
66    handle_inbound_bidi_streams(connection, h3, handler, runtime, wt_dispatcher).await;
67}
68
69async fn handle_inbound_bidi_streams<QC: QuicConnectionTrait>(
70    connection: QC,
71    h3: Arc<H3Connection>,
72    handler: ArcHandler<impl Handler>,
73    runtime: Runtime,
74    wt_dispatcher: Option<WebTransportDispatcher>,
75) {
76    let swansong = h3.swansong().clone();
77    while let Some(Ok((stream_id, transport))) = swansong.interrupt(connection.accept_bidi()).await
78    {
79        let (h3, handler, connection, wt_dispatcher) = (
80            h3.clone(),
81            handler.clone(),
82            connection.clone(),
83            wt_dispatcher.clone(),
84        );
85        let peer_ip = connection.remote_address().ip();
86        runtime.spawn(async move {
87            let handler = &handler;
88            let quic_connection = connection.clone();
89            let wt_dispatcher = wt_dispatcher.clone();
90            let result = h3
91                .clone()
92                .process_inbound_bidi(
93                    transport,
94                    {
95                        let wt_dispatcher = wt_dispatcher.clone();
96                        |mut conn| async move {
97                            conn.set_peer_ip(Some(peer_ip));
98                            conn.set_secure(true);
99                            let state = conn.state_mut();
100                            state.insert(quic_connection.clone());
101                            state.insert(QuicConnection::from(quic_connection));
102                            state.insert(StreamId(stream_id));
103                            if let Some(dispatcher) = wt_dispatcher {
104                                state.insert(dispatcher);
105                            }
106                            let conn = handler.run(conn.into()).await;
107                            let conn = handler.before_send(conn).await;
108                            conn.into_inner()
109                        }
110                    },
111                    stream_id,
112                )
113                .await;
114
115            match result {
116                Ok(H3StreamResult::Request(conn)) if conn.should_upgrade() => {
117                    let upgrade = Upgrade::from(conn);
118                    if handler.has_upgrade(&upgrade) {
119                        log::debug!("upgrading h3 stream");
120                        handler.upgrade(upgrade).await;
121                    } else {
122                        log::error!("h3 upgrade specified but no upgrade handler provided");
123                    }
124                }
125                Ok(H3StreamResult::Request(_)) => {}
126                Ok(H3StreamResult::WebTransport {
127                    session_id,
128                    mut transport,
129                    buffer,
130                }) => {
131                    if let Some(dispatcher) = &wt_dispatcher {
132                        dispatcher.dispatch(WebTransportStream::Bidi {
133                            session_id,
134                            stream: Box::new(transport),
135                            buffer: buffer.into(),
136                        });
137                    } else {
138                        transport.stop(H3ErrorCode::StreamCreationError.into());
139                        transport.reset(H3ErrorCode::StreamCreationError.into());
140                    }
141                }
142                Err(error) => handle_h3_error(error, &connection, &h3).await,
143            }
144        });
145    }
146}
147
148fn spawn_inbound_uni_streams<QC: QuicConnectionTrait>(
149    connection: &QC,
150    h3: &Arc<H3Connection>,
151    runtime: &Runtime,
152    wt_dispatcher: &Option<WebTransportDispatcher>,
153) {
154    let (connection, h3, runtime, wt_dispatcher) = (
155        connection.clone(),
156        h3.clone(),
157        runtime.clone(),
158        wt_dispatcher.clone(),
159    );
160    runtime.clone().spawn(async move {
161        while let Ok((_stream_id, recv)) = connection.accept_uni().await {
162            let (connection, h3, wt_dispatcher) =
163                (connection.clone(), h3.clone(), wt_dispatcher.clone());
164            runtime.spawn(async move {
165                match h3.process_inbound_uni(recv).await {
166                    Ok(UniStreamResult::Handled) => {}
167                    Ok(UniStreamResult::WebTransport {
168                        session_id,
169                        mut stream,
170                        buffer,
171                    }) => {
172                        if let Some(dispatcher) = &wt_dispatcher {
173                            dispatcher.dispatch(WebTransportStream::Uni {
174                                session_id,
175                                stream: Box::new(stream),
176                                buffer: buffer.into(),
177                            });
178                        } else {
179                            stream.stop(H3ErrorCode::StreamCreationError.into());
180                        }
181                    }
182                    Ok(UniStreamResult::Unknown { mut stream, .. }) => {
183                        stream.stop(H3ErrorCode::StreamCreationError.into());
184                    }
185                    Err(error) => {
186                        handle_h3_error(error, &connection, &h3).await;
187                    }
188                }
189            });
190        }
191    });
192}
193
194fn spawn_qpack_decoder_stream<QC: QuicConnectionTrait>(
195    connection: &QC,
196    h3: &Arc<H3Connection>,
197    runtime: &Runtime,
198) {
199    let (connection, h3) = (connection.clone(), h3.clone());
200    runtime.spawn(async move {
201        let result: Result<(), H3Error> =
202            async { h3.run_decoder(connection.open_uni().await?.1).await }.await;
203        if let Err(error) = result {
204            handle_h3_error(error, &connection, &h3).await;
205        }
206    });
207}
208
209fn spawn_qpack_encoder_stream<QC: QuicConnectionTrait>(
210    connection: &QC,
211    h3: &Arc<H3Connection>,
212    runtime: &Runtime,
213) {
214    let (connection, h3) = (connection.clone(), h3.clone());
215    runtime.spawn(async move {
216        let result: Result<(), H3Error> =
217            async { h3.run_encoder(connection.open_uni().await?.1).await }.await;
218        if let Err(error) = result {
219            handle_h3_error(error, &connection, &h3).await;
220        }
221    });
222}
223
224fn spawn_outbound_control_stream<QC: QuicConnectionTrait>(
225    connection: &QC,
226    h3: &Arc<H3Connection>,
227    runtime: &Runtime,
228) {
229    let (connection, h3) = (connection.clone(), h3.clone());
230    runtime.spawn(async move {
231        let guard = h3.swansong().guard();
232
233        let result: Result<(), H3Error> = async {
234            h3.run_outbound_control(connection.open_uni().await?.1)
235                .await
236        }
237        .await;
238        drop(guard);
239        if let Err(error) = result {
240            handle_h3_error(error, &connection, &h3).await;
241        }
242    });
243}
244
245async fn handle_h3_error(error: H3Error, connection: &impl QuicConnectionTrait, h3: &H3Connection) {
246    log::debug!("H3 error: {error}");
247    if let H3Error::Protocol(code) = error {
248        connection.close(code.into(), code.reason().as_bytes());
249    }
250    h3.shut_down().await;
251}