1use crate::crypto_provider;
2use RustlsClientTransportInner::{Tcp, Tls};
3#[cfg(feature = "dangerous")]
4use futures_rustls::rustls::{
5 DigitallySignedStruct, SignatureScheme,
6 client::danger::{HandshakeSignatureValid, ServerCertVerified},
7 crypto::{verify_tls12_signature, verify_tls13_signature},
8 pki_types::{CertificateDer, UnixTime},
9};
10use futures_rustls::{
11 TlsConnector,
12 client::TlsStream,
13 rustls::{
14 ClientConfig, ClientConnection, RootCertStore,
15 client::{WebPkiServerVerifier, danger::ServerCertVerifier},
16 crypto::CryptoProvider,
17 pki_types::ServerName,
18 },
19};
20use std::{
21 fmt::{self, Debug, Formatter},
22 io::{Error, ErrorKind, IoSlice, Result},
23 net::SocketAddr,
24 pin::Pin,
25 sync::Arc,
26 task::{Context, Poll},
27};
28use trillium_server_common::{AsyncRead, AsyncWrite, Connector, Transport, Url, url::Host};
29
30#[derive(Clone, Debug)]
37pub struct RustlsClientConfig(Arc<ClientConfig>);
38
39#[derive(Clone, Default)]
41pub struct RustlsConfig<Config> {
42 pub rustls_config: RustlsClientConfig,
44
45 pub tcp_config: Config,
47}
48
49impl<C: Connector> RustlsConfig<C> {
50 pub fn new(rustls_config: impl Into<RustlsClientConfig>, tcp_config: C) -> Self {
52 Self {
53 rustls_config: rustls_config.into(),
54 tcp_config,
55 }
56 }
57}
58
59impl Default for RustlsClientConfig {
60 fn default() -> Self {
61 Self(Arc::new(default_client_config()))
62 }
63}
64
65#[cfg(feature = "platform-verifier")]
66fn verifier(provider: Arc<CryptoProvider>) -> Arc<dyn ServerCertVerifier> {
67 Arc::new(rustls_platform_verifier::Verifier::new(provider).unwrap())
68}
69
70#[cfg(not(feature = "platform-verifier"))]
71fn verifier(provider: Arc<CryptoProvider>) -> Arc<dyn ServerCertVerifier> {
72 let roots = Arc::new(RootCertStore::from_iter(
73 webpki_roots::TLS_SERVER_ROOTS.iter().cloned(),
74 ));
75 WebPkiServerVerifier::builder_with_provider(roots, provider)
76 .build()
77 .unwrap()
78}
79
80fn client_config_with_verifier(verifier: Arc<dyn ServerCertVerifier>) -> ClientConfig {
81 let mut config = ClientConfig::builder_with_provider(crypto_provider())
82 .with_safe_default_protocol_versions()
83 .expect("crypto provider did not support safe default protocol versions")
84 .dangerous()
85 .with_custom_certificate_verifier(verifier)
86 .with_no_client_auth();
87
88 config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
89
90 config
91}
92
93fn default_client_config() -> ClientConfig {
94 client_config_with_verifier(verifier(crypto_provider()))
95}
96
97impl RustlsClientConfig {
98 pub fn from_root_cert_pem(pem: &[u8]) -> Result<Self> {
114 let mut roots = RootCertStore::empty();
115 let mut reader = pem;
116 for cert in rustls_pemfile::certs(&mut reader) {
117 roots.add(cert?).map_err(Error::other)?;
118 }
119
120 if roots.is_empty() {
121 return Err(Error::new(
122 ErrorKind::InvalidInput,
123 "no certificates found in pem",
124 ));
125 }
126
127 let verifier =
128 WebPkiServerVerifier::builder_with_provider(Arc::new(roots), crypto_provider())
129 .build()
130 .map_err(Error::other)?;
131
132 Ok(Self(Arc::new(client_config_with_verifier(verifier))))
133 }
134}
135
136impl From<ClientConfig> for RustlsClientConfig {
137 fn from(rustls_config: ClientConfig) -> Self {
138 Self(Arc::new(rustls_config))
139 }
140}
141
142impl From<Arc<ClientConfig>> for RustlsClientConfig {
143 fn from(rustls_config: Arc<ClientConfig>) -> Self {
144 Self(rustls_config)
145 }
146}
147
148#[cfg(feature = "dangerous")]
149#[derive(Debug)]
150struct AcceptAnyServerCert(Arc<CryptoProvider>);
151
152#[cfg(feature = "dangerous")]
153impl ServerCertVerifier for AcceptAnyServerCert {
154 fn verify_server_cert(
155 &self,
156 _end_entity: &CertificateDer<'_>,
157 _intermediates: &[CertificateDer<'_>],
158 _server_name: &ServerName<'_>,
159 _ocsp_response: &[u8],
160 _now: UnixTime,
161 ) -> std::result::Result<ServerCertVerified, futures_rustls::rustls::Error> {
162 Ok(ServerCertVerified::assertion())
163 }
164
165 fn verify_tls12_signature(
166 &self,
167 message: &[u8],
168 cert: &CertificateDer<'_>,
169 dss: &DigitallySignedStruct,
170 ) -> std::result::Result<HandshakeSignatureValid, futures_rustls::rustls::Error> {
171 verify_tls12_signature(
172 message,
173 cert,
174 dss,
175 &self.0.signature_verification_algorithms,
176 )
177 }
178
179 fn verify_tls13_signature(
180 &self,
181 message: &[u8],
182 cert: &CertificateDer<'_>,
183 dss: &DigitallySignedStruct,
184 ) -> std::result::Result<HandshakeSignatureValid, futures_rustls::rustls::Error> {
185 verify_tls13_signature(
186 message,
187 cert,
188 dss,
189 &self.0.signature_verification_algorithms,
190 )
191 }
192
193 fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
194 self.0.signature_verification_algorithms.supported_schemes()
195 }
196}
197
198#[cfg(feature = "dangerous")]
199#[cfg_attr(docsrs, doc(cfg(feature = "dangerous")))]
200impl RustlsClientConfig {
201 pub fn dangerously_accept_any_cert() -> Self {
213 log::warn!(
214 "constructing a rustls client config that accepts any server certificate; server \
215 authentication is disabled and connections are vulnerable to interception"
216 );
217 let verifier = Arc::new(AcceptAnyServerCert(crypto_provider()));
218 Self(Arc::new(client_config_with_verifier(verifier)))
219 }
220}
221
222impl<C: Connector> RustlsConfig<C> {
223 pub fn with_tcp_config(mut self, config: C) -> Self {
225 self.tcp_config = config;
226 self
227 }
228
229 #[must_use]
235 pub fn without_http2(mut self) -> Self {
236 let config = Arc::make_mut(&mut self.rustls_config.0);
237 config.alpn_protocols.retain(|p| p != b"h2");
238 self
239 }
240}
241
242impl<Config: Debug> Debug for RustlsConfig<Config> {
243 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
244 f.debug_struct("RustlsConfig")
245 .field("rustls_config", &format_args!(".."))
246 .field("tcp_config", &self.tcp_config)
247 .finish()
248 }
249}
250
251impl<C: Connector> Connector for RustlsConfig<C> {
252 type Runtime = C::Runtime;
253 type Transport = RustlsClientTransport<C::Transport>;
254 type Udp = C::Udp;
255
256 async fn connect(&self, url: &Url) -> Result<Self::Transport> {
257 match url.scheme() {
258 "https" => {
259 let mut http = url.clone();
260 http.set_scheme("http").ok();
261 http.set_port(url.port_or_known_default()).ok();
262
263 let connector: TlsConnector = Arc::clone(&self.rustls_config.0).into();
264 let domain = match url.host() {
270 Some(Host::Domain(domain)) => {
271 ServerName::try_from(domain.to_owned()).map_err(|e| {
272 Error::other(format!("invalid server name {domain:?}: {e}"))
273 })?
274 }
275 Some(Host::Ipv4(ip)) => ServerName::IpAddress(std::net::IpAddr::V4(ip).into()),
276 Some(Host::Ipv6(ip)) => ServerName::IpAddress(std::net::IpAddr::V6(ip).into()),
277 None => return Err(Error::other("url has no host")),
278 };
279
280 connector
281 .connect(domain, self.tcp_config.connect(&http).await?)
282 .await
283 .map_err(|e| Error::other(e.to_string()))
284 .map(Into::into)
285 }
286
287 "http" => self.tcp_config.connect(url).await.map(Into::into),
288
289 unknown => Err(Error::new(
290 ErrorKind::InvalidInput,
291 format!("unknown scheme {unknown}"),
292 )),
293 }
294 }
295
296 fn runtime(&self) -> Self::Runtime {
297 self.tcp_config.runtime()
298 }
299
300 async fn resolve(&self, host: &str, port: u16) -> Result<Vec<SocketAddr>> {
301 self.tcp_config.resolve(host, port).await
302 }
303}
304
305#[derive(Debug)]
306enum RustlsClientTransportInner<T> {
307 Tcp(T),
308 Tls(Box<TlsStream<T>>),
309}
310
311#[derive(Debug)]
316pub struct RustlsClientTransport<T>(RustlsClientTransportInner<T>);
317impl<T> From<T> for RustlsClientTransport<T> {
318 fn from(value: T) -> Self {
319 Self(Tcp(value))
320 }
321}
322
323impl<T> From<TlsStream<T>> for RustlsClientTransport<T> {
324 fn from(value: TlsStream<T>) -> Self {
325 Self(Tls(Box::new(value)))
326 }
327}
328
329impl<C> AsyncRead for RustlsClientTransport<C>
330where
331 C: AsyncWrite + AsyncRead + Unpin,
332{
333 fn poll_read(
334 mut self: Pin<&mut Self>,
335 cx: &mut Context<'_>,
336 buf: &mut [u8],
337 ) -> Poll<Result<usize>> {
338 match &mut self.0 {
339 Tcp(c) => Pin::new(c).poll_read(cx, buf),
340 Tls(c) => Pin::new(c).poll_read(cx, buf),
341 }
342 }
343
344 fn poll_read_vectored(
345 mut self: Pin<&mut Self>,
346 cx: &mut Context<'_>,
347 bufs: &mut [std::io::IoSliceMut<'_>],
348 ) -> Poll<Result<usize>> {
349 match &mut self.0 {
350 Tcp(c) => Pin::new(c).poll_read_vectored(cx, bufs),
351 Tls(c) => Pin::new(c).poll_read_vectored(cx, bufs),
352 }
353 }
354}
355
356impl<C> AsyncWrite for RustlsClientTransport<C>
357where
358 C: AsyncRead + AsyncWrite + Unpin,
359{
360 fn poll_write(
361 mut self: Pin<&mut Self>,
362 cx: &mut Context<'_>,
363 buf: &[u8],
364 ) -> Poll<Result<usize>> {
365 match &mut self.0 {
366 Tcp(c) => Pin::new(c).poll_write(cx, buf),
367 Tls(c) => Pin::new(&mut *c).poll_write(cx, buf),
368 }
369 }
370
371 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
372 match &mut self.0 {
373 Tcp(c) => Pin::new(c).poll_flush(cx),
374 Tls(c) => Pin::new(&mut *c).poll_flush(cx),
375 }
376 }
377
378 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
379 match &mut self.0 {
380 Tcp(c) => Pin::new(c).poll_close(cx),
381 Tls(c) => Pin::new(&mut *c).poll_close(cx),
382 }
383 }
384
385 fn poll_write_vectored(
386 mut self: Pin<&mut Self>,
387 cx: &mut Context<'_>,
388 bufs: &[IoSlice<'_>],
389 ) -> Poll<Result<usize>> {
390 match &mut self.0 {
391 Tcp(c) => Pin::new(c).poll_write_vectored(cx, bufs),
392 Tls(c) => Pin::new(&mut *c).poll_write_vectored(cx, bufs),
393 }
394 }
395}
396
397impl<T: Transport> Transport for RustlsClientTransport<T> {
398 fn peer_addr(&self) -> Result<Option<SocketAddr>> {
399 self.as_ref().peer_addr()
400 }
401
402 fn negotiated_alpn(&self) -> Option<std::borrow::Cow<'_, [u8]>> {
403 self.tls_state()
404 .and_then(|conn| conn.alpn_protocol())
405 .map(std::borrow::Cow::Borrowed)
406 }
407}
408
409impl<T> AsRef<T> for RustlsClientTransport<T> {
410 fn as_ref(&self) -> &T {
411 match &self.0 {
412 Tcp(x) => x,
413 Tls(x) => x.get_ref().0,
414 }
415 }
416}
417
418impl<T> RustlsClientTransport<T> {
419 pub fn tls_state_mut(&mut self) -> Option<&mut ClientConnection> {
421 match &mut self.0 {
422 Tls(x) => Some(x.get_mut().1),
423 _ => None,
424 }
425 }
426
427 pub fn tls_state(&self) -> Option<&ClientConnection> {
429 match &self.0 {
430 Tls(x) => Some(x.get_ref().1),
431 _ => None,
432 }
433 }
434}