trillium_rustls/
server.rs1use crate::crypto_provider;
2use futures_rustls::{
3 TlsAcceptor,
4 rustls::{ServerConfig, ServerConnection},
5 server::TlsStream,
6};
7use std::{
8 borrow::Cow,
9 fmt::{Debug, Formatter},
10 io,
11 pin::Pin,
12 sync::Arc,
13 task::{Context, Poll},
14};
15use trillium_server_common::{Acceptor, AsyncRead, AsyncWrite, Transport};
16
17#[derive(Clone)]
20pub struct RustlsAcceptor(TlsAcceptor);
21impl Debug for RustlsAcceptor {
22 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
23 f.debug_tuple("Rustls").field(&"<<TlsAcceptor>>").finish()
24 }
25}
26
27impl RustlsAcceptor {
28 pub fn new(t: impl Into<Self>) -> Self {
30 t.into()
31 }
32
33 pub fn from_single_cert(cert: &[u8], key: &[u8]) -> Self {
60 Self::single_cert_with_alpn(cert, key, vec![b"h2".to_vec(), b"http/1.1".to_vec()])
61 }
62
63 pub fn from_single_cert_no_h2(cert: &[u8], key: &[u8]) -> Self {
71 Self::single_cert_with_alpn(cert, key, vec![b"http/1.1".to_vec()])
72 }
73
74 fn single_cert_with_alpn(cert: &[u8], key: &[u8], alpn_protocols: Vec<Vec<u8>>) -> Self {
75 use std::io::Cursor;
76
77 let cert_chain = rustls_pemfile::certs(&mut Cursor::new(cert))
78 .collect::<Result<_, _>>()
79 .expect("could not read certificate");
80
81 let key_der = rustls_pemfile::private_key(&mut Cursor::new(key))
82 .expect("could not read key pemfile")
83 .expect("no private key found in `key`");
84
85 let mut config = ServerConfig::builder_with_provider(crypto_provider())
86 .with_safe_default_protocol_versions()
87 .expect("crypto provider did not support safe default protocol versions")
88 .with_no_client_auth()
89 .with_single_cert(cert_chain, key_der)
90 .expect("could not create a rustls ServerConfig from the supplied cert and key");
91 config.alpn_protocols = alpn_protocols;
92 config.into()
93 }
94}
95
96impl From<ServerConfig> for RustlsAcceptor {
97 fn from(sc: ServerConfig) -> Self {
98 Self(Arc::new(sc).into())
99 }
100}
101
102impl From<TlsAcceptor> for RustlsAcceptor {
103 fn from(ta: TlsAcceptor) -> Self {
104 Self(ta)
105 }
106}
107
108#[derive(Debug)]
110pub struct RustlsServerTransport<T>(TlsStream<T>);
111
112impl<T: AsyncRead + AsyncWrite + Unpin> AsyncRead for RustlsServerTransport<T> {
113 fn poll_read(
114 mut self: Pin<&mut Self>,
115 cx: &mut Context<'_>,
116 buf: &mut [u8],
117 ) -> Poll<io::Result<usize>> {
118 Pin::new(&mut self.0).poll_read(cx, buf)
119 }
120}
121
122impl<T: AsyncWrite + AsyncRead + Unpin> AsyncWrite for RustlsServerTransport<T> {
123 fn poll_write(
124 mut self: Pin<&mut Self>,
125 cx: &mut Context<'_>,
126 buf: &[u8],
127 ) -> Poll<io::Result<usize>> {
128 Pin::new(&mut self.0).poll_write(cx, buf)
129 }
130
131 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
132 Pin::new(&mut self.0).poll_flush(cx)
133 }
134
135 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
136 Pin::new(&mut self.0).poll_close(cx)
137 }
138
139 fn poll_write_vectored(
140 mut self: Pin<&mut Self>,
141 cx: &mut Context<'_>,
142 bufs: &[io::IoSlice<'_>],
143 ) -> Poll<io::Result<usize>> {
144 Pin::new(&mut self.0).poll_write_vectored(cx, bufs)
145 }
146}
147
148impl<T: Transport> Transport for RustlsServerTransport<T> {
149 fn peer_addr(&self) -> io::Result<Option<std::net::SocketAddr>> {
150 self.inner_transport().peer_addr()
151 }
152
153 fn negotiated_alpn(&self) -> Option<Cow<'_, [u8]>> {
154 self.as_ref().alpn_protocol().map(Cow::Borrowed)
155 }
156}
157
158impl<T> RustlsServerTransport<T> {
159 pub fn inner_transport(&self) -> &T {
161 self.0.get_ref().0
162 }
163
164 pub fn inner_transport_mut(&mut self) -> &mut T {
166 self.0.get_mut().0
167 }
168}
169
170impl<T> AsRef<ServerConnection> for RustlsServerTransport<T> {
171 fn as_ref(&self) -> &ServerConnection {
172 self.0.get_ref().1
173 }
174}
175
176impl<T> AsMut<ServerConnection> for RustlsServerTransport<T> {
177 fn as_mut(&mut self) -> &mut ServerConnection {
178 self.0.get_mut().1
179 }
180}
181
182impl<T> From<TlsStream<T>> for RustlsServerTransport<T> {
183 fn from(value: TlsStream<T>) -> Self {
184 Self(value)
185 }
186}
187
188impl<T> From<RustlsServerTransport<T>> for TlsStream<T> {
189 fn from(RustlsServerTransport(value): RustlsServerTransport<T>) -> Self {
190 value
191 }
192}
193
194impl<Input> Acceptor<Input> for RustlsAcceptor
195where
196 Input: Transport,
197{
198 type Error = io::Error;
199 type Output = RustlsServerTransport<Input>;
200
201 async fn accept(&self, input: Input) -> Result<Self::Output, Self::Error> {
202 self.0.accept(input).await.map(RustlsServerTransport)
203 }
204}