1pub mod web_transport;
4use crate::{
5 ArcHandler, QuicConnection, QuicConnectionTrait, QuicEndpoint, QuicTransportReceive,
6 QuicTransportSend, RuntimeTrait,
7};
8use std::sync::Arc;
9use trillium::{Handler, Upgrade};
10use trillium_http::{
11 HttpContext,
12 h3::{H3Connection, H3Error, H3ErrorCode, H3StreamResult, UniStreamResult},
13};
14use web_transport::{WebTransportDispatcher, WebTransportStream};
15
16#[derive(Clone, Copy, Debug)]
18pub struct StreamId(u64);
19impl From<StreamId> for u64 {
20 fn from(val: StreamId) -> Self {
21 val.0
22 }
23}
24
25impl From<u64> for StreamId {
26 fn from(value: u64) -> Self {
27 Self(value)
28 }
29}
30
31pub(crate) async fn run_h3<QE: QuicEndpoint>(
32 quic_binding: QE,
33 context: Arc<HttpContext>,
34 handler: ArcHandler<impl Handler>,
35 runtime: impl RuntimeTrait,
36) {
37 let swansong = context.swansong();
38 while let Some(connection) = swansong.interrupt(quic_binding.accept()).await.flatten() {
39 let h3 = H3Connection::new(context.clone());
40 let handler = handler.clone();
41 let runtime = runtime.clone();
42 runtime
43 .clone()
44 .spawn(run_h3_connection(connection, h3, handler, runtime));
45 }
46}
47
48async fn run_h3_connection<QC: QuicConnectionTrait>(
49 connection: QC,
50 h3: Arc<H3Connection>,
51 handler: ArcHandler<impl Handler>,
52 runtime: impl RuntimeTrait,
53) {
54 let wt_dispatcher = h3
55 .context()
56 .config()
57 .webtransport_enabled()
58 .then(WebTransportDispatcher::new);
59
60 log::trace!("new quic connection from {}", connection.remote_address());
61
62 spawn_outbound_control_stream(&connection, &h3, &runtime);
63 spawn_qpack_encoder_stream(&connection, &h3, &runtime);
64 spawn_qpack_decoder_stream(&connection, &h3, &runtime);
65 spawn_inbound_uni_streams(&connection, &h3, &runtime, &wt_dispatcher);
66 handle_inbound_bidi_streams(connection, h3.clone(), handler, runtime, wt_dispatcher).await;
67}
68
69async fn handle_inbound_bidi_streams<QC: QuicConnectionTrait>(
70 connection: QC,
71 h3: Arc<H3Connection>,
72 handler: ArcHandler<impl Handler>,
73 runtime: impl RuntimeTrait,
74 wt_dispatcher: Option<WebTransportDispatcher>,
75) {
76 loop {
77 match h3.swansong().interrupt(connection.accept_bidi()).await {
78 None => {
79 log::trace!("H3 bidi accept loop: interrupted by swansong shutdown");
80 break;
81 }
82 Some(Err(e)) => {
83 log::debug!("H3 bidi accept loop: accept_bidi error: {e}");
84 break;
85 }
86 Some(Ok((stream_id, transport))) => {
87 handle_bidi_stream(
88 stream_id,
89 transport,
90 &h3,
91 &handler,
92 &connection,
93 &runtime,
94 &wt_dispatcher,
95 );
96 }
97 }
98 }
99
100 h3.shut_down();
101}
102
103fn handle_bidi_stream<QC: QuicConnectionTrait>(
104 stream_id: u64,
105 transport: QC::BidiStream,
106 h3: &Arc<H3Connection>,
107 handler: &ArcHandler<impl Handler>,
108 connection: &QC,
109 runtime: &impl RuntimeTrait,
110 wt_dispatcher: &Option<WebTransportDispatcher>,
111) {
112 log::trace!("H3 bidi stream {stream_id}: spawning handler task");
113 let (h3, handler, connection, wt_dispatcher) = (
114 h3.clone(),
115 handler.clone(),
116 connection.clone(),
117 wt_dispatcher.clone(),
118 );
119
120 runtime.spawn(async move {
121 let handler = &handler;
122 let peer_ip = connection.remote_address().ip();
123 let quic_connection = connection.clone();
124 let wt_dispatcher = wt_dispatcher.clone();
125
126 let handler_fn = {
127 let wt_dispatcher = wt_dispatcher.clone();
128 |mut conn: trillium_http::Conn<_>| async move {
129 conn.set_peer_ip(Some(peer_ip));
130 conn.set_secure(true);
131
132 let state = conn.state_mut();
133 state.insert(quic_connection.clone());
134 state.insert(QuicConnection::from(quic_connection));
135 state.insert(StreamId(stream_id));
136 if let Some(dispatcher) = wt_dispatcher {
137 state.insert(dispatcher);
138 }
139
140 let conn = handler.run(conn.into()).await;
141 let conn = handler.before_send(conn).await;
142
143 conn.into_inner()
144 }
145 };
146
147 let result = h3
148 .clone()
149 .process_inbound_bidi_with_reset(transport, handler_fn, stream_id, |t, code| {
150 let raw = u64::from(code);
155 t.stop(raw);
156 t.reset(raw);
157 })
158 .await;
159
160 match result {
161 Ok(H3StreamResult::Request(conn)) if conn.should_upgrade() => {
162 let upgrade = Upgrade::from(conn);
163 if handler.has_upgrade(&upgrade) {
164 log::debug!("upgrading h3 stream");
165 handler.upgrade(upgrade).await;
166 } else {
167 log::error!("h3 upgrade specified but no upgrade handler provided");
168 }
169 }
170
171 Ok(H3StreamResult::Request(_)) => {}
172
173 Ok(H3StreamResult::WebTransport {
174 session_id,
175 mut transport,
176 buffer,
177 }) => {
178 if let Some(dispatcher) = &wt_dispatcher {
179 dispatcher.dispatch(WebTransportStream::Bidi {
180 session_id,
181 stream: Box::new(transport),
182 buffer: buffer.into(),
183 });
184 } else {
185 transport.stop(H3ErrorCode::StreamCreationError.into());
186 transport.reset(H3ErrorCode::StreamCreationError.into());
187 }
188 }
189
190 Err(error) => {
191 log::debug!("H3 bidi stream {stream_id}: error: {error}");
192 handle_h3_error(error, &connection, &h3);
193 }
194 }
195 });
196}
197
198fn spawn_inbound_uni_streams<QC: QuicConnectionTrait>(
199 connection: &QC,
200 h3: &Arc<H3Connection>,
201 runtime: &impl RuntimeTrait,
202 wt_dispatcher: &Option<WebTransportDispatcher>,
203) {
204 let (connection, h3, runtime, wt_dispatcher) = (
205 connection.clone(),
206 h3.clone(),
207 runtime.clone(),
208 wt_dispatcher.clone(),
209 );
210 runtime.clone().spawn(async move {
211 while let Some(Ok((_stream_id, recv))) =
212 h3.swansong().interrupt(connection.accept_uni()).await
213 {
214 let (connection, h3, wt_dispatcher) =
215 (connection.clone(), h3.clone(), wt_dispatcher.clone());
216
217 runtime.spawn(async move {
218 let close_connection = {
226 let connection = connection.clone();
227 let h3 = h3.clone();
228 move |code: H3ErrorCode| {
229 connection.close(code.into(), code.reason().as_bytes());
230 h3.shut_down();
231 }
232 };
233 let result = h3
234 .process_inbound_uni_with_close(recv, close_connection)
235 .await;
236
237 match result {
238 Ok(UniStreamResult::Handled) => {}
239 Ok(UniStreamResult::WebTransport {
240 session_id,
241 mut stream,
242 buffer,
243 }) => {
244 if let Some(dispatcher) = &wt_dispatcher {
245 dispatcher.dispatch(WebTransportStream::Uni {
246 session_id,
247 stream: Box::new(stream),
248 buffer: buffer.into(),
249 });
250 } else {
251 stream.stop(H3ErrorCode::StreamCreationError.into());
252 }
253 }
254
255 Ok(UniStreamResult::Unknown { mut stream, .. }) => {
256 stream.stop(H3ErrorCode::StreamCreationError.into());
257 }
258
259 Err(error) => {
260 handle_h3_error(error, &connection, &h3);
264 }
265 }
266 });
267 }
268
269 h3.shut_down();
270 });
271}
272
273fn spawn_qpack_decoder_stream<QC: QuicConnectionTrait>(
274 connection: &QC,
275 h3: &Arc<H3Connection>,
276 runtime: &impl RuntimeTrait,
277) {
278 let (connection, h3) = (connection.clone(), h3.clone());
279
280 runtime.spawn(async move {
281 log::trace!("H3: opening outbound QPACK decoder stream");
282 let stream = match connection.open_uni().await {
283 Ok((_stream_id, stream)) => stream,
284 Err(err) => {
285 log::error!("H3: open_uni for QPACK decoder stream failed: {err:?}");
286 h3.shut_down();
287 return;
288 }
289 };
290
291 let result = h3.run_decoder(stream).await;
292
293 if let Err(error) = result {
294 handle_h3_error(error, &connection, &h3);
295 }
296
297 h3.shut_down();
298 });
299}
300
301fn spawn_qpack_encoder_stream<QC: QuicConnectionTrait>(
302 connection: &QC,
303 h3: &Arc<H3Connection>,
304 runtime: &impl RuntimeTrait,
305) {
306 let (connection, h3) = (connection.clone(), h3.clone());
307 runtime.spawn(async move {
308 log::trace!("H3: opening outbound QPACK encoder stream");
309 let stream = match connection.open_uni().await {
310 Ok((_stream_id, stream)) => stream,
311 Err(err) => {
312 log::error!("H3: open_uni for QPACK encoder stream failed: {err:?}");
313 h3.shut_down();
314 return;
315 }
316 };
317
318 let result = h3.run_encoder(stream).await;
319
320 if let Err(error) = result {
321 handle_h3_error(error, &connection, &h3);
322 }
323
324 h3.shut_down();
325 });
326}
327
328fn spawn_outbound_control_stream<QC: QuicConnectionTrait>(
329 connection: &QC,
330 h3: &Arc<H3Connection>,
331 runtime: &impl RuntimeTrait,
332) {
333 let (connection, h3) = (connection.clone(), h3.clone());
334 runtime.spawn(async move {
335 log::trace!("H3: opening outbound control stream");
336 let stream = match connection.open_uni().await {
337 Ok((_stream_id, stream)) => stream,
338 Err(err) => {
339 log::error!("H3: open_uni for outbound control stream failed: {err:?}");
340 h3.shut_down();
341 return;
342 }
343 };
344
345 let result = h3.run_outbound_control(stream).await;
346
347 if let Err(error) = result {
348 handle_h3_error(error, &connection, &h3);
349 }
350
351 h3.shut_down();
352 });
353}
354
355fn handle_h3_error(error: H3Error, connection: &impl QuicConnectionTrait, h3: &H3Connection) {
356 log::debug!("H3 error: {error}");
357 if let H3Error::Protocol(code) = error
358 && code.is_connection_error()
359 {
360 connection.close(code.into(), code.reason().as_bytes());
361 h3.shut_down();
362 }
363}