1use crate::{
2 connection::QuinnConnection,
3 runtime::{SocketTransport, TrilliumRuntime},
4};
5use rustls::server::ResolvesServerCert;
6use std::{io, net::SocketAddr, sync::Arc};
7use trillium_server_common::{Info, QuicConfig as QuicConfigTrait, QuicEndpoint, Server};
8
9pub struct QuicConfig(quinn::ServerConfig);
21
22impl QuicConfig {
23 pub fn from_single_cert(cert_pem: &[u8], key_pem: &[u8]) -> Self {
28 let certs: Vec<_> = rustls_pemfile::certs(&mut io::BufReader::new(cert_pem))
29 .collect::<Result<_, _>>()
30 .expect("parsing certificate PEM");
31
32 let key = rustls_pemfile::private_key(&mut io::BufReader::new(key_pem))
33 .expect("parsing private key PEM")
34 .expect("no private key found in PEM");
35
36 let mut tls_config =
37 rustls::ServerConfig::builder_with_provider(crate::crypto_provider::crypto_provider())
38 .with_safe_default_protocol_versions()
39 .expect("building TLS config with protocol versions")
40 .with_no_client_auth()
41 .with_single_cert(certs, key)
42 .expect("building TLS config");
43
44 tls_config.alpn_protocols = vec![b"h3".to_vec()];
45
46 let quic_tls = quinn::crypto::rustls::QuicServerConfig::try_from(Arc::new(tls_config))
47 .expect("building QUIC TLS config");
48
49 Self(quinn::ServerConfig::with_crypto(Arc::new(quic_tls)))
50 }
51
52 pub fn from_rustls_server_config(tls_config: rustls::ServerConfig) -> Self {
57 let mut tls_config = tls_config;
58 if !tls_config.alpn_protocols.contains(&b"h3".to_vec()) {
59 tls_config.alpn_protocols.push(b"h3".to_vec());
60 }
61 let quic_tls = quinn::crypto::rustls::QuicServerConfig::try_from(Arc::new(tls_config))
62 .expect("building QUIC TLS config");
63 Self(quinn::ServerConfig::with_crypto(Arc::new(quic_tls)))
64 }
65
66 pub fn from_quinn_server_config(config: quinn::ServerConfig) -> Self {
72 Self(config)
73 }
74
75 pub fn from_cert_resolver(resolver: Arc<dyn ResolvesServerCert>) -> Self {
87 let tls_config =
88 rustls::ServerConfig::builder_with_provider(crate::crypto_provider::crypto_provider())
89 .with_safe_default_protocol_versions()
90 .expect("building TLS config with protocol versions")
91 .with_no_client_auth()
92 .with_cert_resolver(resolver);
93 Self::from_rustls_server_config(tls_config)
94 }
95}
96
97impl<S> QuicConfigTrait<S> for QuicConfig
98where
99 S: Server,
100 S::Runtime: Unpin,
101 S::UdpTransport: SocketTransport,
102{
103 type Endpoint = QuinnEndpoint;
104
105 fn bind(
106 self,
107 addr: SocketAddr,
108 runtime: S::Runtime,
109 info: &mut Info,
110 ) -> Option<io::Result<Self::Endpoint>> {
111 let socket = match std::net::UdpSocket::bind(addr) {
112 Ok(s) => s,
113 Err(e) => return Some(Err(e)),
114 };
115 Some(<Self as QuicConfigTrait<S>>::bind_with_socket(
116 self, socket, runtime, info,
117 ))
118 }
119
120 fn bind_with_socket(
121 self,
122 socket: std::net::UdpSocket,
123 runtime: S::Runtime,
124 _info: &mut Info,
125 ) -> io::Result<Self::Endpoint> {
126 let quinn_runtime = TrilliumRuntime::<S::Runtime, S::UdpTransport>::new(runtime);
127 quinn::Endpoint::new(
128 quinn::EndpointConfig::default(),
129 Some(self.0),
130 socket,
131 quinn_runtime,
132 )
133 .map(QuinnEndpoint::new)
134 }
135}
136
137pub struct QuinnEndpoint(quinn::Endpoint);
139
140impl QuinnEndpoint {
141 pub(crate) fn new(endpoint: quinn::Endpoint) -> Self {
143 Self(endpoint)
144 }
145}
146
147impl QuicEndpoint for QuinnEndpoint {
148 type Connection = QuinnConnection;
149
150 async fn accept(&self) -> Option<Self::Connection> {
151 loop {
152 let incoming = self.0.accept().await?;
153 match incoming.await {
154 Ok(connection) => return Some(QuinnConnection::new(connection)),
155 Err(e) => log::error!("QUIC accept failed: {e}"),
156 }
157 }
158 }
159
160 async fn connect(&self, addr: SocketAddr, server_name: &str) -> io::Result<Self::Connection> {
161 let connection = self
162 .0
163 .connect(addr, server_name)
164 .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?
165 .await
166 .map_err(|e| io::Error::new(io::ErrorKind::ConnectionRefused, e))?;
167 Ok(QuinnConnection::new(connection))
168 }
169
170 fn local_addr(&self) -> io::Result<SocketAddr> {
171 self.0.local_addr()
172 }
173}