Skip to main content

trillium_openssl/
client.rs

1use crate::alpn::encode_alpn;
2use OpenSslClientTransportInner::{Tcp, Tls};
3use async_openssl::SslStream;
4use openssl::ssl::{SslConnector, SslMethod};
5use std::{
6    fmt::{self, Debug, Formatter},
7    io::{Error, IoSliceMut, Result},
8    net::SocketAddr,
9    pin::Pin,
10    sync::Arc,
11    task::{Context, Poll},
12};
13use trillium_server_common::{AsyncRead, AsyncWrite, Connector, Destination, Transport, Url};
14
15/// A reference-counted [`SslConnector`] with a sensible default.
16///
17/// The `Default` impl uses [`SslMethod::tls_client`] and advertises `[h2, http/1.1]` via ALPN.
18/// To customize, build a [`SslConnector`] yourself and convert via `From`/`Into`.
19#[derive(Clone, Debug)]
20pub struct OpenSslClientConfig(Arc<SslConnector>);
21
22impl OpenSslClientConfig {
23    /// borrow the inner [`SslConnector`]
24    pub fn as_inner(&self) -> &SslConnector {
25        &self.0
26    }
27}
28
29impl Default for OpenSslClientConfig {
30    fn default() -> Self {
31        let mut builder =
32            SslConnector::builder(SslMethod::tls_client()).expect("could not build SslConnector");
33        let alpn_wire = encode_alpn(&[b"h2".to_vec(), b"http/1.1".to_vec()]);
34        builder
35            .set_alpn_protos(&alpn_wire)
36            .expect("could not set ALPN protocols");
37        Self(Arc::new(builder.build()))
38    }
39}
40
41impl From<SslConnector> for OpenSslClientConfig {
42    fn from(connector: SslConnector) -> Self {
43        Self(Arc::new(connector))
44    }
45}
46
47impl From<Arc<SslConnector>> for OpenSslClientConfig {
48    fn from(connector: Arc<SslConnector>) -> Self {
49        Self(connector)
50    }
51}
52
53/// Configuration for the openssl client connector
54#[derive(Clone, Default)]
55pub struct OpenSslConfig<Config> {
56    /// configuration for the inner Connector (usually tcp)
57    pub tcp_config: Config,
58
59    /// the openssl client configuration
60    pub ssl_config: OpenSslClientConfig,
61}
62
63impl<C: Connector> OpenSslConfig<C> {
64    /// build a new `OpenSslConfig` from a ssl client configuration and a tcp config
65    pub fn new(ssl_config: impl Into<OpenSslClientConfig>, tcp_config: C) -> Self {
66        Self {
67            tcp_config,
68            ssl_config: ssl_config.into(),
69        }
70    }
71
72    /// replace the tcp config
73    #[must_use]
74    pub fn with_tcp_config(mut self, config: C) -> Self {
75        self.tcp_config = config;
76        self
77    }
78}
79
80impl<Config: Debug> Debug for OpenSslConfig<Config> {
81    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
82        f.debug_struct("OpenSslConfig")
83            .field("tcp_config", &self.tcp_config)
84            .field("ssl_config", &self.ssl_config)
85            .finish()
86    }
87}
88
89impl<C> AsRef<C> for OpenSslConfig<C> {
90    fn as_ref(&self) -> &C {
91        &self.tcp_config
92    }
93}
94
95impl<C: Connector> Connector for OpenSslConfig<C> {
96    type Runtime = C::Runtime;
97    type Transport = OpenSslClientTransport<C::Transport>;
98    type Udp = C::Udp;
99
100    async fn connect(&self, url: &Url) -> Result<Self::Transport> {
101        self.connect_to(Destination::from_url(url)?).await
102    }
103
104    async fn connect_to(&self, destination: Destination) -> Result<Self::Transport> {
105        if !destination.secure() {
106            return self
107                .tcp_config
108                .connect_to(destination)
109                .await
110                .map(|t| OpenSslClientTransport(Tcp(t)));
111        }
112
113        // OpenSSL's `into_ssl` requires a domain for SNI; a bare-IP destination has none.
114        let mut ssl = self
115            .ssl_config
116            .as_inner()
117            .configure()
118            .map_err(Error::other)?
119            .into_ssl(
120                destination
121                    .host()
122                    .ok_or_else(|| Error::other("missing domain"))?,
123            )
124            .map_err(Error::other)?;
125
126        // A per-connection ALPN override replaces the connector's default; absent one, the
127        // configured default stays in place.
128        if let Some(alpn) = destination.alpn() {
129            ssl.set_alpn_protos(&encode_alpn(alpn))
130                .map_err(Error::other)?;
131        }
132
133        let inner = self
134            .tcp_config
135            .connect_to(destination.with_secure(false))
136            .await?;
137        let mut stream = SslStream::new(ssl, inner).map_err(Error::other)?;
138        Pin::new(&mut stream)
139            .connect()
140            .await
141            .map_err(Error::other)?;
142        Ok(OpenSslClientTransport(Tls(Box::new(stream))))
143    }
144
145    fn runtime(&self) -> Self::Runtime {
146        self.tcp_config.runtime()
147    }
148
149    async fn resolve(&self, host: &str, port: u16) -> Result<Vec<SocketAddr>> {
150        self.tcp_config.resolve(host, port).await
151    }
152}
153
154#[derive(Debug)]
155enum OpenSslClientTransportInner<T: Unpin> {
156    Tcp(T),
157    Tls(Box<SslStream<T>>),
158}
159
160/// Transport for the openssl connector
161///
162/// May represent either an encrypted tls connection or a plaintext connection,
163/// depending on the request scheme.
164#[derive(Debug)]
165pub struct OpenSslClientTransport<T: Unpin>(OpenSslClientTransportInner<T>);
166
167impl<T: Unpin> OpenSslClientTransport<T> {
168    /// Borrow the underlying [`SslStream`] if this transport is TLS.
169    pub fn as_tls(&self) -> Option<&SslStream<T>> {
170        match &self.0 {
171            Tcp(_) => None,
172            Tls(s) => Some(s),
173        }
174    }
175}
176
177impl<T: Unpin> AsRef<T> for OpenSslClientTransport<T> {
178    fn as_ref(&self) -> &T {
179        match &self.0 {
180            Tcp(t) => t,
181            Tls(s) => s.get_ref(),
182        }
183    }
184}
185
186impl<T: AsyncRead + AsyncWrite + Unpin> AsyncRead for OpenSslClientTransport<T> {
187    fn poll_read(
188        mut self: Pin<&mut Self>,
189        cx: &mut Context<'_>,
190        buf: &mut [u8],
191    ) -> Poll<Result<usize>> {
192        match &mut self.0 {
193            Tcp(t) => Pin::new(t).poll_read(cx, buf),
194            Tls(s) => Pin::new(&mut **s).poll_read(cx, buf),
195        }
196    }
197
198    fn poll_read_vectored(
199        mut self: Pin<&mut Self>,
200        cx: &mut Context<'_>,
201        bufs: &mut [IoSliceMut<'_>],
202    ) -> Poll<Result<usize>> {
203        match &mut self.0 {
204            Tcp(t) => Pin::new(t).poll_read_vectored(cx, bufs),
205            Tls(s) => Pin::new(&mut **s).poll_read_vectored(cx, bufs),
206        }
207    }
208}
209
210impl<T: AsyncRead + AsyncWrite + Unpin> AsyncWrite for OpenSslClientTransport<T> {
211    fn poll_write(
212        mut self: Pin<&mut Self>,
213        cx: &mut Context<'_>,
214        buf: &[u8],
215    ) -> Poll<Result<usize>> {
216        match &mut self.0 {
217            Tcp(t) => Pin::new(t).poll_write(cx, buf),
218            Tls(s) => Pin::new(&mut **s).poll_write(cx, buf),
219        }
220    }
221
222    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
223        match &mut self.0 {
224            Tcp(t) => Pin::new(t).poll_flush(cx),
225            Tls(s) => Pin::new(&mut **s).poll_flush(cx),
226        }
227    }
228
229    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
230        match &mut self.0 {
231            Tcp(t) => Pin::new(t).poll_close(cx),
232            Tls(s) => Pin::new(&mut **s).poll_close(cx),
233        }
234    }
235}
236
237impl<T: Transport> Transport for OpenSslClientTransport<T> {
238    fn peer_addr(&self) -> Result<Option<SocketAddr>> {
239        self.as_ref().peer_addr()
240    }
241
242    fn negotiated_alpn(&self) -> Option<std::borrow::Cow<'_, [u8]>> {
243        match &self.0 {
244            Tcp(_) => None,
245            Tls(s) => s
246                .ssl()
247                .selected_alpn_protocol()
248                .map(std::borrow::Cow::Borrowed),
249        }
250    }
251}