1use crate::{
2 connection::QuinnConnection,
3 runtime::{SocketTransport, TrilliumRuntime},
4};
5use rustls::server::ResolvesServerCert;
6use std::{
7 borrow::Cow,
8 collections::HashMap,
9 io,
10 net::SocketAddr,
11 sync::{Arc, Mutex},
12};
13use trillium_server_common::{Info, QuicConfig as QuicConfigTrait, QuicEndpoint, Server};
14
15pub struct QuicConfig(quinn::ServerConfig);
27
28impl QuicConfig {
29 pub fn from_single_cert(cert_pem: &[u8], key_pem: &[u8]) -> Self {
34 let certs: Vec<_> = rustls_pemfile::certs(&mut io::BufReader::new(cert_pem))
35 .collect::<Result<_, _>>()
36 .expect("parsing certificate PEM");
37
38 let key = rustls_pemfile::private_key(&mut io::BufReader::new(key_pem))
39 .expect("parsing private key PEM")
40 .expect("no private key found in PEM");
41
42 let mut tls_config =
43 rustls::ServerConfig::builder_with_provider(crate::crypto_provider::crypto_provider())
44 .with_safe_default_protocol_versions()
45 .expect("building TLS config with protocol versions")
46 .with_no_client_auth()
47 .with_single_cert(certs, key)
48 .expect("building TLS config");
49
50 tls_config.alpn_protocols = vec![b"h3".to_vec()];
51
52 let quic_tls = quinn::crypto::rustls::QuicServerConfig::try_from(Arc::new(tls_config))
53 .expect("building QUIC TLS config");
54
55 Self(quinn::ServerConfig::with_crypto(Arc::new(quic_tls)))
56 }
57
58 pub fn from_rustls_server_config(tls_config: rustls::ServerConfig) -> Self {
63 let mut tls_config = tls_config;
64 if !tls_config.alpn_protocols.contains(&b"h3".to_vec()) {
65 tls_config.alpn_protocols.push(b"h3".to_vec());
66 }
67 let quic_tls = quinn::crypto::rustls::QuicServerConfig::try_from(Arc::new(tls_config))
68 .expect("building QUIC TLS config");
69 Self(quinn::ServerConfig::with_crypto(Arc::new(quic_tls)))
70 }
71
72 pub fn from_quinn_server_config(config: quinn::ServerConfig) -> Self {
78 Self(config)
79 }
80
81 #[must_use]
87 pub fn with_transport_config(mut self, transport: Arc<quinn::TransportConfig>) -> Self {
88 self.0.transport_config(transport);
89 self
90 }
91
92 pub fn from_cert_resolver(resolver: Arc<dyn ResolvesServerCert>) -> Self {
104 let tls_config =
105 rustls::ServerConfig::builder_with_provider(crate::crypto_provider::crypto_provider())
106 .with_safe_default_protocol_versions()
107 .expect("building TLS config with protocol versions")
108 .with_no_client_auth()
109 .with_cert_resolver(resolver);
110 Self::from_rustls_server_config(tls_config)
111 }
112}
113
114impl<S> QuicConfigTrait<S> for QuicConfig
115where
116 S: Server,
117 S::Runtime: Unpin,
118 S::UdpTransport: SocketTransport,
119{
120 type Endpoint = QuinnEndpoint;
121
122 fn bind(
123 self,
124 addr: SocketAddr,
125 runtime: S::Runtime,
126 info: &mut Info,
127 ) -> Option<io::Result<Self::Endpoint>> {
128 let socket = match std::net::UdpSocket::bind(addr) {
129 Ok(s) => s,
130 Err(e) => return Some(Err(e)),
131 };
132 Some(<Self as QuicConfigTrait<S>>::bind_with_socket(
133 self, socket, runtime, info,
134 ))
135 }
136
137 fn bind_with_socket(
138 self,
139 socket: std::net::UdpSocket,
140 runtime: S::Runtime,
141 _info: &mut Info,
142 ) -> io::Result<Self::Endpoint> {
143 let quinn_runtime = TrilliumRuntime::<S::Runtime, S::UdpTransport>::new(runtime);
144 quinn::Endpoint::new(
145 quinn::EndpointConfig::default(),
146 Some(self.0),
147 socket,
148 quinn_runtime,
149 )
150 .map(QuinnEndpoint::new)
151 }
152}
153
154pub struct QuinnEndpoint {
156 endpoint: quinn::Endpoint,
157 base_tls: Option<Arc<rustls::ClientConfig>>,
162 alpn_configs: Mutex<HashMap<Vec<Vec<u8>>, quinn::ClientConfig>>,
166}
167
168impl QuinnEndpoint {
169 pub(crate) fn new(endpoint: quinn::Endpoint) -> Self {
172 Self {
173 endpoint,
174 base_tls: None,
175 alpn_configs: Mutex::new(HashMap::new()),
176 }
177 }
178
179 pub(crate) fn new_client(
181 endpoint: quinn::Endpoint,
182 base_tls: Option<Arc<rustls::ClientConfig>>,
183 ) -> Self {
184 Self {
185 endpoint,
186 base_tls,
187 alpn_configs: Mutex::new(HashMap::new()),
188 }
189 }
190
191 fn client_config_for_alpn(
194 &self,
195 alpn: &[Cow<'static, [u8]>],
196 ) -> io::Result<Option<quinn::ClientConfig>> {
197 let Some(base) = &self.base_tls else {
198 return Ok(None);
199 };
200 let key: Vec<Vec<u8>> = alpn.iter().map(|a| a.to_vec()).collect();
201
202 let mut cache = self.alpn_configs.lock().unwrap();
203 if let Some(config) = cache.get(&key) {
204 return Ok(Some(config.clone()));
205 }
206
207 let mut tls = (**base).clone();
208 tls.alpn_protocols = key.clone();
209 let quic_tls = quinn::crypto::rustls::QuicClientConfig::try_from(Arc::new(tls))
210 .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
211 let config = quinn::ClientConfig::new(Arc::new(quic_tls));
212 cache.insert(key, config.clone());
213 Ok(Some(config))
214 }
215}
216
217impl QuicEndpoint for QuinnEndpoint {
218 type Connection = QuinnConnection;
219
220 async fn accept(&self) -> Option<Self::Connection> {
221 loop {
222 let incoming = self.endpoint.accept().await?;
223 match incoming.await {
224 Ok(connection) => return Some(QuinnConnection::new(connection)),
225 Err(e) => log::error!("QUIC accept failed: {e}"),
226 }
227 }
228 }
229
230 async fn connect(&self, addr: SocketAddr, server_name: &str) -> io::Result<Self::Connection> {
231 let connection = self
232 .endpoint
233 .connect(addr, server_name)
234 .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?
235 .await
236 .map_err(|e| io::Error::new(io::ErrorKind::ConnectionRefused, e))?;
237 Ok(QuinnConnection::new(connection))
238 }
239
240 async fn connect_with_alpn(
241 &self,
242 addr: SocketAddr,
243 server_name: &str,
244 alpn: &[Cow<'static, [u8]>],
245 ) -> io::Result<Self::Connection> {
246 let Some(config) = (!alpn.is_empty())
249 .then(|| self.client_config_for_alpn(alpn))
250 .transpose()?
251 .flatten()
252 else {
253 return self.connect(addr, server_name).await;
254 };
255
256 let connection = self
257 .endpoint
258 .connect_with(config, addr, server_name)
259 .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?
260 .await
261 .map_err(|e| io::Error::new(io::ErrorKind::ConnectionRefused, e))?;
262 Ok(QuinnConnection::new(connection))
263 }
264
265 fn local_addr(&self) -> io::Result<SocketAddr> {
266 self.endpoint.local_addr()
267 }
268}