trillium_openssl/
client.rs1use crate::alpn::encode_alpn;
2use OpenSslClientTransportInner::{Tcp, Tls};
3use async_openssl::SslStream;
4use openssl::ssl::{SslConnector, SslMethod};
5use std::{
6 fmt::{self, Debug, Formatter},
7 io::{Error, IoSliceMut, Result},
8 net::SocketAddr,
9 pin::Pin,
10 sync::Arc,
11 task::{Context, Poll},
12};
13use trillium_server_common::{AsyncRead, AsyncWrite, Connector, Destination, Transport, Url};
14
15#[derive(Clone, Debug)]
20pub struct OpenSslClientConfig(Arc<SslConnector>);
21
22impl OpenSslClientConfig {
23 pub fn as_inner(&self) -> &SslConnector {
25 &self.0
26 }
27}
28
29impl Default for OpenSslClientConfig {
30 fn default() -> Self {
31 let mut builder =
32 SslConnector::builder(SslMethod::tls_client()).expect("could not build SslConnector");
33 let alpn_wire = encode_alpn(&[b"h2".to_vec(), b"http/1.1".to_vec()]);
34 builder
35 .set_alpn_protos(&alpn_wire)
36 .expect("could not set ALPN protocols");
37 Self(Arc::new(builder.build()))
38 }
39}
40
41impl From<SslConnector> for OpenSslClientConfig {
42 fn from(connector: SslConnector) -> Self {
43 Self(Arc::new(connector))
44 }
45}
46
47impl From<Arc<SslConnector>> for OpenSslClientConfig {
48 fn from(connector: Arc<SslConnector>) -> Self {
49 Self(connector)
50 }
51}
52
53#[derive(Clone, Default)]
55pub struct OpenSslConfig<Config> {
56 pub tcp_config: Config,
58
59 pub ssl_config: OpenSslClientConfig,
61}
62
63impl<C: Connector> OpenSslConfig<C> {
64 pub fn new(ssl_config: impl Into<OpenSslClientConfig>, tcp_config: C) -> Self {
66 Self {
67 tcp_config,
68 ssl_config: ssl_config.into(),
69 }
70 }
71
72 #[must_use]
74 pub fn with_tcp_config(mut self, config: C) -> Self {
75 self.tcp_config = config;
76 self
77 }
78}
79
80impl<Config: Debug> Debug for OpenSslConfig<Config> {
81 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
82 f.debug_struct("OpenSslConfig")
83 .field("tcp_config", &self.tcp_config)
84 .field("ssl_config", &self.ssl_config)
85 .finish()
86 }
87}
88
89impl<C> AsRef<C> for OpenSslConfig<C> {
90 fn as_ref(&self) -> &C {
91 &self.tcp_config
92 }
93}
94
95impl<C: Connector> Connector for OpenSslConfig<C> {
96 type Runtime = C::Runtime;
97 type Transport = OpenSslClientTransport<C::Transport>;
98 type Udp = C::Udp;
99
100 async fn connect(&self, url: &Url) -> Result<Self::Transport> {
101 self.connect_to(Destination::from_url(url)?).await
102 }
103
104 async fn connect_to(&self, destination: Destination) -> Result<Self::Transport> {
105 if !destination.secure() {
106 return self
107 .tcp_config
108 .connect_to(destination)
109 .await
110 .map(|t| OpenSslClientTransport(Tcp(t)));
111 }
112
113 let mut ssl = self
115 .ssl_config
116 .as_inner()
117 .configure()
118 .map_err(Error::other)?
119 .into_ssl(
120 destination
121 .host()
122 .ok_or_else(|| Error::other("missing domain"))?,
123 )
124 .map_err(Error::other)?;
125
126 if let Some(alpn) = destination.alpn() {
129 ssl.set_alpn_protos(&encode_alpn(alpn))
130 .map_err(Error::other)?;
131 }
132
133 let inner = self
134 .tcp_config
135 .connect_to(destination.with_secure(false))
136 .await?;
137 let mut stream = SslStream::new(ssl, inner).map_err(Error::other)?;
138 Pin::new(&mut stream)
139 .connect()
140 .await
141 .map_err(Error::other)?;
142 Ok(OpenSslClientTransport(Tls(Box::new(stream))))
143 }
144
145 fn runtime(&self) -> Self::Runtime {
146 self.tcp_config.runtime()
147 }
148
149 async fn resolve(&self, host: &str, port: u16) -> Result<Vec<SocketAddr>> {
150 self.tcp_config.resolve(host, port).await
151 }
152}
153
154#[derive(Debug)]
155enum OpenSslClientTransportInner<T: Unpin> {
156 Tcp(T),
157 Tls(Box<SslStream<T>>),
158}
159
160#[derive(Debug)]
165pub struct OpenSslClientTransport<T: Unpin>(OpenSslClientTransportInner<T>);
166
167impl<T: Unpin> OpenSslClientTransport<T> {
168 pub fn as_tls(&self) -> Option<&SslStream<T>> {
170 match &self.0 {
171 Tcp(_) => None,
172 Tls(s) => Some(s),
173 }
174 }
175}
176
177impl<T: Unpin> AsRef<T> for OpenSslClientTransport<T> {
178 fn as_ref(&self) -> &T {
179 match &self.0 {
180 Tcp(t) => t,
181 Tls(s) => s.get_ref(),
182 }
183 }
184}
185
186impl<T: AsyncRead + AsyncWrite + Unpin> AsyncRead for OpenSslClientTransport<T> {
187 fn poll_read(
188 mut self: Pin<&mut Self>,
189 cx: &mut Context<'_>,
190 buf: &mut [u8],
191 ) -> Poll<Result<usize>> {
192 match &mut self.0 {
193 Tcp(t) => Pin::new(t).poll_read(cx, buf),
194 Tls(s) => Pin::new(&mut **s).poll_read(cx, buf),
195 }
196 }
197
198 fn poll_read_vectored(
199 mut self: Pin<&mut Self>,
200 cx: &mut Context<'_>,
201 bufs: &mut [IoSliceMut<'_>],
202 ) -> Poll<Result<usize>> {
203 match &mut self.0 {
204 Tcp(t) => Pin::new(t).poll_read_vectored(cx, bufs),
205 Tls(s) => Pin::new(&mut **s).poll_read_vectored(cx, bufs),
206 }
207 }
208}
209
210impl<T: AsyncRead + AsyncWrite + Unpin> AsyncWrite for OpenSslClientTransport<T> {
211 fn poll_write(
212 mut self: Pin<&mut Self>,
213 cx: &mut Context<'_>,
214 buf: &[u8],
215 ) -> Poll<Result<usize>> {
216 match &mut self.0 {
217 Tcp(t) => Pin::new(t).poll_write(cx, buf),
218 Tls(s) => Pin::new(&mut **s).poll_write(cx, buf),
219 }
220 }
221
222 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
223 match &mut self.0 {
224 Tcp(t) => Pin::new(t).poll_flush(cx),
225 Tls(s) => Pin::new(&mut **s).poll_flush(cx),
226 }
227 }
228
229 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
230 match &mut self.0 {
231 Tcp(t) => Pin::new(t).poll_close(cx),
232 Tls(s) => Pin::new(&mut **s).poll_close(cx),
233 }
234 }
235}
236
237impl<T: Transport> Transport for OpenSslClientTransport<T> {
238 fn peer_addr(&self) -> Result<Option<SocketAddr>> {
239 self.as_ref().peer_addr()
240 }
241
242 fn negotiated_alpn(&self) -> Option<std::borrow::Cow<'_, [u8]>> {
243 match &self.0 {
244 Tcp(_) => None,
245 Tls(s) => s
246 .ssl()
247 .selected_alpn_protocol()
248 .map(std::borrow::Cow::Borrowed),
249 }
250 }
251}