Skip to main content

trillium_rustls/
server.rs

1use crate::crypto_provider;
2use futures_rustls::{
3    TlsAcceptor,
4    rustls::{ServerConfig, ServerConnection},
5    server::TlsStream,
6};
7use std::{
8    borrow::Cow,
9    fmt::{Debug, Formatter},
10    io,
11    pin::Pin,
12    sync::Arc,
13    task::{Context, Poll},
14};
15use trillium_server_common::{Acceptor, AsyncRead, AsyncWrite, Transport};
16
17/// trillium [`Acceptor`] for Rustls
18
19#[derive(Clone)]
20pub struct RustlsAcceptor(TlsAcceptor);
21impl Debug for RustlsAcceptor {
22    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
23        f.debug_tuple("Rustls").field(&"<<TlsAcceptor>>").finish()
24    }
25}
26
27impl RustlsAcceptor {
28    /// build a new RustlsAcceptor from a [`ServerConfig`] or a [`TlsAcceptor`]
29    pub fn new(t: impl Into<Self>) -> Self {
30        t.into()
31    }
32
33    /// build a new RustlsAcceptor from a cert chain (pem) and private key.
34    ///
35    /// See
36    /// [`ConfigBuilder::with_single_cert`][`crate::rustls::ConfigBuilder::with_single_cert`]
37    /// for accepted formats. If you need to customize the
38    /// [`ServerConfig`], use ServerConfig's `Into<RustlsAcceptor>`, eg
39    ///
40    /// ```rust,no_run
41    /// use trillium_rustls::{rustls::ServerConfig, RustlsAcceptor};
42    /// # let certs = vec![];
43    /// # let mut private_key = rustls_pemfile::private_key(&mut std::io::Cursor::new(b"")).unwrap().unwrap();
44    /// let rustls_acceptor: RustlsAcceptor = ServerConfig::builder()
45    ///     .with_no_client_auth()
46    ///     .with_single_cert(certs, private_key)
47    ///     .expect("could not build rustls ServerConfig")
48    ///     .into();
49    /// ```
50    ///
51    /// # Example
52    ///
53    /// ```rust,no_run
54    /// use trillium_rustls::RustlsAcceptor;
55    /// const KEY: &[u8] = include_bytes!("../examples/key.pem");
56    /// const CERT: &[u8] = include_bytes!("../examples/cert.pem");
57    /// let rustls_acceptor = RustlsAcceptor::from_single_cert(CERT, KEY);
58    /// ```
59    pub fn from_single_cert(cert: &[u8], key: &[u8]) -> Self {
60        Self::single_cert_with_alpn(cert, key, vec![b"h2".to_vec(), b"http/1.1".to_vec()])
61    }
62
63    /// build a [`RustlsAcceptor`] from a cert chain + private key that advertises only
64    /// `http/1.1` via ALPN, opting out of HTTP/2.
65    ///
66    /// This exists as a separate constructor because [`futures_rustls::TlsAcceptor`] does
67    /// not expose its inner [`ServerConfig`] for post-construction mutation. Callers needing
68    /// finer control should construct a [`ServerConfig`] directly and use its `Into`
69    /// conversion.
70    pub fn from_single_cert_no_h2(cert: &[u8], key: &[u8]) -> Self {
71        Self::single_cert_with_alpn(cert, key, vec![b"http/1.1".to_vec()])
72    }
73
74    fn single_cert_with_alpn(cert: &[u8], key: &[u8], alpn_protocols: Vec<Vec<u8>>) -> Self {
75        use std::io::Cursor;
76
77        let cert_chain = rustls_pemfile::certs(&mut Cursor::new(cert))
78            .collect::<Result<_, _>>()
79            .expect("could not read certificate");
80
81        let key_der = rustls_pemfile::private_key(&mut Cursor::new(key))
82            .expect("could not read key pemfile")
83            .expect("no private key found in `key`");
84
85        let mut config = ServerConfig::builder_with_provider(crypto_provider())
86            .with_safe_default_protocol_versions()
87            .expect("crypto provider did not support safe default protocol versions")
88            .with_no_client_auth()
89            .with_single_cert(cert_chain, key_der)
90            .expect("could not create a rustls ServerConfig from the supplied cert and key");
91        config.alpn_protocols = alpn_protocols;
92        config.into()
93    }
94}
95
96impl From<ServerConfig> for RustlsAcceptor {
97    fn from(sc: ServerConfig) -> Self {
98        Self(Arc::new(sc).into())
99    }
100}
101
102impl From<TlsAcceptor> for RustlsAcceptor {
103    fn from(ta: TlsAcceptor) -> Self {
104        Self(ta)
105    }
106}
107
108/// Transport for rustls server acceptor
109#[derive(Debug)]
110pub struct RustlsServerTransport<T>(TlsStream<T>);
111
112impl<T: AsyncRead + AsyncWrite + Unpin> AsyncRead for RustlsServerTransport<T> {
113    fn poll_read(
114        mut self: Pin<&mut Self>,
115        cx: &mut Context<'_>,
116        buf: &mut [u8],
117    ) -> Poll<io::Result<usize>> {
118        Pin::new(&mut self.0).poll_read(cx, buf)
119    }
120}
121
122impl<T: AsyncWrite + AsyncRead + Unpin> AsyncWrite for RustlsServerTransport<T> {
123    fn poll_write(
124        mut self: Pin<&mut Self>,
125        cx: &mut Context<'_>,
126        buf: &[u8],
127    ) -> Poll<io::Result<usize>> {
128        Pin::new(&mut self.0).poll_write(cx, buf)
129    }
130
131    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
132        Pin::new(&mut self.0).poll_flush(cx)
133    }
134
135    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
136        Pin::new(&mut self.0).poll_close(cx)
137    }
138
139    fn poll_write_vectored(
140        mut self: Pin<&mut Self>,
141        cx: &mut Context<'_>,
142        bufs: &[io::IoSlice<'_>],
143    ) -> Poll<io::Result<usize>> {
144        Pin::new(&mut self.0).poll_write_vectored(cx, bufs)
145    }
146}
147
148impl<T: Transport> Transport for RustlsServerTransport<T> {
149    fn peer_addr(&self) -> io::Result<Option<std::net::SocketAddr>> {
150        self.inner_transport().peer_addr()
151    }
152
153    fn negotiated_alpn(&self) -> Option<Cow<'_, [u8]>> {
154        self.as_ref().alpn_protocol().map(Cow::Borrowed)
155    }
156}
157
158impl<T> RustlsServerTransport<T> {
159    /// access the contained transport type (eg TcpStream)
160    pub fn inner_transport(&self) -> &T {
161        self.0.get_ref().0
162    }
163
164    /// mutably access the contained transport type (eg TcpStream)
165    pub fn inner_transport_mut(&mut self) -> &mut T {
166        self.0.get_mut().0
167    }
168}
169
170impl<T> AsRef<ServerConnection> for RustlsServerTransport<T> {
171    fn as_ref(&self) -> &ServerConnection {
172        self.0.get_ref().1
173    }
174}
175
176impl<T> AsMut<ServerConnection> for RustlsServerTransport<T> {
177    fn as_mut(&mut self) -> &mut ServerConnection {
178        self.0.get_mut().1
179    }
180}
181
182impl<T> From<TlsStream<T>> for RustlsServerTransport<T> {
183    fn from(value: TlsStream<T>) -> Self {
184        Self(value)
185    }
186}
187
188impl<T> From<RustlsServerTransport<T>> for TlsStream<T> {
189    fn from(RustlsServerTransport(value): RustlsServerTransport<T>) -> Self {
190        value
191    }
192}
193
194impl<Input> Acceptor<Input> for RustlsAcceptor
195where
196    Input: Transport,
197{
198    type Error = io::Error;
199    type Output = RustlsServerTransport<Input>;
200
201    async fn accept(&self, input: Input) -> Result<Self::Output, Self::Error> {
202        self.0.accept(input).await.map(RustlsServerTransport)
203    }
204}