Skip to main content

trillium_rustls/
client.rs

1use crate::crypto_provider;
2use RustlsClientTransportInner::{Tcp, Tls};
3#[cfg(feature = "dangerous")]
4use futures_rustls::rustls::{
5    DigitallySignedStruct, SignatureScheme,
6    client::danger::{HandshakeSignatureValid, ServerCertVerified},
7    crypto::{verify_tls12_signature, verify_tls13_signature},
8    pki_types::{CertificateDer, UnixTime},
9};
10use futures_rustls::{
11    TlsConnector,
12    client::TlsStream,
13    rustls::{
14        ClientConfig, ClientConnection, RootCertStore,
15        client::{WebPkiServerVerifier, danger::ServerCertVerifier},
16        crypto::CryptoProvider,
17        pki_types::ServerName,
18    },
19};
20use std::{
21    fmt::{self, Debug, Formatter},
22    io::{Error, ErrorKind, IoSlice, Result},
23    net::SocketAddr,
24    pin::Pin,
25    sync::Arc,
26    task::{Context, Poll},
27};
28use trillium_server_common::{AsyncRead, AsyncWrite, Connector, Transport, Url, url::Host};
29
30/// Rustls [`ClientConfig`] wrapper used by [`RustlsConfig`].
31///
32/// [`RustlsClientConfig::default`] trusts the platform or webpki roots (depending on the
33/// `platform-verifier` feature). Use [`RustlsClientConfig::from_root_cert_pem`] to trust a specific
34/// private or self-signed certificate instead, or convert an existing [`ClientConfig`] via
35/// [`From`].
36#[derive(Clone, Debug)]
37pub struct RustlsClientConfig(Arc<ClientConfig>);
38
39/// Client configuration for RustlsConnector
40#[derive(Clone, Default)]
41pub struct RustlsConfig<Config> {
42    /// configuration for rustls itself
43    pub rustls_config: RustlsClientConfig,
44
45    /// configuration for the inner transport
46    pub tcp_config: Config,
47}
48
49impl<C: Connector> RustlsConfig<C> {
50    /// build a new default rustls config with this tcp config
51    pub fn new(rustls_config: impl Into<RustlsClientConfig>, tcp_config: C) -> Self {
52        Self {
53            rustls_config: rustls_config.into(),
54            tcp_config,
55        }
56    }
57}
58
59impl Default for RustlsClientConfig {
60    fn default() -> Self {
61        Self(Arc::new(default_client_config()))
62    }
63}
64
65#[cfg(feature = "platform-verifier")]
66fn verifier(provider: Arc<CryptoProvider>) -> Arc<dyn ServerCertVerifier> {
67    Arc::new(rustls_platform_verifier::Verifier::new(provider).unwrap())
68}
69
70#[cfg(not(feature = "platform-verifier"))]
71fn verifier(provider: Arc<CryptoProvider>) -> Arc<dyn ServerCertVerifier> {
72    let roots = Arc::new(RootCertStore::from_iter(
73        webpki_roots::TLS_SERVER_ROOTS.iter().cloned(),
74    ));
75    WebPkiServerVerifier::builder_with_provider(roots, provider)
76        .build()
77        .unwrap()
78}
79
80fn client_config_with_verifier(verifier: Arc<dyn ServerCertVerifier>) -> ClientConfig {
81    let mut config = ClientConfig::builder_with_provider(crypto_provider())
82        .with_safe_default_protocol_versions()
83        .expect("crypto provider did not support safe default protocol versions")
84        .dangerous()
85        .with_custom_certificate_verifier(verifier)
86        .with_no_client_auth();
87
88    config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
89
90    config
91}
92
93fn default_client_config() -> ClientConfig {
94    client_config_with_verifier(verifier(crypto_provider()))
95}
96
97impl RustlsClientConfig {
98    /// Build a client configuration that trusts exactly the certificate(s) in `pem`.
99    ///
100    /// Unlike [`RustlsClientConfig::default`], this consults neither the platform trust store nor
101    /// the webpki root bundle — the provided roots are the only trust anchors. Server
102    /// authentication is otherwise unchanged: certificate chains, signatures, expiry, and server
103    /// name are all still verified against these roots. This is the right tool for talking to a
104    /// service that presents a private or self-signed certificate.
105    ///
106    /// The crate's configured crypto provider and default ALPN protocol list (`h2`, `http/1.1`)
107    /// are reused.
108    ///
109    /// # Errors
110    ///
111    /// Returns an error if `pem` contains no certificates or cannot be parsed, or if the resulting
112    /// trust anchors are rejected by the verifier builder.
113    pub fn from_root_cert_pem(pem: &[u8]) -> Result<Self> {
114        let mut roots = RootCertStore::empty();
115        let mut reader = pem;
116        for cert in rustls_pemfile::certs(&mut reader) {
117            roots.add(cert?).map_err(Error::other)?;
118        }
119
120        if roots.is_empty() {
121            return Err(Error::new(
122                ErrorKind::InvalidInput,
123                "no certificates found in pem",
124            ));
125        }
126
127        let verifier =
128            WebPkiServerVerifier::builder_with_provider(Arc::new(roots), crypto_provider())
129                .build()
130                .map_err(Error::other)?;
131
132        Ok(Self(Arc::new(client_config_with_verifier(verifier))))
133    }
134}
135
136impl From<ClientConfig> for RustlsClientConfig {
137    fn from(rustls_config: ClientConfig) -> Self {
138        Self(Arc::new(rustls_config))
139    }
140}
141
142impl From<Arc<ClientConfig>> for RustlsClientConfig {
143    fn from(rustls_config: Arc<ClientConfig>) -> Self {
144        Self(rustls_config)
145    }
146}
147
148#[cfg(feature = "dangerous")]
149#[derive(Debug)]
150struct AcceptAnyServerCert(Arc<CryptoProvider>);
151
152#[cfg(feature = "dangerous")]
153impl ServerCertVerifier for AcceptAnyServerCert {
154    fn verify_server_cert(
155        &self,
156        _end_entity: &CertificateDer<'_>,
157        _intermediates: &[CertificateDer<'_>],
158        _server_name: &ServerName<'_>,
159        _ocsp_response: &[u8],
160        _now: UnixTime,
161    ) -> std::result::Result<ServerCertVerified, futures_rustls::rustls::Error> {
162        Ok(ServerCertVerified::assertion())
163    }
164
165    fn verify_tls12_signature(
166        &self,
167        message: &[u8],
168        cert: &CertificateDer<'_>,
169        dss: &DigitallySignedStruct,
170    ) -> std::result::Result<HandshakeSignatureValid, futures_rustls::rustls::Error> {
171        verify_tls12_signature(
172            message,
173            cert,
174            dss,
175            &self.0.signature_verification_algorithms,
176        )
177    }
178
179    fn verify_tls13_signature(
180        &self,
181        message: &[u8],
182        cert: &CertificateDer<'_>,
183        dss: &DigitallySignedStruct,
184    ) -> std::result::Result<HandshakeSignatureValid, futures_rustls::rustls::Error> {
185        verify_tls13_signature(
186            message,
187            cert,
188            dss,
189            &self.0.signature_verification_algorithms,
190        )
191    }
192
193    fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
194        self.0.signature_verification_algorithms.supported_schemes()
195    }
196}
197
198#[cfg(feature = "dangerous")]
199#[cfg_attr(docsrs, doc(cfg(feature = "dangerous")))]
200impl RustlsClientConfig {
201    /// Build a client configuration that accepts **any** server certificate without verification.
202    ///
203    /// ⚠️ This disables server authentication entirely: handshake signatures are still checked,
204    /// but the certificate is never validated against any trust anchor, so the connection is
205    /// vulnerable to man-in-the-middle attacks. It exists for development against throwaway
206    /// self-signed certificates and for `--insecure`-style CLI flags. For talking to a service
207    /// with a known private certificate, prefer [`RustlsClientConfig::from_root_cert_pem`], which
208    /// keeps authentication intact.
209    ///
210    /// This constructor is only available with the `dangerous` crate feature enabled, and logs a
211    /// warning when called.
212    pub fn dangerously_accept_any_cert() -> Self {
213        log::warn!(
214            "constructing a rustls client config that accepts any server certificate; server \
215             authentication is disabled and connections are vulnerable to interception"
216        );
217        let verifier = Arc::new(AcceptAnyServerCert(crypto_provider()));
218        Self(Arc::new(client_config_with_verifier(verifier)))
219    }
220}
221
222impl<C: Connector> RustlsConfig<C> {
223    /// replace the tcp config
224    pub fn with_tcp_config(mut self, config: C) -> Self {
225        self.tcp_config = config;
226        self
227    }
228
229    /// Drop `h2` from the ALPN protocol list, forcing HTTP/1.1 over TLS.
230    ///
231    /// `RustlsConfig::default()` advertises `[h2, http/1.1]` so HTTP/2 is the preferred
232    /// protocol when the server supports it. Call this to opt out and pin the connection to
233    /// HTTP/1.1.
234    #[must_use]
235    pub fn without_http2(mut self) -> Self {
236        let config = Arc::make_mut(&mut self.rustls_config.0);
237        config.alpn_protocols.retain(|p| p != b"h2");
238        self
239    }
240}
241
242impl<Config: Debug> Debug for RustlsConfig<Config> {
243    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
244        f.debug_struct("RustlsConfig")
245            .field("rustls_config", &format_args!(".."))
246            .field("tcp_config", &self.tcp_config)
247            .finish()
248    }
249}
250
251impl<C: Connector> Connector for RustlsConfig<C> {
252    type Runtime = C::Runtime;
253    type Transport = RustlsClientTransport<C::Transport>;
254    type Udp = C::Udp;
255
256    async fn connect(&self, url: &Url) -> Result<Self::Transport> {
257        match url.scheme() {
258            "https" => {
259                let mut http = url.clone();
260                http.set_scheme("http").ok();
261                http.set_port(url.port_or_known_default()).ok();
262
263                let connector: TlsConnector = Arc::clone(&self.rustls_config.0).into();
264                // Derive the TLS server name from the URL host. A domain becomes a DNS
265                // `ServerName` (sent via SNI); an IP literal becomes an `IpAddress` server
266                // name (no SNI, validated against the certificate's IP SAN) — `url.domain()`
267                // returns `None` for IPs, so matching on `host()` is what lets us connect to
268                // an IP address over TLS at all.
269                let domain = match url.host() {
270                    Some(Host::Domain(domain)) => {
271                        ServerName::try_from(domain.to_owned()).map_err(|e| {
272                            Error::other(format!("invalid server name {domain:?}: {e}"))
273                        })?
274                    }
275                    Some(Host::Ipv4(ip)) => ServerName::IpAddress(std::net::IpAddr::V4(ip).into()),
276                    Some(Host::Ipv6(ip)) => ServerName::IpAddress(std::net::IpAddr::V6(ip).into()),
277                    None => return Err(Error::other("url has no host")),
278                };
279
280                connector
281                    .connect(domain, self.tcp_config.connect(&http).await?)
282                    .await
283                    .map_err(|e| Error::other(e.to_string()))
284                    .map(Into::into)
285            }
286
287            "http" => self.tcp_config.connect(url).await.map(Into::into),
288
289            unknown => Err(Error::new(
290                ErrorKind::InvalidInput,
291                format!("unknown scheme {unknown}"),
292            )),
293        }
294    }
295
296    fn runtime(&self) -> Self::Runtime {
297        self.tcp_config.runtime()
298    }
299
300    async fn resolve(&self, host: &str, port: u16) -> Result<Vec<SocketAddr>> {
301        self.tcp_config.resolve(host, port).await
302    }
303}
304
305#[derive(Debug)]
306enum RustlsClientTransportInner<T> {
307    Tcp(T),
308    Tls(Box<TlsStream<T>>),
309}
310
311/// Transport for the rustls connector
312///
313/// This may represent either an encrypted tls connection or a plaintext
314/// connection, depending on the request schema
315#[derive(Debug)]
316pub struct RustlsClientTransport<T>(RustlsClientTransportInner<T>);
317impl<T> From<T> for RustlsClientTransport<T> {
318    fn from(value: T) -> Self {
319        Self(Tcp(value))
320    }
321}
322
323impl<T> From<TlsStream<T>> for RustlsClientTransport<T> {
324    fn from(value: TlsStream<T>) -> Self {
325        Self(Tls(Box::new(value)))
326    }
327}
328
329impl<C> AsyncRead for RustlsClientTransport<C>
330where
331    C: AsyncWrite + AsyncRead + Unpin,
332{
333    fn poll_read(
334        mut self: Pin<&mut Self>,
335        cx: &mut Context<'_>,
336        buf: &mut [u8],
337    ) -> Poll<Result<usize>> {
338        match &mut self.0 {
339            Tcp(c) => Pin::new(c).poll_read(cx, buf),
340            Tls(c) => Pin::new(c).poll_read(cx, buf),
341        }
342    }
343
344    fn poll_read_vectored(
345        mut self: Pin<&mut Self>,
346        cx: &mut Context<'_>,
347        bufs: &mut [std::io::IoSliceMut<'_>],
348    ) -> Poll<Result<usize>> {
349        match &mut self.0 {
350            Tcp(c) => Pin::new(c).poll_read_vectored(cx, bufs),
351            Tls(c) => Pin::new(c).poll_read_vectored(cx, bufs),
352        }
353    }
354}
355
356impl<C> AsyncWrite for RustlsClientTransport<C>
357where
358    C: AsyncRead + AsyncWrite + Unpin,
359{
360    fn poll_write(
361        mut self: Pin<&mut Self>,
362        cx: &mut Context<'_>,
363        buf: &[u8],
364    ) -> Poll<Result<usize>> {
365        match &mut self.0 {
366            Tcp(c) => Pin::new(c).poll_write(cx, buf),
367            Tls(c) => Pin::new(&mut *c).poll_write(cx, buf),
368        }
369    }
370
371    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
372        match &mut self.0 {
373            Tcp(c) => Pin::new(c).poll_flush(cx),
374            Tls(c) => Pin::new(&mut *c).poll_flush(cx),
375        }
376    }
377
378    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
379        match &mut self.0 {
380            Tcp(c) => Pin::new(c).poll_close(cx),
381            Tls(c) => Pin::new(&mut *c).poll_close(cx),
382        }
383    }
384
385    fn poll_write_vectored(
386        mut self: Pin<&mut Self>,
387        cx: &mut Context<'_>,
388        bufs: &[IoSlice<'_>],
389    ) -> Poll<Result<usize>> {
390        match &mut self.0 {
391            Tcp(c) => Pin::new(c).poll_write_vectored(cx, bufs),
392            Tls(c) => Pin::new(&mut *c).poll_write_vectored(cx, bufs),
393        }
394    }
395}
396
397impl<T: Transport> Transport for RustlsClientTransport<T> {
398    fn peer_addr(&self) -> Result<Option<SocketAddr>> {
399        self.as_ref().peer_addr()
400    }
401
402    fn negotiated_alpn(&self) -> Option<std::borrow::Cow<'_, [u8]>> {
403        self.tls_state()
404            .and_then(|conn| conn.alpn_protocol())
405            .map(std::borrow::Cow::Borrowed)
406    }
407}
408
409impl<T> AsRef<T> for RustlsClientTransport<T> {
410    fn as_ref(&self) -> &T {
411        match &self.0 {
412            Tcp(x) => x,
413            Tls(x) => x.get_ref().0,
414        }
415    }
416}
417
418impl<T> RustlsClientTransport<T> {
419    /// Retrieve the tls [`ClientConnection`] if this transport is Tls
420    pub fn tls_state_mut(&mut self) -> Option<&mut ClientConnection> {
421        match &mut self.0 {
422            Tls(x) => Some(x.get_mut().1),
423            _ => None,
424        }
425    }
426
427    /// Retrieve the tls [`ClientConnection`] if this transport is Tls
428    pub fn tls_state(&self) -> Option<&ClientConnection> {
429        match &self.0 {
430            Tls(x) => Some(x.get_ref().1),
431            _ => None,
432        }
433    }
434}