trillium_openssl/
server.rs1use async_openssl::SslStream;
2use openssl::{
3 pkcs12::Pkcs12,
4 pkey::{PKey, Private},
5 ssl::{AlpnError, Ssl, SslAcceptor, SslMethod},
6 x509::X509,
7};
8use std::{
9 borrow::Cow,
10 fmt::{Debug, Formatter},
11 io,
12 pin::Pin,
13 sync::Arc,
14 task::{Context, Poll},
15};
16use trillium_server_common::{Acceptor, AsyncRead, AsyncWrite, Transport};
17
18#[derive(Clone)]
20pub struct OpenSslAcceptor(Inner);
21
22#[derive(Clone)]
23enum Inner {
24 Rebuildable {
29 acceptor: Arc<SslAcceptor>,
30 source: Source,
31 },
32 Custom(Arc<SslAcceptor>),
34}
35
36#[derive(Clone)]
37struct Source {
38 cert: X509,
39 chain: Vec<X509>,
40 key: PKey<Private>,
41 alpn: Vec<Vec<u8>>,
42}
43
44impl Debug for OpenSslAcceptor {
45 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
46 f.debug_tuple("OpenSslAcceptor")
47 .field(&"<<SslAcceptor>>")
48 .finish()
49 }
50}
51
52impl OpenSslAcceptor {
53 pub fn new(acceptor: SslAcceptor) -> Self {
55 Self(Inner::Custom(Arc::new(acceptor)))
56 }
57
58 pub fn from_single_cert(cert: &[u8], key: &[u8]) -> Self {
72 let mut chain = X509::stack_from_pem(cert)
73 .expect("could not parse certificate chain")
74 .into_iter();
75 let leaf = chain.next().expect("certificate chain was empty");
76 let chain = chain.collect();
77 let key = PKey::private_key_from_pem(key).expect("could not parse private key");
78 Self::from_parts(leaf, chain, key, default_alpn())
79 }
80
81 pub fn from_pkcs12(der: &[u8], password: &str) -> Self {
83 let parsed = Pkcs12::from_der(der)
84 .expect("could not read pkcs12 archive")
85 .parse2(password)
86 .expect("could not parse pkcs12 archive");
87 let cert = parsed
88 .cert
89 .expect("pkcs12 archive contained no certificate");
90 let key = parsed
91 .pkey
92 .expect("pkcs12 archive contained no private key");
93 let chain = parsed
94 .ca
95 .map(|stack| stack.into_iter().collect())
96 .unwrap_or_default();
97 Self::from_parts(cert, chain, key, default_alpn())
98 }
99
100 fn from_parts(cert: X509, chain: Vec<X509>, key: PKey<Private>, alpn: Vec<Vec<u8>>) -> Self {
101 let source = Source {
102 cert,
103 chain,
104 key,
105 alpn,
106 };
107 let acceptor = build_acceptor(&source);
108 Self(Inner::Rebuildable {
109 acceptor: Arc::new(acceptor),
110 source,
111 })
112 }
113
114 #[must_use]
119 pub fn without_http2(self) -> Self {
120 match self.0 {
121 Inner::Rebuildable { mut source, .. } => {
122 source.alpn.retain(|p| p != b"h2");
123 let acceptor = build_acceptor(&source);
124 Self(Inner::Rebuildable {
125 acceptor: Arc::new(acceptor),
126 source,
127 })
128 }
129 other @ Inner::Custom(_) => Self(other),
130 }
131 }
132
133 fn acceptor(&self) -> &SslAcceptor {
134 match &self.0 {
135 Inner::Rebuildable { acceptor, .. } | Inner::Custom(acceptor) => acceptor,
136 }
137 }
138}
139
140fn default_alpn() -> Vec<Vec<u8>> {
141 vec![b"h2".to_vec(), b"http/1.1".to_vec()]
142}
143
144fn build_acceptor(source: &Source) -> SslAcceptor {
145 let mut builder = SslAcceptor::mozilla_intermediate_v5(SslMethod::tls())
146 .expect("could not build SslAcceptor");
147 builder
148 .set_certificate(&source.cert)
149 .expect("could not set certificate");
150 for ca in &source.chain {
151 builder
152 .add_extra_chain_cert(ca.clone())
153 .expect("could not add chain certificate");
154 }
155 builder
156 .set_private_key(&source.key)
157 .expect("could not set private key");
158 builder
159 .check_private_key()
160 .expect("private key did not match certificate");
161 if !source.alpn.is_empty() {
162 let server_protos = source.alpn.clone();
163 builder.set_alpn_select_callback(move |_ssl, client_wire| {
164 select_alpn(&server_protos, client_wire).ok_or(AlpnError::NOACK)
165 });
166 }
167 builder.build()
168}
169
170fn select_alpn<'c>(server: &[Vec<u8>], client_wire: &'c [u8]) -> Option<&'c [u8]> {
174 let mut i = 0;
175 while i < client_wire.len() {
176 let len = usize::from(client_wire[i]);
177 let start = i + 1;
178 let end = start + len;
179 if end > client_wire.len() {
180 return None;
181 }
182 let proto = &client_wire[start..end];
183 if server.iter().any(|p| p.as_slice() == proto) {
184 return Some(proto);
185 }
186 i = end;
187 }
188 None
189}
190
191impl From<SslAcceptor> for OpenSslAcceptor {
192 fn from(acceptor: SslAcceptor) -> Self {
193 Self::new(acceptor)
194 }
195}
196
197impl<Input> Acceptor<Input> for OpenSslAcceptor
198where
199 Input: Transport,
200{
201 type Error = io::Error;
202 type Output = OpenSslServerTransport<Input>;
203
204 async fn accept(&self, input: Input) -> Result<Self::Output, Self::Error> {
205 let ssl = Ssl::new(self.acceptor().context()).map_err(io::Error::other)?;
206 let mut stream = SslStream::new(ssl, input).map_err(io::Error::other)?;
207 Pin::new(&mut stream)
208 .accept()
209 .await
210 .map_err(io::Error::other)?;
211 Ok(OpenSslServerTransport(stream))
212 }
213}
214
215#[derive(Debug)]
217pub struct OpenSslServerTransport<T: Unpin>(SslStream<T>);
218
219impl<T: Unpin> OpenSslServerTransport<T> {
220 pub fn inner_transport(&self) -> &T {
222 self.0.get_ref()
223 }
224
225 pub fn inner_transport_mut(&mut self) -> &mut T {
227 self.0.get_mut()
228 }
229}
230
231impl<T: Unpin> AsRef<T> for OpenSslServerTransport<T> {
232 fn as_ref(&self) -> &T {
233 self.0.get_ref()
234 }
235}
236
237impl<T: Unpin> AsMut<T> for OpenSslServerTransport<T> {
238 fn as_mut(&mut self) -> &mut T {
239 self.0.get_mut()
240 }
241}
242
243impl<T: Unpin> AsRef<SslStream<T>> for OpenSslServerTransport<T> {
244 fn as_ref(&self) -> &SslStream<T> {
245 &self.0
246 }
247}
248
249impl<T: Unpin> AsMut<SslStream<T>> for OpenSslServerTransport<T> {
250 fn as_mut(&mut self) -> &mut SslStream<T> {
251 &mut self.0
252 }
253}
254
255impl<T: AsyncRead + AsyncWrite + Unpin> AsyncRead for OpenSslServerTransport<T> {
256 fn poll_read(
257 mut self: Pin<&mut Self>,
258 cx: &mut Context<'_>,
259 buf: &mut [u8],
260 ) -> Poll<io::Result<usize>> {
261 Pin::new(&mut self.0).poll_read(cx, buf)
262 }
263}
264
265impl<T: AsyncRead + AsyncWrite + Unpin> AsyncWrite for OpenSslServerTransport<T> {
266 fn poll_write(
267 mut self: Pin<&mut Self>,
268 cx: &mut Context<'_>,
269 buf: &[u8],
270 ) -> Poll<io::Result<usize>> {
271 Pin::new(&mut self.0).poll_write(cx, buf)
272 }
273
274 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
275 Pin::new(&mut self.0).poll_flush(cx)
276 }
277
278 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
279 Pin::new(&mut self.0).poll_close(cx)
280 }
281}
282
283impl<T: Transport> Transport for OpenSslServerTransport<T> {
284 fn peer_addr(&self) -> io::Result<Option<std::net::SocketAddr>> {
285 self.0.get_ref().peer_addr()
286 }
287
288 fn negotiated_alpn(&self) -> Option<Cow<'_, [u8]>> {
289 self.0.ssl().selected_alpn_protocol().map(Cow::Borrowed)
290 }
291}