trillium_native_tls/
client.rs1use async_native_tls::{TlsConnector, TlsStream};
2use std::{
3 fmt::{Debug, Formatter},
4 io::{Error, ErrorKind, IoSlice, IoSliceMut, Result},
5 net::SocketAddr,
6 pin::Pin,
7 sync::Arc,
8 task::{Context, Poll},
9};
10use trillium_server_common::{AsyncRead, AsyncWrite, Connector, Transport, Url};
11
12#[derive(Clone)]
14pub struct NativeTlsConfig<Config> {
15 pub tcp_config: Config,
17
18 pub tls_connector: Arc<TlsConnector>,
23}
24
25impl<C: Connector> NativeTlsConfig<C> {
26 pub fn with_tcp_config(mut self, config: C) -> Self {
28 self.tcp_config = config;
29 self
30 }
31}
32
33impl<C: Connector> From<C> for NativeTlsConfig<C> {
34 fn from(tcp_config: C) -> Self {
35 Self {
36 tcp_config,
37 tls_connector: Arc::new(TlsConnector::default()),
38 }
39 }
40}
41
42impl<Config: Debug> Debug for NativeTlsConfig<Config> {
43 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
44 f.debug_struct("NativeTlsConfig")
45 .field("tcp_config", &self.tcp_config)
46 .field("tls_connector", &format_args!(".."))
47 .finish()
48 }
49}
50
51impl<Config: Default> Default for NativeTlsConfig<Config> {
52 fn default() -> Self {
53 Self {
54 tcp_config: Config::default(),
55 tls_connector: Arc::new(TlsConnector::default()),
56 }
57 }
58}
59
60impl<Config> AsRef<Config> for NativeTlsConfig<Config> {
61 fn as_ref(&self) -> &Config {
62 &self.tcp_config
63 }
64}
65
66impl<T: Connector> Connector for NativeTlsConfig<T> {
67 type Runtime = T::Runtime;
68 type Transport = NativeTlsClientTransport<T::Transport>;
69 type Udp = T::Udp;
70
71 async fn connect(&self, url: &Url) -> Result<Self::Transport> {
72 match url.scheme() {
73 "https" => {
74 let mut http = url.clone();
75 http.set_scheme("http").ok();
76 http.set_port(url.port_or_known_default()).ok();
77 let inner_stream = self.tcp_config.connect(&http).await?;
78
79 self.tls_connector
80 .connect(url, inner_stream)
81 .await
82 .map_err(|e| Error::other(e.to_string()))
83 .map(NativeTlsClientTransport::from)
84 }
85
86 "http" => self
87 .tcp_config
88 .connect(url)
89 .await
90 .map(NativeTlsClientTransport::from),
91
92 unknown => Err(Error::new(
93 ErrorKind::InvalidInput,
94 format!("unknown scheme {unknown}"),
95 )),
96 }
97 }
98
99 fn runtime(&self) -> Self::Runtime {
100 self.tcp_config.runtime()
101 }
102
103 async fn resolve(&self, host: &str, port: u16) -> Result<Vec<SocketAddr>> {
104 self.tcp_config.resolve(host, port).await
105 }
106}
107
108#[derive(Debug)]
114pub struct NativeTlsClientTransport<T>(NativeTlsClientTransportInner<T>);
115
116impl<T: AsyncWrite + AsyncRead + Unpin> NativeTlsClientTransport<T> {
117 pub fn as_tls(&self) -> Option<&TlsStream<T>> {
121 match &self.0 {
122 Tcp(_) => None,
123 Tls(tls) => Some(tls),
124 }
125 }
126}
127
128impl<T> From<T> for NativeTlsClientTransport<T> {
129 fn from(value: T) -> Self {
130 Self(Tcp(value))
131 }
132}
133
134impl<T> From<TlsStream<T>> for NativeTlsClientTransport<T> {
135 fn from(value: TlsStream<T>) -> Self {
136 Self(Tls(value))
137 }
138}
139
140impl<T: Transport> AsRef<T> for NativeTlsClientTransport<T> {
141 fn as_ref(&self) -> &T {
142 match &self.0 {
143 Tcp(transport) => transport,
144 Tls(tls_stream) => tls_stream.get_ref(),
145 }
146 }
147}
148
149#[derive(Debug)]
150enum NativeTlsClientTransportInner<T> {
151 Tcp(T),
152 Tls(TlsStream<T>),
153}
154use NativeTlsClientTransportInner::{Tcp, Tls};
155
156impl<T: AsyncRead + AsyncWrite + Unpin> AsyncRead for NativeTlsClientTransport<T> {
157 fn poll_read(
158 mut self: Pin<&mut Self>,
159 cx: &mut Context<'_>,
160 buf: &mut [u8],
161 ) -> Poll<Result<usize>> {
162 match &mut self.0 {
163 Tcp(t) => Pin::new(t).poll_read(cx, buf),
164 Tls(t) => Pin::new(t).poll_read(cx, buf),
165 }
166 }
167
168 fn poll_read_vectored(
169 mut self: Pin<&mut Self>,
170 cx: &mut Context<'_>,
171 bufs: &mut [IoSliceMut<'_>],
172 ) -> Poll<Result<usize>> {
173 match &mut self.0 {
174 Tcp(t) => Pin::new(t).poll_read_vectored(cx, bufs),
175 Tls(t) => Pin::new(t).poll_read_vectored(cx, bufs),
176 }
177 }
178}
179
180impl<T: AsyncRead + AsyncWrite + Unpin> AsyncWrite for NativeTlsClientTransport<T> {
181 fn poll_write(
182 mut self: Pin<&mut Self>,
183 cx: &mut Context<'_>,
184 buf: &[u8],
185 ) -> Poll<Result<usize>> {
186 match &mut self.0 {
187 Tcp(t) => Pin::new(t).poll_write(cx, buf),
188 Tls(t) => Pin::new(t).poll_write(cx, buf),
189 }
190 }
191
192 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
193 match &mut self.0 {
194 Tcp(t) => Pin::new(t).poll_flush(cx),
195 Tls(t) => Pin::new(t).poll_flush(cx),
196 }
197 }
198
199 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
200 match &mut self.0 {
201 Tcp(t) => Pin::new(t).poll_close(cx),
202 Tls(t) => Pin::new(t).poll_close(cx),
203 }
204 }
205
206 fn poll_write_vectored(
207 mut self: Pin<&mut Self>,
208 cx: &mut Context<'_>,
209 bufs: &[IoSlice<'_>],
210 ) -> Poll<Result<usize>> {
211 match &mut self.0 {
212 Tcp(t) => Pin::new(t).poll_write_vectored(cx, bufs),
213 Tls(t) => Pin::new(t).poll_write_vectored(cx, bufs),
214 }
215 }
216}
217
218impl<T: Transport> Transport for NativeTlsClientTransport<T> {
219 fn peer_addr(&self) -> Result<Option<SocketAddr>> {
220 self.as_ref().peer_addr()
221 }
222}