Skip to main content

trillium_native_tls/
client.rs

1use async_native_tls::{TlsConnector, TlsStream};
2use std::{
3    fmt::{Debug, Formatter},
4    io::{Error, ErrorKind, IoSlice, IoSliceMut, Result},
5    net::SocketAddr,
6    pin::Pin,
7    sync::Arc,
8    task::{Context, Poll},
9};
10use trillium_server_common::{AsyncRead, AsyncWrite, Connector, Transport, Url};
11
12/// Configuration for the native tls client connector
13#[derive(Clone)]
14pub struct NativeTlsConfig<Config> {
15    /// configuration for the inner Connector (usually tcp)
16    pub tcp_config: Config,
17
18    /// native tls configuration
19    ///
20    /// Although async_native_tls calls this
21    /// a TlsConnector, it's actually a builder ¯\_(ツ)_/¯
22    pub tls_connector: Arc<TlsConnector>,
23}
24
25impl<C: Connector> NativeTlsConfig<C> {
26    /// replace the tcp config
27    pub fn with_tcp_config(mut self, config: C) -> Self {
28        self.tcp_config = config;
29        self
30    }
31}
32
33impl<C: Connector> From<C> for NativeTlsConfig<C> {
34    fn from(tcp_config: C) -> Self {
35        Self {
36            tcp_config,
37            tls_connector: Arc::new(TlsConnector::default()),
38        }
39    }
40}
41
42impl<Config: Debug> Debug for NativeTlsConfig<Config> {
43    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
44        f.debug_struct("NativeTlsConfig")
45            .field("tcp_config", &self.tcp_config)
46            .field("tls_connector", &format_args!(".."))
47            .finish()
48    }
49}
50
51impl<Config: Default> Default for NativeTlsConfig<Config> {
52    fn default() -> Self {
53        Self {
54            tcp_config: Config::default(),
55            tls_connector: Arc::new(TlsConnector::default()),
56        }
57    }
58}
59
60impl<Config> AsRef<Config> for NativeTlsConfig<Config> {
61    fn as_ref(&self) -> &Config {
62        &self.tcp_config
63    }
64}
65
66impl<T: Connector> Connector for NativeTlsConfig<T> {
67    type Runtime = T::Runtime;
68    type Transport = NativeTlsClientTransport<T::Transport>;
69    type Udp = T::Udp;
70
71    async fn connect(&self, url: &Url) -> Result<Self::Transport> {
72        match url.scheme() {
73            "https" => {
74                let mut http = url.clone();
75                http.set_scheme("http").ok();
76                http.set_port(url.port_or_known_default()).ok();
77                let inner_stream = self.tcp_config.connect(&http).await?;
78
79                self.tls_connector
80                    .connect(url, inner_stream)
81                    .await
82                    .map_err(|e| Error::other(e.to_string()))
83                    .map(NativeTlsClientTransport::from)
84            }
85
86            "http" => self
87                .tcp_config
88                .connect(url)
89                .await
90                .map(NativeTlsClientTransport::from),
91
92            unknown => Err(Error::new(
93                ErrorKind::InvalidInput,
94                format!("unknown scheme {unknown}"),
95            )),
96        }
97    }
98
99    fn runtime(&self) -> Self::Runtime {
100        self.tcp_config.runtime()
101    }
102
103    async fn resolve(&self, host: &str, port: u16) -> Result<Vec<SocketAddr>> {
104        self.tcp_config.resolve(host, port).await
105    }
106}
107
108/// Client [`Transport`] for the native tls connector
109///
110/// This may represent either an encrypted tls connection or a plaintext
111/// connection
112
113#[derive(Debug)]
114pub struct NativeTlsClientTransport<T>(NativeTlsClientTransportInner<T>);
115
116impl<T: AsyncWrite + AsyncRead + Unpin> NativeTlsClientTransport<T> {
117    /// Borrow the TlsStream, if this connection is tls.
118    ///
119    /// Returns None otherwise
120    pub fn as_tls(&self) -> Option<&TlsStream<T>> {
121        match &self.0 {
122            Tcp(_) => None,
123            Tls(tls) => Some(tls),
124        }
125    }
126}
127
128impl<T> From<T> for NativeTlsClientTransport<T> {
129    fn from(value: T) -> Self {
130        Self(Tcp(value))
131    }
132}
133
134impl<T> From<TlsStream<T>> for NativeTlsClientTransport<T> {
135    fn from(value: TlsStream<T>) -> Self {
136        Self(Tls(value))
137    }
138}
139
140impl<T: Transport> AsRef<T> for NativeTlsClientTransport<T> {
141    fn as_ref(&self) -> &T {
142        match &self.0 {
143            Tcp(transport) => transport,
144            Tls(tls_stream) => tls_stream.get_ref(),
145        }
146    }
147}
148
149#[derive(Debug)]
150enum NativeTlsClientTransportInner<T> {
151    Tcp(T),
152    Tls(TlsStream<T>),
153}
154use NativeTlsClientTransportInner::{Tcp, Tls};
155
156impl<T: AsyncRead + AsyncWrite + Unpin> AsyncRead for NativeTlsClientTransport<T> {
157    fn poll_read(
158        mut self: Pin<&mut Self>,
159        cx: &mut Context<'_>,
160        buf: &mut [u8],
161    ) -> Poll<Result<usize>> {
162        match &mut self.0 {
163            Tcp(t) => Pin::new(t).poll_read(cx, buf),
164            Tls(t) => Pin::new(t).poll_read(cx, buf),
165        }
166    }
167
168    fn poll_read_vectored(
169        mut self: Pin<&mut Self>,
170        cx: &mut Context<'_>,
171        bufs: &mut [IoSliceMut<'_>],
172    ) -> Poll<Result<usize>> {
173        match &mut self.0 {
174            Tcp(t) => Pin::new(t).poll_read_vectored(cx, bufs),
175            Tls(t) => Pin::new(t).poll_read_vectored(cx, bufs),
176        }
177    }
178}
179
180impl<T: AsyncRead + AsyncWrite + Unpin> AsyncWrite for NativeTlsClientTransport<T> {
181    fn poll_write(
182        mut self: Pin<&mut Self>,
183        cx: &mut Context<'_>,
184        buf: &[u8],
185    ) -> Poll<Result<usize>> {
186        match &mut self.0 {
187            Tcp(t) => Pin::new(t).poll_write(cx, buf),
188            Tls(t) => Pin::new(t).poll_write(cx, buf),
189        }
190    }
191
192    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
193        match &mut self.0 {
194            Tcp(t) => Pin::new(t).poll_flush(cx),
195            Tls(t) => Pin::new(t).poll_flush(cx),
196        }
197    }
198
199    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
200        match &mut self.0 {
201            Tcp(t) => Pin::new(t).poll_close(cx),
202            Tls(t) => Pin::new(t).poll_close(cx),
203        }
204    }
205
206    fn poll_write_vectored(
207        mut self: Pin<&mut Self>,
208        cx: &mut Context<'_>,
209        bufs: &[IoSlice<'_>],
210    ) -> Poll<Result<usize>> {
211        match &mut self.0 {
212            Tcp(t) => Pin::new(t).poll_write_vectored(cx, bufs),
213            Tls(t) => Pin::new(t).poll_write_vectored(cx, bufs),
214        }
215    }
216}
217
218impl<T: Transport> Transport for NativeTlsClientTransport<T> {
219    fn peer_addr(&self) -> Result<Option<SocketAddr>> {
220        self.as_ref().peer_addr()
221    }
222}