Skip to main content

trillium_openssl/
client.rs

1use crate::alpn::encode_alpn;
2use OpenSslClientTransportInner::{Tcp, Tls};
3use async_openssl::SslStream;
4use openssl::ssl::{SslConnector, SslMethod};
5use std::{
6    fmt::{self, Debug, Formatter},
7    io::{Error, ErrorKind, IoSliceMut, Result},
8    net::SocketAddr,
9    pin::Pin,
10    sync::Arc,
11    task::{Context, Poll},
12};
13use trillium_server_common::{AsyncRead, AsyncWrite, Connector, Transport, Url};
14
15/// A reference-counted [`SslConnector`] with a sensible default.
16///
17/// The `Default` impl uses [`SslMethod::tls_client`] and advertises `[h2, http/1.1]` via ALPN.
18/// To customize, build a [`SslConnector`] yourself and convert via `From`/`Into`.
19#[derive(Clone, Debug)]
20pub struct OpenSslClientConfig(Arc<SslConnector>);
21
22impl OpenSslClientConfig {
23    /// borrow the inner [`SslConnector`]
24    pub fn as_inner(&self) -> &SslConnector {
25        &self.0
26    }
27}
28
29impl Default for OpenSslClientConfig {
30    fn default() -> Self {
31        let mut builder =
32            SslConnector::builder(SslMethod::tls_client()).expect("could not build SslConnector");
33        let alpn_wire = encode_alpn(&[b"h2".to_vec(), b"http/1.1".to_vec()]);
34        builder
35            .set_alpn_protos(&alpn_wire)
36            .expect("could not set ALPN protocols");
37        Self(Arc::new(builder.build()))
38    }
39}
40
41impl From<SslConnector> for OpenSslClientConfig {
42    fn from(connector: SslConnector) -> Self {
43        Self(Arc::new(connector))
44    }
45}
46
47impl From<Arc<SslConnector>> for OpenSslClientConfig {
48    fn from(connector: Arc<SslConnector>) -> Self {
49        Self(connector)
50    }
51}
52
53/// Configuration for the openssl client connector
54#[derive(Clone, Default)]
55pub struct OpenSslConfig<Config> {
56    /// configuration for the inner Connector (usually tcp)
57    pub tcp_config: Config,
58
59    /// the openssl client configuration
60    pub ssl_config: OpenSslClientConfig,
61}
62
63impl<C: Connector> OpenSslConfig<C> {
64    /// build a new `OpenSslConfig` from a ssl client configuration and a tcp config
65    pub fn new(ssl_config: impl Into<OpenSslClientConfig>, tcp_config: C) -> Self {
66        Self {
67            tcp_config,
68            ssl_config: ssl_config.into(),
69        }
70    }
71
72    /// replace the tcp config
73    #[must_use]
74    pub fn with_tcp_config(mut self, config: C) -> Self {
75        self.tcp_config = config;
76        self
77    }
78}
79
80impl<Config: Debug> Debug for OpenSslConfig<Config> {
81    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
82        f.debug_struct("OpenSslConfig")
83            .field("tcp_config", &self.tcp_config)
84            .field("ssl_config", &self.ssl_config)
85            .finish()
86    }
87}
88
89impl<C> AsRef<C> for OpenSslConfig<C> {
90    fn as_ref(&self) -> &C {
91        &self.tcp_config
92    }
93}
94
95impl<C: Connector> Connector for OpenSslConfig<C> {
96    type Runtime = C::Runtime;
97    type Transport = OpenSslClientTransport<C::Transport>;
98    type Udp = C::Udp;
99
100    async fn connect(&self, url: &Url) -> Result<Self::Transport> {
101        match url.scheme() {
102            "https" => {
103                let mut http = url.clone();
104                http.set_scheme("http").ok();
105                http.set_port(url.port_or_known_default()).ok();
106
107                let domain = url.domain().ok_or_else(|| Error::other("missing domain"))?;
108                let ssl = self
109                    .ssl_config
110                    .as_inner()
111                    .configure()
112                    .map_err(Error::other)?
113                    .into_ssl(domain)
114                    .map_err(Error::other)?;
115
116                let inner = self.tcp_config.connect(&http).await?;
117                let mut stream = SslStream::new(ssl, inner).map_err(Error::other)?;
118                Pin::new(&mut stream)
119                    .connect()
120                    .await
121                    .map_err(Error::other)?;
122                Ok(OpenSslClientTransport(Tls(Box::new(stream))))
123            }
124
125            "http" => self
126                .tcp_config
127                .connect(url)
128                .await
129                .map(|t| OpenSslClientTransport(Tcp(t))),
130
131            unknown => Err(Error::new(
132                ErrorKind::InvalidInput,
133                format!("unknown scheme {unknown}"),
134            )),
135        }
136    }
137
138    fn runtime(&self) -> Self::Runtime {
139        self.tcp_config.runtime()
140    }
141
142    async fn resolve(&self, host: &str, port: u16) -> Result<Vec<SocketAddr>> {
143        self.tcp_config.resolve(host, port).await
144    }
145}
146
147#[derive(Debug)]
148enum OpenSslClientTransportInner<T: Unpin> {
149    Tcp(T),
150    Tls(Box<SslStream<T>>),
151}
152
153/// Transport for the openssl connector
154///
155/// May represent either an encrypted tls connection or a plaintext connection,
156/// depending on the request scheme.
157#[derive(Debug)]
158pub struct OpenSslClientTransport<T: Unpin>(OpenSslClientTransportInner<T>);
159
160impl<T: Unpin> OpenSslClientTransport<T> {
161    /// Borrow the underlying [`SslStream`] if this transport is TLS.
162    pub fn as_tls(&self) -> Option<&SslStream<T>> {
163        match &self.0 {
164            Tcp(_) => None,
165            Tls(s) => Some(s),
166        }
167    }
168}
169
170impl<T: Unpin> AsRef<T> for OpenSslClientTransport<T> {
171    fn as_ref(&self) -> &T {
172        match &self.0 {
173            Tcp(t) => t,
174            Tls(s) => s.get_ref(),
175        }
176    }
177}
178
179impl<T: AsyncRead + AsyncWrite + Unpin> AsyncRead for OpenSslClientTransport<T> {
180    fn poll_read(
181        mut self: Pin<&mut Self>,
182        cx: &mut Context<'_>,
183        buf: &mut [u8],
184    ) -> Poll<Result<usize>> {
185        match &mut self.0 {
186            Tcp(t) => Pin::new(t).poll_read(cx, buf),
187            Tls(s) => Pin::new(&mut **s).poll_read(cx, buf),
188        }
189    }
190
191    fn poll_read_vectored(
192        mut self: Pin<&mut Self>,
193        cx: &mut Context<'_>,
194        bufs: &mut [IoSliceMut<'_>],
195    ) -> Poll<Result<usize>> {
196        match &mut self.0 {
197            Tcp(t) => Pin::new(t).poll_read_vectored(cx, bufs),
198            Tls(s) => Pin::new(&mut **s).poll_read_vectored(cx, bufs),
199        }
200    }
201}
202
203impl<T: AsyncRead + AsyncWrite + Unpin> AsyncWrite for OpenSslClientTransport<T> {
204    fn poll_write(
205        mut self: Pin<&mut Self>,
206        cx: &mut Context<'_>,
207        buf: &[u8],
208    ) -> Poll<Result<usize>> {
209        match &mut self.0 {
210            Tcp(t) => Pin::new(t).poll_write(cx, buf),
211            Tls(s) => Pin::new(&mut **s).poll_write(cx, buf),
212        }
213    }
214
215    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
216        match &mut self.0 {
217            Tcp(t) => Pin::new(t).poll_flush(cx),
218            Tls(s) => Pin::new(&mut **s).poll_flush(cx),
219        }
220    }
221
222    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
223        match &mut self.0 {
224            Tcp(t) => Pin::new(t).poll_close(cx),
225            Tls(s) => Pin::new(&mut **s).poll_close(cx),
226        }
227    }
228}
229
230impl<T: Transport> Transport for OpenSslClientTransport<T> {
231    fn peer_addr(&self) -> Result<Option<SocketAddr>> {
232        self.as_ref().peer_addr()
233    }
234
235    fn negotiated_alpn(&self) -> Option<std::borrow::Cow<'_, [u8]>> {
236        match &self.0 {
237            Tcp(_) => None,
238            Tls(s) => s
239                .ssl()
240                .selected_alpn_protocol()
241                .map(std::borrow::Cow::Borrowed),
242        }
243    }
244}