Skip to main content

trillium_openssl/
server.rs

1use async_openssl::SslStream;
2use openssl::{
3    pkcs12::Pkcs12,
4    pkey::{PKey, Private},
5    ssl::{AlpnError, Ssl, SslAcceptor, SslMethod},
6    x509::X509,
7};
8use std::{
9    borrow::Cow,
10    fmt::{Debug, Formatter},
11    io,
12    pin::Pin,
13    sync::Arc,
14    task::{Context, Poll},
15};
16use trillium_server_common::{Acceptor, AsyncRead, AsyncWrite, Transport};
17
18/// trillium [`Acceptor`] for openssl
19#[derive(Clone)]
20pub struct OpenSslAcceptor(Inner);
21
22#[derive(Clone)]
23enum Inner {
24    /// Built from cert + key inside this crate. We retain the parsed openssl handles
25    /// (the same ones the `SslAcceptor` already holds internally) so chain methods
26    /// like `without_http2` can rebuild with a different ALPN list without keeping
27    /// a copy of the raw PEM bytes around.
28    Rebuildable {
29        acceptor: Arc<SslAcceptor>,
30        source: Source,
31    },
32    /// Constructed from a pre-built `SslAcceptor`. Chain methods are no-ops.
33    Custom(Arc<SslAcceptor>),
34}
35
36#[derive(Clone)]
37struct Source {
38    cert: X509,
39    chain: Vec<X509>,
40    key: PKey<Private>,
41    alpn: Vec<Vec<u8>>,
42}
43
44impl Debug for OpenSslAcceptor {
45    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
46        f.debug_tuple("OpenSslAcceptor")
47            .field(&"<<SslAcceptor>>")
48            .finish()
49    }
50}
51
52impl OpenSslAcceptor {
53    /// build a new `OpenSslAcceptor` from a [`SslAcceptor`]
54    pub fn new(acceptor: SslAcceptor) -> Self {
55        Self(Inner::Custom(Arc::new(acceptor)))
56    }
57
58    /// build a new `OpenSslAcceptor` from a PEM-encoded cert chain and PEM-encoded private key.
59    ///
60    /// Defaults to advertising `[h2, http/1.1]` via ALPN. Use [`Self::without_http2`] to
61    /// drop HTTP/2.
62    ///
63    /// # Example
64    ///
65    /// ```rust,no_run
66    /// use trillium_openssl::OpenSslAcceptor;
67    /// const KEY: &[u8] = include_bytes!("../examples/key.pem");
68    /// const CERT: &[u8] = include_bytes!("../examples/cert.pem");
69    /// let acceptor = OpenSslAcceptor::from_single_cert(CERT, KEY);
70    /// ```
71    pub fn from_single_cert(cert: &[u8], key: &[u8]) -> Self {
72        let mut chain = X509::stack_from_pem(cert)
73            .expect("could not parse certificate chain")
74            .into_iter();
75        let leaf = chain.next().expect("certificate chain was empty");
76        let chain = chain.collect();
77        let key = PKey::private_key_from_pem(key).expect("could not parse private key");
78        Self::from_parts(leaf, chain, key, default_alpn())
79    }
80
81    /// build a new `OpenSslAcceptor` from a pkcs12 archive and password.
82    pub fn from_pkcs12(der: &[u8], password: &str) -> Self {
83        let parsed = Pkcs12::from_der(der)
84            .expect("could not read pkcs12 archive")
85            .parse2(password)
86            .expect("could not parse pkcs12 archive");
87        let cert = parsed
88            .cert
89            .expect("pkcs12 archive contained no certificate");
90        let key = parsed
91            .pkey
92            .expect("pkcs12 archive contained no private key");
93        let chain = parsed
94            .ca
95            .map(|stack| stack.into_iter().collect())
96            .unwrap_or_default();
97        Self::from_parts(cert, chain, key, default_alpn())
98    }
99
100    fn from_parts(cert: X509, chain: Vec<X509>, key: PKey<Private>, alpn: Vec<Vec<u8>>) -> Self {
101        let source = Source {
102            cert,
103            chain,
104            key,
105            alpn,
106        };
107        let acceptor = build_acceptor(&source);
108        Self(Inner::Rebuildable {
109            acceptor: Arc::new(acceptor),
110            source,
111        })
112    }
113
114    /// Drop `h2` from the ALPN protocol list, forcing HTTP/1.1 over TLS.
115    ///
116    /// Has no effect on acceptors constructed from a pre-built [`SslAcceptor`] via
117    /// [`Self::new`] — those manage their own ALPN configuration.
118    #[must_use]
119    pub fn without_http2(self) -> Self {
120        match self.0 {
121            Inner::Rebuildable { mut source, .. } => {
122                source.alpn.retain(|p| p != b"h2");
123                let acceptor = build_acceptor(&source);
124                Self(Inner::Rebuildable {
125                    acceptor: Arc::new(acceptor),
126                    source,
127                })
128            }
129            other @ Inner::Custom(_) => Self(other),
130        }
131    }
132
133    fn acceptor(&self) -> &SslAcceptor {
134        match &self.0 {
135            Inner::Rebuildable { acceptor, .. } | Inner::Custom(acceptor) => acceptor,
136        }
137    }
138}
139
140fn default_alpn() -> Vec<Vec<u8>> {
141    vec![b"h2".to_vec(), b"http/1.1".to_vec()]
142}
143
144fn build_acceptor(source: &Source) -> SslAcceptor {
145    let mut builder = SslAcceptor::mozilla_intermediate_v5(SslMethod::tls())
146        .expect("could not build SslAcceptor");
147    builder
148        .set_certificate(&source.cert)
149        .expect("could not set certificate");
150    for ca in &source.chain {
151        builder
152            .add_extra_chain_cert(ca.clone())
153            .expect("could not add chain certificate");
154    }
155    builder
156        .set_private_key(&source.key)
157        .expect("could not set private key");
158    builder
159        .check_private_key()
160        .expect("private key did not match certificate");
161    if !source.alpn.is_empty() {
162        let server_protos = source.alpn.clone();
163        builder.set_alpn_select_callback(move |_ssl, client_wire| {
164            select_alpn(&server_protos, client_wire).ok_or(AlpnError::NOACK)
165        });
166    }
167    builder.build()
168}
169
170/// Walk the wire-format ALPN list from the client and return a slice of the first protocol
171/// the server prefers. Returning a subslice of `client_wire` preserves its lifetime so the
172/// `set_alpn_select_callback` closure type-checks.
173fn select_alpn<'c>(server: &[Vec<u8>], client_wire: &'c [u8]) -> Option<&'c [u8]> {
174    let mut i = 0;
175    while i < client_wire.len() {
176        let len = usize::from(client_wire[i]);
177        let start = i + 1;
178        let end = start + len;
179        if end > client_wire.len() {
180            return None;
181        }
182        let proto = &client_wire[start..end];
183        if server.iter().any(|p| p.as_slice() == proto) {
184            return Some(proto);
185        }
186        i = end;
187    }
188    None
189}
190
191impl From<SslAcceptor> for OpenSslAcceptor {
192    fn from(acceptor: SslAcceptor) -> Self {
193        Self::new(acceptor)
194    }
195}
196
197impl<Input> Acceptor<Input> for OpenSslAcceptor
198where
199    Input: Transport,
200{
201    type Error = io::Error;
202    type Output = OpenSslServerTransport<Input>;
203
204    async fn accept(&self, input: Input) -> Result<Self::Output, Self::Error> {
205        let ssl = Ssl::new(self.acceptor().context()).map_err(io::Error::other)?;
206        let mut stream = SslStream::new(ssl, input).map_err(io::Error::other)?;
207        Pin::new(&mut stream)
208            .accept()
209            .await
210            .map_err(io::Error::other)?;
211        Ok(OpenSslServerTransport(stream))
212    }
213}
214
215/// Transport for the openssl server acceptor
216#[derive(Debug)]
217pub struct OpenSslServerTransport<T: Unpin>(SslStream<T>);
218
219impl<T: Unpin> OpenSslServerTransport<T> {
220    /// access the contained transport (eg `TcpStream`)
221    pub fn inner_transport(&self) -> &T {
222        self.0.get_ref()
223    }
224
225    /// mutably access the contained transport (eg `TcpStream`)
226    pub fn inner_transport_mut(&mut self) -> &mut T {
227        self.0.get_mut()
228    }
229}
230
231impl<T: Unpin> AsRef<T> for OpenSslServerTransport<T> {
232    fn as_ref(&self) -> &T {
233        self.0.get_ref()
234    }
235}
236
237impl<T: Unpin> AsMut<T> for OpenSslServerTransport<T> {
238    fn as_mut(&mut self) -> &mut T {
239        self.0.get_mut()
240    }
241}
242
243impl<T: Unpin> AsRef<SslStream<T>> for OpenSslServerTransport<T> {
244    fn as_ref(&self) -> &SslStream<T> {
245        &self.0
246    }
247}
248
249impl<T: Unpin> AsMut<SslStream<T>> for OpenSslServerTransport<T> {
250    fn as_mut(&mut self) -> &mut SslStream<T> {
251        &mut self.0
252    }
253}
254
255impl<T: AsyncRead + AsyncWrite + Unpin> AsyncRead for OpenSslServerTransport<T> {
256    fn poll_read(
257        mut self: Pin<&mut Self>,
258        cx: &mut Context<'_>,
259        buf: &mut [u8],
260    ) -> Poll<io::Result<usize>> {
261        Pin::new(&mut self.0).poll_read(cx, buf)
262    }
263}
264
265impl<T: AsyncRead + AsyncWrite + Unpin> AsyncWrite for OpenSslServerTransport<T> {
266    fn poll_write(
267        mut self: Pin<&mut Self>,
268        cx: &mut Context<'_>,
269        buf: &[u8],
270    ) -> Poll<io::Result<usize>> {
271        Pin::new(&mut self.0).poll_write(cx, buf)
272    }
273
274    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
275        Pin::new(&mut self.0).poll_flush(cx)
276    }
277
278    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
279        Pin::new(&mut self.0).poll_close(cx)
280    }
281}
282
283impl<T: Transport> Transport for OpenSslServerTransport<T> {
284    fn peer_addr(&self) -> io::Result<Option<std::net::SocketAddr>> {
285        self.0.get_ref().peer_addr()
286    }
287
288    fn negotiated_alpn(&self) -> Option<Cow<'_, [u8]>> {
289        self.0.ssl().selected_alpn_protocol().map(Cow::Borrowed)
290    }
291}