1pub mod web_transport;
4use crate::{
5 ArcHandler, QuicConnection, QuicConnectionTrait, QuicEndpoint, QuicTransportReceive,
6 QuicTransportSend, Runtime,
7};
8use std::sync::Arc;
9use trillium::{Handler, Upgrade};
10use trillium_http::{
11 HttpContext,
12 h3::{H3Connection, H3Error, H3ErrorCode, H3StreamResult, UniStreamResult},
13};
14use web_transport::{WebTransportDispatcher, WebTransportStream};
15
16#[derive(Clone, Copy, Debug)]
18pub struct StreamId(u64);
19impl From<StreamId> for u64 {
20 fn from(val: StreamId) -> Self {
21 val.0
22 }
23}
24
25impl From<u64> for StreamId {
26 fn from(value: u64) -> Self {
27 Self(value)
28 }
29}
30
31pub(crate) async fn run_h3<QE: QuicEndpoint>(
32 quic_binding: QE,
33 context: Arc<HttpContext>,
34 handler: ArcHandler<impl Handler>,
35 runtime: Runtime,
36) {
37 let swansong = context.swansong();
38 while let Some(connection) = swansong.interrupt(quic_binding.accept()).await.flatten() {
39 let h3 = H3Connection::new(context.clone());
40 let handler = handler.clone();
41 let runtime = runtime.clone();
42 runtime
43 .clone()
44 .spawn(run_h3_connection(connection, h3, handler, runtime));
45 }
46}
47
48async fn run_h3_connection<QC: QuicConnectionTrait>(
49 connection: QC,
50 h3: Arc<H3Connection>,
51 handler: ArcHandler<impl Handler>,
52 runtime: Runtime,
53) {
54 let wt_dispatcher = h3
55 .context()
56 .config()
57 .webtransport_enabled()
58 .then(WebTransportDispatcher::new);
59
60 log::trace!("new quic connection from {}", connection.remote_address());
61
62 spawn_outbound_control_stream(&connection, &h3, &runtime);
63 spawn_qpack_encoder_stream(&connection, &h3, &runtime);
64 spawn_qpack_decoder_stream(&connection, &h3, &runtime);
65 spawn_inbound_uni_streams(&connection, &h3, &runtime, &wt_dispatcher);
66 handle_inbound_bidi_streams(connection, h3, handler, runtime, wt_dispatcher).await;
67}
68
69async fn handle_inbound_bidi_streams<QC: QuicConnectionTrait>(
70 connection: QC,
71 h3: Arc<H3Connection>,
72 handler: ArcHandler<impl Handler>,
73 runtime: Runtime,
74 wt_dispatcher: Option<WebTransportDispatcher>,
75) {
76 let swansong = h3.swansong().clone();
77 while let Some(Ok((stream_id, transport))) = swansong.interrupt(connection.accept_bidi()).await
78 {
79 let (h3, handler, connection, wt_dispatcher) = (
80 h3.clone(),
81 handler.clone(),
82 connection.clone(),
83 wt_dispatcher.clone(),
84 );
85 let peer_ip = connection.remote_address().ip();
86 runtime.spawn(async move {
87 let handler = &handler;
88 let quic_connection = connection.clone();
89 let wt_dispatcher = wt_dispatcher.clone();
90 let result = h3
91 .clone()
92 .process_inbound_bidi(
93 transport,
94 {
95 let wt_dispatcher = wt_dispatcher.clone();
96 |mut conn| async move {
97 conn.set_peer_ip(Some(peer_ip));
98 conn.set_secure(true);
99 let state = conn.state_mut();
100 state.insert(quic_connection.clone());
101 state.insert(QuicConnection::from(quic_connection));
102 state.insert(StreamId(stream_id));
103 if let Some(dispatcher) = wt_dispatcher {
104 state.insert(dispatcher);
105 }
106 let conn = handler.run(conn.into()).await;
107 let conn = handler.before_send(conn).await;
108 conn.into_inner()
109 }
110 },
111 stream_id,
112 )
113 .await;
114
115 match result {
116 Ok(H3StreamResult::Request(conn)) if conn.should_upgrade() => {
117 let upgrade = Upgrade::from(conn);
118 if handler.has_upgrade(&upgrade) {
119 log::debug!("upgrading h3 stream");
120 handler.upgrade(upgrade).await;
121 } else {
122 log::error!("h3 upgrade specified but no upgrade handler provided");
123 }
124 }
125 Ok(H3StreamResult::Request(_)) => {}
126 Ok(H3StreamResult::WebTransport {
127 session_id,
128 mut transport,
129 buffer,
130 }) => {
131 if let Some(dispatcher) = &wt_dispatcher {
132 dispatcher.dispatch(WebTransportStream::Bidi {
133 session_id,
134 stream: Box::new(transport),
135 buffer: buffer.into(),
136 });
137 } else {
138 transport.stop(H3ErrorCode::StreamCreationError.into());
139 transport.reset(H3ErrorCode::StreamCreationError.into());
140 }
141 }
142 Err(error) => handle_h3_error(error, &connection, &h3).await,
143 }
144 });
145 }
146}
147
148fn spawn_inbound_uni_streams<QC: QuicConnectionTrait>(
149 connection: &QC,
150 h3: &Arc<H3Connection>,
151 runtime: &Runtime,
152 wt_dispatcher: &Option<WebTransportDispatcher>,
153) {
154 let (connection, h3, runtime, wt_dispatcher) = (
155 connection.clone(),
156 h3.clone(),
157 runtime.clone(),
158 wt_dispatcher.clone(),
159 );
160 runtime.clone().spawn(async move {
161 while let Ok((_stream_id, recv)) = connection.accept_uni().await {
162 let (connection, h3, wt_dispatcher) =
163 (connection.clone(), h3.clone(), wt_dispatcher.clone());
164 runtime.spawn(async move {
165 match h3.process_inbound_uni(recv).await {
166 Ok(UniStreamResult::Handled) => {}
167 Ok(UniStreamResult::WebTransport {
168 session_id,
169 mut stream,
170 buffer,
171 }) => {
172 if let Some(dispatcher) = &wt_dispatcher {
173 dispatcher.dispatch(WebTransportStream::Uni {
174 session_id,
175 stream: Box::new(stream),
176 buffer: buffer.into(),
177 });
178 } else {
179 stream.stop(H3ErrorCode::StreamCreationError.into());
180 }
181 }
182 Ok(UniStreamResult::Unknown { mut stream, .. }) => {
183 stream.stop(H3ErrorCode::StreamCreationError.into());
184 }
185 Err(error) => {
186 handle_h3_error(error, &connection, &h3).await;
187 }
188 }
189 });
190 }
191 });
192}
193
194fn spawn_qpack_decoder_stream<QC: QuicConnectionTrait>(
195 connection: &QC,
196 h3: &Arc<H3Connection>,
197 runtime: &Runtime,
198) {
199 let (connection, h3) = (connection.clone(), h3.clone());
200 runtime.spawn(async move {
201 let result: Result<(), H3Error> =
202 async { h3.run_decoder(connection.open_uni().await?.1).await }.await;
203 if let Err(error) = result {
204 handle_h3_error(error, &connection, &h3).await;
205 }
206 });
207}
208
209fn spawn_qpack_encoder_stream<QC: QuicConnectionTrait>(
210 connection: &QC,
211 h3: &Arc<H3Connection>,
212 runtime: &Runtime,
213) {
214 let (connection, h3) = (connection.clone(), h3.clone());
215 runtime.spawn(async move {
216 let result: Result<(), H3Error> =
217 async { h3.run_encoder(connection.open_uni().await?.1).await }.await;
218 if let Err(error) = result {
219 handle_h3_error(error, &connection, &h3).await;
220 }
221 });
222}
223
224fn spawn_outbound_control_stream<QC: QuicConnectionTrait>(
225 connection: &QC,
226 h3: &Arc<H3Connection>,
227 runtime: &Runtime,
228) {
229 let (connection, h3) = (connection.clone(), h3.clone());
230 runtime.spawn(async move {
231 let guard = h3.swansong().guard();
232
233 let result: Result<(), H3Error> = async {
234 h3.run_outbound_control(connection.open_uni().await?.1)
235 .await
236 }
237 .await;
238 drop(guard);
239 if let Err(error) = result {
240 handle_h3_error(error, &connection, &h3).await;
241 }
242 });
243}
244
245async fn handle_h3_error(error: H3Error, connection: &impl QuicConnectionTrait, h3: &H3Connection) {
246 log::debug!("H3 error: {error}");
247 if let H3Error::Protocol(code) = error {
248 connection.close(code.into(), code.reason().as_bytes());
249 }
250 h3.shut_down().await;
251}