trillium_openssl/
client.rs1use crate::alpn::encode_alpn;
2use OpenSslClientTransportInner::{Tcp, Tls};
3use async_openssl::SslStream;
4use openssl::ssl::{SslConnector, SslMethod};
5use std::{
6 fmt::{self, Debug, Formatter},
7 io::{Error, ErrorKind, IoSliceMut, Result},
8 net::SocketAddr,
9 pin::Pin,
10 sync::Arc,
11 task::{Context, Poll},
12};
13use trillium_server_common::{AsyncRead, AsyncWrite, Connector, Transport, Url};
14
15#[derive(Clone, Debug)]
20pub struct OpenSslClientConfig(Arc<SslConnector>);
21
22impl OpenSslClientConfig {
23 pub fn as_inner(&self) -> &SslConnector {
25 &self.0
26 }
27}
28
29impl Default for OpenSslClientConfig {
30 fn default() -> Self {
31 let mut builder =
32 SslConnector::builder(SslMethod::tls_client()).expect("could not build SslConnector");
33 let alpn_wire = encode_alpn(&[b"h2".to_vec(), b"http/1.1".to_vec()]);
34 builder
35 .set_alpn_protos(&alpn_wire)
36 .expect("could not set ALPN protocols");
37 Self(Arc::new(builder.build()))
38 }
39}
40
41impl From<SslConnector> for OpenSslClientConfig {
42 fn from(connector: SslConnector) -> Self {
43 Self(Arc::new(connector))
44 }
45}
46
47impl From<Arc<SslConnector>> for OpenSslClientConfig {
48 fn from(connector: Arc<SslConnector>) -> Self {
49 Self(connector)
50 }
51}
52
53#[derive(Clone, Default)]
55pub struct OpenSslConfig<Config> {
56 pub tcp_config: Config,
58
59 pub ssl_config: OpenSslClientConfig,
61}
62
63impl<C: Connector> OpenSslConfig<C> {
64 pub fn new(ssl_config: impl Into<OpenSslClientConfig>, tcp_config: C) -> Self {
66 Self {
67 tcp_config,
68 ssl_config: ssl_config.into(),
69 }
70 }
71
72 #[must_use]
74 pub fn with_tcp_config(mut self, config: C) -> Self {
75 self.tcp_config = config;
76 self
77 }
78}
79
80impl<Config: Debug> Debug for OpenSslConfig<Config> {
81 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
82 f.debug_struct("OpenSslConfig")
83 .field("tcp_config", &self.tcp_config)
84 .field("ssl_config", &self.ssl_config)
85 .finish()
86 }
87}
88
89impl<C> AsRef<C> for OpenSslConfig<C> {
90 fn as_ref(&self) -> &C {
91 &self.tcp_config
92 }
93}
94
95impl<C: Connector> Connector for OpenSslConfig<C> {
96 type Runtime = C::Runtime;
97 type Transport = OpenSslClientTransport<C::Transport>;
98 type Udp = C::Udp;
99
100 async fn connect(&self, url: &Url) -> Result<Self::Transport> {
101 match url.scheme() {
102 "https" => {
103 let mut http = url.clone();
104 http.set_scheme("http").ok();
105 http.set_port(url.port_or_known_default()).ok();
106
107 let domain = url.domain().ok_or_else(|| Error::other("missing domain"))?;
108 let ssl = self
109 .ssl_config
110 .as_inner()
111 .configure()
112 .map_err(Error::other)?
113 .into_ssl(domain)
114 .map_err(Error::other)?;
115
116 let inner = self.tcp_config.connect(&http).await?;
117 let mut stream = SslStream::new(ssl, inner).map_err(Error::other)?;
118 Pin::new(&mut stream)
119 .connect()
120 .await
121 .map_err(Error::other)?;
122 Ok(OpenSslClientTransport(Tls(Box::new(stream))))
123 }
124
125 "http" => self
126 .tcp_config
127 .connect(url)
128 .await
129 .map(|t| OpenSslClientTransport(Tcp(t))),
130
131 unknown => Err(Error::new(
132 ErrorKind::InvalidInput,
133 format!("unknown scheme {unknown}"),
134 )),
135 }
136 }
137
138 fn runtime(&self) -> Self::Runtime {
139 self.tcp_config.runtime()
140 }
141
142 async fn resolve(&self, host: &str, port: u16) -> Result<Vec<SocketAddr>> {
143 self.tcp_config.resolve(host, port).await
144 }
145}
146
147#[derive(Debug)]
148enum OpenSslClientTransportInner<T: Unpin> {
149 Tcp(T),
150 Tls(Box<SslStream<T>>),
151}
152
153#[derive(Debug)]
158pub struct OpenSslClientTransport<T: Unpin>(OpenSslClientTransportInner<T>);
159
160impl<T: Unpin> OpenSslClientTransport<T> {
161 pub fn as_tls(&self) -> Option<&SslStream<T>> {
163 match &self.0 {
164 Tcp(_) => None,
165 Tls(s) => Some(s),
166 }
167 }
168}
169
170impl<T: Unpin> AsRef<T> for OpenSslClientTransport<T> {
171 fn as_ref(&self) -> &T {
172 match &self.0 {
173 Tcp(t) => t,
174 Tls(s) => s.get_ref(),
175 }
176 }
177}
178
179impl<T: AsyncRead + AsyncWrite + Unpin> AsyncRead for OpenSslClientTransport<T> {
180 fn poll_read(
181 mut self: Pin<&mut Self>,
182 cx: &mut Context<'_>,
183 buf: &mut [u8],
184 ) -> Poll<Result<usize>> {
185 match &mut self.0 {
186 Tcp(t) => Pin::new(t).poll_read(cx, buf),
187 Tls(s) => Pin::new(&mut **s).poll_read(cx, buf),
188 }
189 }
190
191 fn poll_read_vectored(
192 mut self: Pin<&mut Self>,
193 cx: &mut Context<'_>,
194 bufs: &mut [IoSliceMut<'_>],
195 ) -> Poll<Result<usize>> {
196 match &mut self.0 {
197 Tcp(t) => Pin::new(t).poll_read_vectored(cx, bufs),
198 Tls(s) => Pin::new(&mut **s).poll_read_vectored(cx, bufs),
199 }
200 }
201}
202
203impl<T: AsyncRead + AsyncWrite + Unpin> AsyncWrite for OpenSslClientTransport<T> {
204 fn poll_write(
205 mut self: Pin<&mut Self>,
206 cx: &mut Context<'_>,
207 buf: &[u8],
208 ) -> Poll<Result<usize>> {
209 match &mut self.0 {
210 Tcp(t) => Pin::new(t).poll_write(cx, buf),
211 Tls(s) => Pin::new(&mut **s).poll_write(cx, buf),
212 }
213 }
214
215 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
216 match &mut self.0 {
217 Tcp(t) => Pin::new(t).poll_flush(cx),
218 Tls(s) => Pin::new(&mut **s).poll_flush(cx),
219 }
220 }
221
222 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
223 match &mut self.0 {
224 Tcp(t) => Pin::new(t).poll_close(cx),
225 Tls(s) => Pin::new(&mut **s).poll_close(cx),
226 }
227 }
228}
229
230impl<T: Transport> Transport for OpenSslClientTransport<T> {
231 fn peer_addr(&self) -> Result<Option<SocketAddr>> {
232 self.as_ref().peer_addr()
233 }
234
235 fn negotiated_alpn(&self) -> Option<std::borrow::Cow<'_, [u8]>> {
236 match &self.0 {
237 Tcp(_) => None,
238 Tls(s) => s
239 .ssl()
240 .selected_alpn_protocol()
241 .map(std::borrow::Cow::Borrowed),
242 }
243 }
244}