1use crate::{
2 connection::QuinnConnection,
3 runtime::{SocketTransport, TrilliumRuntime},
4};
5use rustls::server::ResolvesServerCert;
6use std::{io, net::SocketAddr, sync::Arc};
7use trillium_server_common::{Info, QuicConfig as QuicConfigTrait, QuicEndpoint, Server};
8
9pub struct QuicConfig(quinn::ServerConfig);
21
22impl QuicConfig {
23 pub fn from_single_cert(cert_pem: &[u8], key_pem: &[u8]) -> Self {
28 let certs: Vec<_> = rustls_pemfile::certs(&mut io::BufReader::new(cert_pem))
29 .collect::<Result<_, _>>()
30 .expect("parsing certificate PEM");
31
32 let key = rustls_pemfile::private_key(&mut io::BufReader::new(key_pem))
33 .expect("parsing private key PEM")
34 .expect("no private key found in PEM");
35
36 let mut tls_config =
37 rustls::ServerConfig::builder_with_provider(crate::crypto_provider::crypto_provider())
38 .with_safe_default_protocol_versions()
39 .expect("building TLS config with protocol versions")
40 .with_no_client_auth()
41 .with_single_cert(certs, key)
42 .expect("building TLS config");
43
44 tls_config.alpn_protocols = vec![b"h3".to_vec()];
45
46 let quic_tls = quinn::crypto::rustls::QuicServerConfig::try_from(Arc::new(tls_config))
47 .expect("building QUIC TLS config");
48
49 Self(quinn::ServerConfig::with_crypto(Arc::new(quic_tls)))
50 }
51
52 pub fn from_rustls_server_config(tls_config: rustls::ServerConfig) -> Self {
57 let mut tls_config = tls_config;
58 if !tls_config.alpn_protocols.contains(&b"h3".to_vec()) {
59 tls_config.alpn_protocols.push(b"h3".to_vec());
60 }
61 let quic_tls = quinn::crypto::rustls::QuicServerConfig::try_from(Arc::new(tls_config))
62 .expect("building QUIC TLS config");
63 Self(quinn::ServerConfig::with_crypto(Arc::new(quic_tls)))
64 }
65
66 pub fn from_quinn_server_config(config: quinn::ServerConfig) -> Self {
72 Self(config)
73 }
74
75 pub fn from_cert_resolver(resolver: Arc<dyn ResolvesServerCert>) -> Self {
87 let tls_config =
88 rustls::ServerConfig::builder_with_provider(crate::crypto_provider::crypto_provider())
89 .with_safe_default_protocol_versions()
90 .expect("building TLS config with protocol versions")
91 .with_no_client_auth()
92 .with_cert_resolver(resolver);
93 Self::from_rustls_server_config(tls_config)
94 }
95}
96
97impl<S> QuicConfigTrait<S> for QuicConfig
98where
99 S: Server,
100 S::Runtime: Unpin,
101 S::UdpTransport: SocketTransport,
102{
103 type Endpoint = QuinnEndpoint;
104
105 fn bind(
106 self,
107 addr: SocketAddr,
108 runtime: S::Runtime,
109 _info: &mut Info,
110 ) -> Option<io::Result<Self::Endpoint>> {
111 let quinn_runtime = TrilliumRuntime::<S::Runtime, S::UdpTransport>::new(runtime);
112 let socket = match std::net::UdpSocket::bind(addr) {
113 Ok(s) => s,
114 Err(e) => return Some(Err(e)),
115 };
116
117 Some(
118 quinn::Endpoint::new(
119 quinn::EndpointConfig::default(),
120 Some(self.0),
121 socket,
122 quinn_runtime,
123 )
124 .map(QuinnEndpoint::new),
125 )
126 }
127}
128
129pub struct QuinnEndpoint(quinn::Endpoint);
131
132impl QuinnEndpoint {
133 pub(crate) fn new(endpoint: quinn::Endpoint) -> Self {
135 Self(endpoint)
136 }
137}
138
139impl QuicEndpoint for QuinnEndpoint {
140 type Connection = QuinnConnection;
141
142 async fn accept(&self) -> Option<Self::Connection> {
143 loop {
144 let incoming = self.0.accept().await?;
145 match incoming.await {
146 Ok(connection) => return Some(QuinnConnection::new(connection)),
147 Err(e) => log::error!("QUIC accept failed: {e}"),
148 }
149 }
150 }
151
152 async fn connect(&self, addr: SocketAddr, server_name: &str) -> io::Result<Self::Connection> {
153 let connection = self
154 .0
155 .connect(addr, server_name)
156 .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?
157 .await
158 .map_err(|e| io::Error::new(io::ErrorKind::ConnectionRefused, e))?;
159 Ok(QuinnConnection::new(connection))
160 }
161}