Skip to main content

trillium_native_tls/
client.rs

1use async_native_tls::{TlsConnector, TlsStream};
2use std::{
3    fmt::{Debug, Formatter},
4    io::{Error, IoSlice, IoSliceMut, Result},
5    net::SocketAddr,
6    pin::Pin,
7    sync::Arc,
8    task::{Context, Poll},
9};
10use trillium_server_common::{AsyncRead, AsyncWrite, Connector, Destination, Transport, Url};
11
12/// Configuration for the native tls client connector
13#[derive(Clone)]
14pub struct NativeTlsConfig<Config> {
15    /// configuration for the inner Connector (usually tcp)
16    pub tcp_config: Config,
17
18    /// native tls configuration
19    ///
20    /// Although async_native_tls calls this
21    /// a TlsConnector, it's actually a builder ¯\_(ツ)_/¯
22    pub tls_connector: Arc<TlsConnector>,
23}
24
25impl<C: Connector> NativeTlsConfig<C> {
26    /// replace the tcp config
27    pub fn with_tcp_config(mut self, config: C) -> Self {
28        self.tcp_config = config;
29        self
30    }
31}
32
33impl<C: Connector> From<C> for NativeTlsConfig<C> {
34    fn from(tcp_config: C) -> Self {
35        Self {
36            tcp_config,
37            tls_connector: Arc::new(TlsConnector::default()),
38        }
39    }
40}
41
42impl<Config: Debug> Debug for NativeTlsConfig<Config> {
43    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
44        f.debug_struct("NativeTlsConfig")
45            .field("tcp_config", &self.tcp_config)
46            .field("tls_connector", &format_args!(".."))
47            .finish()
48    }
49}
50
51impl<Config: Default> Default for NativeTlsConfig<Config> {
52    fn default() -> Self {
53        Self {
54            tcp_config: Config::default(),
55            tls_connector: Arc::new(TlsConnector::default()),
56        }
57    }
58}
59
60impl<Config> AsRef<Config> for NativeTlsConfig<Config> {
61    fn as_ref(&self) -> &Config {
62        &self.tcp_config
63    }
64}
65
66impl<T: Connector> Connector for NativeTlsConfig<T> {
67    type Runtime = T::Runtime;
68    type Transport = NativeTlsClientTransport<T::Transport>;
69    type Udp = T::Udp;
70
71    async fn connect(&self, url: &Url) -> Result<Self::Transport> {
72        self.connect_to(Destination::from_url(url)?).await
73    }
74
75    async fn connect_to(&self, destination: Destination) -> Result<Self::Transport> {
76        if !destination.secure() {
77            return self
78                .tcp_config
79                .connect_to(destination)
80                .await
81                .map(NativeTlsClientTransport::from);
82        }
83
84        // The server name comes from the destination, never the dialed address; capture a domain
85        // before the dial moves the destination, deferring the host-less (bare-IP) case to the
86        // address actually connected to.
87        let domain = destination.host().map(str::to_owned);
88        let inner_stream = self
89            .tcp_config
90            .connect_to(destination.with_secure(false))
91            .await?;
92        let host = match domain {
93            Some(domain) => domain,
94            None => inner_stream
95                .peer_addr()?
96                .ok_or_else(|| Error::other("no peer address for bare-ip destination"))?
97                .ip()
98                .to_string(),
99        };
100
101        // `destination.alpn()` is intentionally ignored: `async-native-tls` does not yet expose
102        // per-connection ALPN configuration. Honoring the override awaits the upstream ALPN support
103        // PR; until then this connector negotiates no ALPN (effectively h1-only).
104        self.tls_connector
105            .connect(host.as_str(), inner_stream)
106            .await
107            .map_err(|e| Error::other(e.to_string()))
108            .map(NativeTlsClientTransport::from)
109    }
110
111    fn runtime(&self) -> Self::Runtime {
112        self.tcp_config.runtime()
113    }
114
115    async fn resolve(&self, host: &str, port: u16) -> Result<Vec<SocketAddr>> {
116        self.tcp_config.resolve(host, port).await
117    }
118}
119
120/// Client [`Transport`] for the native tls connector
121///
122/// This may represent either an encrypted tls connection or a plaintext
123/// connection
124
125#[derive(Debug)]
126pub struct NativeTlsClientTransport<T>(NativeTlsClientTransportInner<T>);
127
128impl<T: AsyncWrite + AsyncRead + Unpin> NativeTlsClientTransport<T> {
129    /// Borrow the TlsStream, if this connection is tls.
130    ///
131    /// Returns None otherwise
132    pub fn as_tls(&self) -> Option<&TlsStream<T>> {
133        match &self.0 {
134            Tcp(_) => None,
135            Tls(tls) => Some(tls),
136        }
137    }
138}
139
140impl<T> From<T> for NativeTlsClientTransport<T> {
141    fn from(value: T) -> Self {
142        Self(Tcp(value))
143    }
144}
145
146impl<T> From<TlsStream<T>> for NativeTlsClientTransport<T> {
147    fn from(value: TlsStream<T>) -> Self {
148        Self(Tls(value))
149    }
150}
151
152impl<T: Transport> AsRef<T> for NativeTlsClientTransport<T> {
153    fn as_ref(&self) -> &T {
154        match &self.0 {
155            Tcp(transport) => transport,
156            Tls(tls_stream) => tls_stream.get_ref(),
157        }
158    }
159}
160
161#[derive(Debug)]
162enum NativeTlsClientTransportInner<T> {
163    Tcp(T),
164    Tls(TlsStream<T>),
165}
166use NativeTlsClientTransportInner::{Tcp, Tls};
167
168impl<T: AsyncRead + AsyncWrite + Unpin> AsyncRead for NativeTlsClientTransport<T> {
169    fn poll_read(
170        mut self: Pin<&mut Self>,
171        cx: &mut Context<'_>,
172        buf: &mut [u8],
173    ) -> Poll<Result<usize>> {
174        match &mut self.0 {
175            Tcp(t) => Pin::new(t).poll_read(cx, buf),
176            Tls(t) => Pin::new(t).poll_read(cx, buf),
177        }
178    }
179
180    fn poll_read_vectored(
181        mut self: Pin<&mut Self>,
182        cx: &mut Context<'_>,
183        bufs: &mut [IoSliceMut<'_>],
184    ) -> Poll<Result<usize>> {
185        match &mut self.0 {
186            Tcp(t) => Pin::new(t).poll_read_vectored(cx, bufs),
187            Tls(t) => Pin::new(t).poll_read_vectored(cx, bufs),
188        }
189    }
190}
191
192impl<T: AsyncRead + AsyncWrite + Unpin> AsyncWrite for NativeTlsClientTransport<T> {
193    fn poll_write(
194        mut self: Pin<&mut Self>,
195        cx: &mut Context<'_>,
196        buf: &[u8],
197    ) -> Poll<Result<usize>> {
198        match &mut self.0 {
199            Tcp(t) => Pin::new(t).poll_write(cx, buf),
200            Tls(t) => Pin::new(t).poll_write(cx, buf),
201        }
202    }
203
204    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
205        match &mut self.0 {
206            Tcp(t) => Pin::new(t).poll_flush(cx),
207            Tls(t) => Pin::new(t).poll_flush(cx),
208        }
209    }
210
211    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
212        match &mut self.0 {
213            Tcp(t) => Pin::new(t).poll_close(cx),
214            Tls(t) => Pin::new(t).poll_close(cx),
215        }
216    }
217
218    fn poll_write_vectored(
219        mut self: Pin<&mut Self>,
220        cx: &mut Context<'_>,
221        bufs: &[IoSlice<'_>],
222    ) -> Poll<Result<usize>> {
223        match &mut self.0 {
224            Tcp(t) => Pin::new(t).poll_write_vectored(cx, bufs),
225            Tls(t) => Pin::new(t).poll_write_vectored(cx, bufs),
226        }
227    }
228}
229
230impl<T: Transport> Transport for NativeTlsClientTransport<T> {
231    fn peer_addr(&self) -> Result<Option<SocketAddr>> {
232        self.as_ref().peer_addr()
233    }
234
235    // `negotiated_alpn` is left at the trait default (`None`). `native-tls` exposes the negotiated
236    // ALPN protocol via `TlsStream::negotiated_alpn` (gated on the `alpn` feature), but the
237    // `async-native-tls` 0.6 wrapper keeps its inner `native_tls::TlsStream` private and offers no
238    // accessor for it, so we cannot reach the value from here. Lift this once async-native-tls
239    // exposes the read side.
240}