trillium_native_tls/
client.rs1use async_native_tls::{TlsConnector, TlsStream};
2use std::{
3 fmt::{Debug, Formatter},
4 io::{Error, IoSlice, IoSliceMut, Result},
5 net::SocketAddr,
6 pin::Pin,
7 sync::Arc,
8 task::{Context, Poll},
9};
10use trillium_server_common::{AsyncRead, AsyncWrite, Connector, Destination, Transport, Url};
11
12#[derive(Clone)]
14pub struct NativeTlsConfig<Config> {
15 pub tcp_config: Config,
17
18 pub tls_connector: Arc<TlsConnector>,
23}
24
25impl<C: Connector> NativeTlsConfig<C> {
26 pub fn with_tcp_config(mut self, config: C) -> Self {
28 self.tcp_config = config;
29 self
30 }
31}
32
33impl<C: Connector> From<C> for NativeTlsConfig<C> {
34 fn from(tcp_config: C) -> Self {
35 Self {
36 tcp_config,
37 tls_connector: Arc::new(TlsConnector::default()),
38 }
39 }
40}
41
42impl<Config: Debug> Debug for NativeTlsConfig<Config> {
43 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
44 f.debug_struct("NativeTlsConfig")
45 .field("tcp_config", &self.tcp_config)
46 .field("tls_connector", &format_args!(".."))
47 .finish()
48 }
49}
50
51impl<Config: Default> Default for NativeTlsConfig<Config> {
52 fn default() -> Self {
53 Self {
54 tcp_config: Config::default(),
55 tls_connector: Arc::new(TlsConnector::default()),
56 }
57 }
58}
59
60impl<Config> AsRef<Config> for NativeTlsConfig<Config> {
61 fn as_ref(&self) -> &Config {
62 &self.tcp_config
63 }
64}
65
66impl<T: Connector> Connector for NativeTlsConfig<T> {
67 type Runtime = T::Runtime;
68 type Transport = NativeTlsClientTransport<T::Transport>;
69 type Udp = T::Udp;
70
71 async fn connect(&self, url: &Url) -> Result<Self::Transport> {
72 self.connect_to(Destination::from_url(url)?).await
73 }
74
75 async fn connect_to(&self, destination: Destination) -> Result<Self::Transport> {
76 if !destination.secure() {
77 return self
78 .tcp_config
79 .connect_to(destination)
80 .await
81 .map(NativeTlsClientTransport::from);
82 }
83
84 let domain = destination.host().map(str::to_owned);
88 let inner_stream = self
89 .tcp_config
90 .connect_to(destination.with_secure(false))
91 .await?;
92 let host = match domain {
93 Some(domain) => domain,
94 None => inner_stream
95 .peer_addr()?
96 .ok_or_else(|| Error::other("no peer address for bare-ip destination"))?
97 .ip()
98 .to_string(),
99 };
100
101 self.tls_connector
105 .connect(host.as_str(), inner_stream)
106 .await
107 .map_err(|e| Error::other(e.to_string()))
108 .map(NativeTlsClientTransport::from)
109 }
110
111 fn runtime(&self) -> Self::Runtime {
112 self.tcp_config.runtime()
113 }
114
115 async fn resolve(&self, host: &str, port: u16) -> Result<Vec<SocketAddr>> {
116 self.tcp_config.resolve(host, port).await
117 }
118}
119
120#[derive(Debug)]
126pub struct NativeTlsClientTransport<T>(NativeTlsClientTransportInner<T>);
127
128impl<T: AsyncWrite + AsyncRead + Unpin> NativeTlsClientTransport<T> {
129 pub fn as_tls(&self) -> Option<&TlsStream<T>> {
133 match &self.0 {
134 Tcp(_) => None,
135 Tls(tls) => Some(tls),
136 }
137 }
138}
139
140impl<T> From<T> for NativeTlsClientTransport<T> {
141 fn from(value: T) -> Self {
142 Self(Tcp(value))
143 }
144}
145
146impl<T> From<TlsStream<T>> for NativeTlsClientTransport<T> {
147 fn from(value: TlsStream<T>) -> Self {
148 Self(Tls(value))
149 }
150}
151
152impl<T: Transport> AsRef<T> for NativeTlsClientTransport<T> {
153 fn as_ref(&self) -> &T {
154 match &self.0 {
155 Tcp(transport) => transport,
156 Tls(tls_stream) => tls_stream.get_ref(),
157 }
158 }
159}
160
161#[derive(Debug)]
162enum NativeTlsClientTransportInner<T> {
163 Tcp(T),
164 Tls(TlsStream<T>),
165}
166use NativeTlsClientTransportInner::{Tcp, Tls};
167
168impl<T: AsyncRead + AsyncWrite + Unpin> AsyncRead for NativeTlsClientTransport<T> {
169 fn poll_read(
170 mut self: Pin<&mut Self>,
171 cx: &mut Context<'_>,
172 buf: &mut [u8],
173 ) -> Poll<Result<usize>> {
174 match &mut self.0 {
175 Tcp(t) => Pin::new(t).poll_read(cx, buf),
176 Tls(t) => Pin::new(t).poll_read(cx, buf),
177 }
178 }
179
180 fn poll_read_vectored(
181 mut self: Pin<&mut Self>,
182 cx: &mut Context<'_>,
183 bufs: &mut [IoSliceMut<'_>],
184 ) -> Poll<Result<usize>> {
185 match &mut self.0 {
186 Tcp(t) => Pin::new(t).poll_read_vectored(cx, bufs),
187 Tls(t) => Pin::new(t).poll_read_vectored(cx, bufs),
188 }
189 }
190}
191
192impl<T: AsyncRead + AsyncWrite + Unpin> AsyncWrite for NativeTlsClientTransport<T> {
193 fn poll_write(
194 mut self: Pin<&mut Self>,
195 cx: &mut Context<'_>,
196 buf: &[u8],
197 ) -> Poll<Result<usize>> {
198 match &mut self.0 {
199 Tcp(t) => Pin::new(t).poll_write(cx, buf),
200 Tls(t) => Pin::new(t).poll_write(cx, buf),
201 }
202 }
203
204 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
205 match &mut self.0 {
206 Tcp(t) => Pin::new(t).poll_flush(cx),
207 Tls(t) => Pin::new(t).poll_flush(cx),
208 }
209 }
210
211 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
212 match &mut self.0 {
213 Tcp(t) => Pin::new(t).poll_close(cx),
214 Tls(t) => Pin::new(t).poll_close(cx),
215 }
216 }
217
218 fn poll_write_vectored(
219 mut self: Pin<&mut Self>,
220 cx: &mut Context<'_>,
221 bufs: &[IoSlice<'_>],
222 ) -> Poll<Result<usize>> {
223 match &mut self.0 {
224 Tcp(t) => Pin::new(t).poll_write_vectored(cx, bufs),
225 Tls(t) => Pin::new(t).poll_write_vectored(cx, bufs),
226 }
227 }
228}
229
230impl<T: Transport> Transport for NativeTlsClientTransport<T> {
231 fn peer_addr(&self) -> Result<Option<SocketAddr>> {
232 self.as_ref().peer_addr()
233 }
234
235 }