Skip to main content

trillium_rustls/
client.rs

1use crate::crypto_provider;
2use RustlsClientTransportInner::{Tcp, Tls};
3use futures_rustls::{
4    TlsConnector,
5    client::TlsStream,
6    rustls::{
7        ClientConfig, ClientConnection, client::danger::ServerCertVerifier, crypto::CryptoProvider,
8        pki_types::ServerName,
9    },
10};
11use std::{
12    fmt::{self, Debug, Formatter},
13    io::{Error, ErrorKind, IoSlice, Result},
14    net::SocketAddr,
15    pin::Pin,
16    sync::Arc,
17    task::{Context, Poll},
18};
19use trillium_server_common::{AsyncRead, AsyncWrite, Connector, Transport, Url};
20
21#[derive(Clone, Debug)]
22pub struct RustlsClientConfig(Arc<ClientConfig>);
23
24/// Client configuration for RustlsConnector
25#[derive(Clone, Default)]
26pub struct RustlsConfig<Config> {
27    /// configuration for rustls itself
28    pub rustls_config: RustlsClientConfig,
29
30    /// configuration for the inner transport
31    pub tcp_config: Config,
32}
33
34impl<C: Connector> RustlsConfig<C> {
35    /// build a new default rustls config with this tcp config
36    pub fn new(rustls_config: impl Into<RustlsClientConfig>, tcp_config: C) -> Self {
37        Self {
38            rustls_config: rustls_config.into(),
39            tcp_config,
40        }
41    }
42}
43
44impl Default for RustlsClientConfig {
45    fn default() -> Self {
46        Self(Arc::new(default_client_config()))
47    }
48}
49
50#[cfg(feature = "platform-verifier")]
51fn verifier(provider: Arc<CryptoProvider>) -> Arc<dyn ServerCertVerifier> {
52    Arc::new(rustls_platform_verifier::Verifier::new(provider).unwrap())
53}
54
55#[cfg(not(feature = "platform-verifier"))]
56fn verifier(provider: Arc<CryptoProvider>) -> Arc<dyn ServerCertVerifier> {
57    let roots = Arc::new(futures_rustls::rustls::RootCertStore::from_iter(
58        webpki_roots::TLS_SERVER_ROOTS.iter().cloned(),
59    ));
60    futures_rustls::rustls::client::WebPkiServerVerifier::builder_with_provider(roots, provider)
61        .build()
62        .unwrap()
63}
64
65fn default_client_config() -> ClientConfig {
66    let provider = crypto_provider();
67    let verifier = verifier(Arc::clone(&provider));
68
69    let mut config = ClientConfig::builder_with_provider(provider)
70        .with_safe_default_protocol_versions()
71        .expect("crypto provider did not support safe default protocol versions")
72        .dangerous()
73        .with_custom_certificate_verifier(verifier)
74        .with_no_client_auth();
75
76    config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
77
78    config
79}
80
81impl From<ClientConfig> for RustlsClientConfig {
82    fn from(rustls_config: ClientConfig) -> Self {
83        Self(Arc::new(rustls_config))
84    }
85}
86
87impl From<Arc<ClientConfig>> for RustlsClientConfig {
88    fn from(rustls_config: Arc<ClientConfig>) -> Self {
89        Self(rustls_config)
90    }
91}
92
93impl<C: Connector> RustlsConfig<C> {
94    /// replace the tcp config
95    pub fn with_tcp_config(mut self, config: C) -> Self {
96        self.tcp_config = config;
97        self
98    }
99
100    /// Drop `h2` from the ALPN protocol list, forcing HTTP/1.1 over TLS.
101    ///
102    /// `RustlsConfig::default()` advertises `[h2, http/1.1]` so HTTP/2 is the preferred
103    /// protocol when the server supports it. Call this to opt out and pin the connection to
104    /// HTTP/1.1.
105    #[must_use]
106    pub fn without_http2(mut self) -> Self {
107        let config = Arc::make_mut(&mut self.rustls_config.0);
108        config.alpn_protocols.retain(|p| p != b"h2");
109        self
110    }
111}
112
113impl<Config: Debug> Debug for RustlsConfig<Config> {
114    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
115        f.debug_struct("RustlsConfig")
116            .field("rustls_config", &format_args!(".."))
117            .field("tcp_config", &self.tcp_config)
118            .finish()
119    }
120}
121
122impl<C: Connector> Connector for RustlsConfig<C> {
123    type Runtime = C::Runtime;
124    type Transport = RustlsClientTransport<C::Transport>;
125    type Udp = C::Udp;
126
127    async fn connect(&self, url: &Url) -> Result<Self::Transport> {
128        match url.scheme() {
129            "https" => {
130                let mut http = url.clone();
131                http.set_scheme("http").ok();
132                http.set_port(url.port_or_known_default()).ok();
133
134                let connector: TlsConnector = Arc::clone(&self.rustls_config.0).into();
135                let domain = url
136                    .domain()
137                    .and_then(|dns_name| ServerName::try_from(dns_name.to_string()).ok())
138                    .ok_or_else(|| Error::other("missing domain"))?;
139
140                connector
141                    .connect(domain, self.tcp_config.connect(&http).await?)
142                    .await
143                    .map_err(|e| Error::other(e.to_string()))
144                    .map(Into::into)
145            }
146
147            "http" => self.tcp_config.connect(url).await.map(Into::into),
148
149            unknown => Err(Error::new(
150                ErrorKind::InvalidInput,
151                format!("unknown scheme {unknown}"),
152            )),
153        }
154    }
155
156    fn runtime(&self) -> Self::Runtime {
157        self.tcp_config.runtime()
158    }
159
160    async fn resolve(&self, host: &str, port: u16) -> Result<Vec<SocketAddr>> {
161        self.tcp_config.resolve(host, port).await
162    }
163}
164
165#[derive(Debug)]
166enum RustlsClientTransportInner<T> {
167    Tcp(T),
168    Tls(Box<TlsStream<T>>),
169}
170
171/// Transport for the rustls connector
172///
173/// This may represent either an encrypted tls connection or a plaintext
174/// connection, depending on the request schema
175#[derive(Debug)]
176pub struct RustlsClientTransport<T>(RustlsClientTransportInner<T>);
177impl<T> From<T> for RustlsClientTransport<T> {
178    fn from(value: T) -> Self {
179        Self(Tcp(value))
180    }
181}
182
183impl<T> From<TlsStream<T>> for RustlsClientTransport<T> {
184    fn from(value: TlsStream<T>) -> Self {
185        Self(Tls(Box::new(value)))
186    }
187}
188
189impl<C> AsyncRead for RustlsClientTransport<C>
190where
191    C: AsyncWrite + AsyncRead + Unpin,
192{
193    fn poll_read(
194        mut self: Pin<&mut Self>,
195        cx: &mut Context<'_>,
196        buf: &mut [u8],
197    ) -> Poll<Result<usize>> {
198        match &mut self.0 {
199            Tcp(c) => Pin::new(c).poll_read(cx, buf),
200            Tls(c) => Pin::new(c).poll_read(cx, buf),
201        }
202    }
203
204    fn poll_read_vectored(
205        mut self: Pin<&mut Self>,
206        cx: &mut Context<'_>,
207        bufs: &mut [std::io::IoSliceMut<'_>],
208    ) -> Poll<Result<usize>> {
209        match &mut self.0 {
210            Tcp(c) => Pin::new(c).poll_read_vectored(cx, bufs),
211            Tls(c) => Pin::new(c).poll_read_vectored(cx, bufs),
212        }
213    }
214}
215
216impl<C> AsyncWrite for RustlsClientTransport<C>
217where
218    C: AsyncRead + AsyncWrite + Unpin,
219{
220    fn poll_write(
221        mut self: Pin<&mut Self>,
222        cx: &mut Context<'_>,
223        buf: &[u8],
224    ) -> Poll<Result<usize>> {
225        match &mut self.0 {
226            Tcp(c) => Pin::new(c).poll_write(cx, buf),
227            Tls(c) => Pin::new(&mut *c).poll_write(cx, buf),
228        }
229    }
230
231    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
232        match &mut self.0 {
233            Tcp(c) => Pin::new(c).poll_flush(cx),
234            Tls(c) => Pin::new(&mut *c).poll_flush(cx),
235        }
236    }
237
238    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
239        match &mut self.0 {
240            Tcp(c) => Pin::new(c).poll_close(cx),
241            Tls(c) => Pin::new(&mut *c).poll_close(cx),
242        }
243    }
244
245    fn poll_write_vectored(
246        mut self: Pin<&mut Self>,
247        cx: &mut Context<'_>,
248        bufs: &[IoSlice<'_>],
249    ) -> Poll<Result<usize>> {
250        match &mut self.0 {
251            Tcp(c) => Pin::new(c).poll_write_vectored(cx, bufs),
252            Tls(c) => Pin::new(&mut *c).poll_write_vectored(cx, bufs),
253        }
254    }
255}
256
257impl<T: Transport> Transport for RustlsClientTransport<T> {
258    fn peer_addr(&self) -> Result<Option<SocketAddr>> {
259        self.as_ref().peer_addr()
260    }
261
262    fn negotiated_alpn(&self) -> Option<std::borrow::Cow<'_, [u8]>> {
263        self.tls_state()
264            .and_then(|conn| conn.alpn_protocol())
265            .map(std::borrow::Cow::Borrowed)
266    }
267}
268
269impl<T> AsRef<T> for RustlsClientTransport<T> {
270    fn as_ref(&self) -> &T {
271        match &self.0 {
272            Tcp(x) => x,
273            Tls(x) => x.get_ref().0,
274        }
275    }
276}
277
278impl<T> RustlsClientTransport<T> {
279    /// Retrieve the tls [`ClientConnection`] if this transport is Tls
280    pub fn tls_state_mut(&mut self) -> Option<&mut ClientConnection> {
281        match &mut self.0 {
282            Tls(x) => Some(x.get_mut().1),
283            _ => None,
284        }
285    }
286
287    /// Retrieve the tls [`ClientConnection`] if this transport is Tls
288    pub fn tls_state(&self) -> Option<&ClientConnection> {
289        match &self.0 {
290            Tls(x) => Some(x.get_ref().1),
291            _ => None,
292        }
293    }
294}