1use crate::Identity;
2use async_native_tls::{Error, TlsAcceptor, TlsStream};
3use pem::Pem;
4use pkcs8::{
5 AlgorithmIdentifierRef, ObjectIdentifier, PrivateKeyInfo,
6 der::{Decode, Encode, asn1::AnyRef},
7};
8use std::{
9 io::{self, IoSlice, IoSliceMut},
10 net::SocketAddr,
11 pin::Pin,
12 task::{Context, Poll},
13};
14use trillium_server_common::{Acceptor, AsyncRead, AsyncWrite, Transport};
15
16#[derive(Clone, Debug)]
19pub struct NativeTlsAcceptor(TlsAcceptor);
20
21impl NativeTlsAcceptor {
22 pub fn new(t: impl Into<Self>) -> Self {
25 t.into()
26 }
27
28 pub fn from_cert_and_key(cert: &[u8], key: &[u8]) -> Self {
69 let cert_chain_der = extract_cert_chain_der(cert);
70 let key_pkcs8_der = normalize_key_to_pkcs8_der(key);
71
72 let cert_chain_pem = encode_cert_chain_pem(&cert_chain_der);
73 let key_pkcs8_pem = encode_pkcs8_pem(&key_pkcs8_der);
74 let pkcs8_err = match Identity::from_pkcs8(&cert_chain_pem, &key_pkcs8_pem) {
75 Ok(identity) => return identity.into(),
76 Err(e) => e,
77 };
78
79 let p12_der = build_pkcs12_der(&cert_chain_der, &key_pkcs8_der);
80 match Identity::from_pkcs12(&p12_der, INTERNAL_P12_PASSWORD) {
81 Ok(identity) => identity.into(),
82 Err(p12_err) => panic!(
83 "could not build Identity from provided cert and key.\n from_pkcs8 error: \
84 {pkcs8_err}\n from_pkcs12 fallback error: {p12_err}"
85 ),
86 }
87 }
88
89 pub fn from_pkcs12(der: &[u8], password: &str) -> Self {
96 Identity::from_pkcs12(der, password)
97 .expect("could not build Identity from provided pkcs12 key and password")
98 .into()
99 }
100
101 pub fn from_pkcs8(pem: &[u8], key: &[u8]) -> Self {
108 Identity::from_pkcs8(pem, key)
109 .expect("could not build Identity from provided pem and key")
110 .into()
111 }
112}
113
114const PEM_TAG_PKCS8: &str = "PRIVATE KEY";
115const PEM_TAG_PKCS1: &str = "RSA PRIVATE KEY";
116const PEM_TAG_SEC1: &str = "EC PRIVATE KEY";
117const PEM_TAG_CERT: &str = "CERTIFICATE";
118
119const RSA_ENCRYPTION_OID: ObjectIdentifier = ObjectIdentifier::new_unwrap("1.2.840.113549.1.1.1");
120const EC_PUBLIC_KEY_OID: ObjectIdentifier = ObjectIdentifier::new_unwrap("1.2.840.10045.2.1");
121
122const INTERNAL_P12_PASSWORD: &str = "trillium";
126
127fn parse_pem_blocks(input: &[u8]) -> Vec<Pem> {
128 pem::parse_many(input).expect("could not parse PEM input")
129}
130
131fn normalize_key_to_pkcs8_der(input: &[u8]) -> Vec<u8> {
132 let blocks = parse_pem_blocks(input);
133 let key = blocks
134 .iter()
135 .find(|b| matches!(b.tag(), PEM_TAG_PKCS8 | PEM_TAG_PKCS1 | PEM_TAG_SEC1))
136 .expect(
137 "no private key block found in key input (expected PRIVATE KEY, RSA PRIVATE KEY, or \
138 EC PRIVATE KEY)",
139 );
140
141 match key.tag() {
142 PEM_TAG_PKCS8 => key.contents().to_vec(),
143 PEM_TAG_PKCS1 => wrap_pkcs1_in_pkcs8(key.contents()),
144 PEM_TAG_SEC1 => wrap_sec1_in_pkcs8(key.contents()),
145 _ => unreachable!(),
146 }
147}
148
149fn wrap_pkcs1_in_pkcs8(pkcs1_der: &[u8]) -> Vec<u8> {
150 let algorithm = AlgorithmIdentifierRef {
151 oid: RSA_ENCRYPTION_OID,
152 parameters: Some(AnyRef::NULL),
153 };
154 PrivateKeyInfo::new(algorithm, pkcs1_der)
155 .to_der()
156 .expect("could not encode PKCS#1 key as PKCS#8")
157}
158
159fn wrap_sec1_in_pkcs8(sec1_der: &[u8]) -> Vec<u8> {
160 let parsed =
161 sec1::EcPrivateKey::from_der(sec1_der).expect("could not parse SEC1 EC private key");
162 let curve_oid = parsed
163 .parameters
164 .and_then(|p| p.named_curve())
165 .expect("EC private key is missing namedCurve parameters");
166 let curve_param: AnyRef<'_> = (&curve_oid).into();
167 let algorithm = AlgorithmIdentifierRef {
168 oid: EC_PUBLIC_KEY_OID,
169 parameters: Some(curve_param),
170 };
171 PrivateKeyInfo::new(algorithm, sec1_der)
172 .to_der()
173 .expect("could not encode SEC1 key as PKCS#8")
174}
175
176fn encode_cert_chain_pem(cert_chain_der: &[Vec<u8>]) -> Vec<u8> {
177 let blocks: Vec<Pem> = cert_chain_der
178 .iter()
179 .map(|d| Pem::new(PEM_TAG_CERT, d.clone()))
180 .collect();
181 pem::encode_many(&blocks).into_bytes()
182}
183
184fn encode_pkcs8_pem(key_pkcs8_der: &[u8]) -> Vec<u8> {
185 pem::encode(&Pem::new(PEM_TAG_PKCS8, key_pkcs8_der.to_vec())).into_bytes()
186}
187
188fn extract_cert_chain_der(input: &[u8]) -> Vec<Vec<u8>> {
189 let certs: Vec<Vec<u8>> = parse_pem_blocks(input)
190 .into_iter()
191 .filter(|b| b.tag() == PEM_TAG_CERT)
192 .map(|b| b.into_contents())
193 .collect();
194 assert!(
195 !certs.is_empty(),
196 "no CERTIFICATE blocks found in cert input"
197 );
198 certs
199}
200
201fn build_pkcs12_der(cert_chain_der: &[Vec<u8>], key_pkcs8_der: &[u8]) -> Vec<u8> {
202 let leaf = cert_chain_der.first().expect("cert chain was empty");
203 let intermediates: Vec<&[u8]> = cert_chain_der.iter().skip(1).map(Vec::as_slice).collect();
204 let pfx = p12::PFX::new_with_cas(
205 leaf,
206 key_pkcs8_der,
207 &intermediates,
208 INTERNAL_P12_PASSWORD,
209 "",
210 )
211 .expect("could not build PKCS#12 archive from cert and key");
212 pfx.to_der()
213}
214
215impl From<Identity> for NativeTlsAcceptor {
216 fn from(i: Identity) -> Self {
217 native_tls::TlsAcceptor::new(i).unwrap().into()
218 }
219}
220
221impl From<native_tls::TlsAcceptor> for NativeTlsAcceptor {
222 fn from(i: native_tls::TlsAcceptor) -> Self {
223 Self(i.into())
224 }
225}
226
227impl From<TlsAcceptor> for NativeTlsAcceptor {
228 fn from(i: TlsAcceptor) -> Self {
229 Self(i)
230 }
231}
232
233impl From<(&[u8], &str)> for NativeTlsAcceptor {
234 fn from(i: (&[u8], &str)) -> Self {
235 Self::from_pkcs12(i.0, i.1)
236 }
237}
238
239impl<Input> Acceptor<Input> for NativeTlsAcceptor
240where
241 Input: Transport,
242{
243 type Error = Error;
244 type Output = NativeTlsServerTransport<Input>;
245
246 async fn accept(&self, input: Input) -> Result<Self::Output, Self::Error> {
247 self.0.accept(input).await.map(NativeTlsServerTransport)
248 }
249}
250
251#[derive(Debug)]
255pub struct NativeTlsServerTransport<T>(TlsStream<T>);
256
257impl<T: AsyncWrite + AsyncRead + Unpin> AsRef<T> for NativeTlsServerTransport<T> {
258 fn as_ref(&self) -> &T {
259 self.0.get_ref()
260 }
261}
262impl<T: AsyncWrite + AsyncRead + Unpin> AsMut<T> for NativeTlsServerTransport<T> {
263 fn as_mut(&mut self) -> &mut T {
264 self.0.get_mut()
265 }
266}
267
268impl<T> AsRef<TlsStream<T>> for NativeTlsServerTransport<T> {
269 fn as_ref(&self) -> &TlsStream<T> {
270 &self.0
271 }
272}
273impl<T> AsMut<TlsStream<T>> for NativeTlsServerTransport<T> {
274 fn as_mut(&mut self) -> &mut TlsStream<T> {
275 &mut self.0
276 }
277}
278
279impl<T: AsyncRead + AsyncWrite + Unpin> AsyncRead for NativeTlsServerTransport<T> {
280 fn poll_read(
281 mut self: Pin<&mut Self>,
282 cx: &mut Context<'_>,
283 buf: &mut [u8],
284 ) -> Poll<io::Result<usize>> {
285 Pin::new(&mut self.0).poll_read(cx, buf)
286 }
287
288 fn poll_read_vectored(
289 mut self: Pin<&mut Self>,
290 cx: &mut Context<'_>,
291 bufs: &mut [IoSliceMut<'_>],
292 ) -> Poll<io::Result<usize>> {
293 Pin::new(&mut self.0).poll_read_vectored(cx, bufs)
294 }
295}
296
297impl<T: AsyncWrite + AsyncRead + Unpin> AsyncWrite for NativeTlsServerTransport<T> {
298 fn poll_write(
299 mut self: Pin<&mut Self>,
300 cx: &mut Context<'_>,
301 buf: &[u8],
302 ) -> Poll<io::Result<usize>> {
303 Pin::new(&mut self.0).poll_write(cx, buf)
304 }
305
306 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
307 Pin::new(&mut self.0).poll_flush(cx)
308 }
309
310 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
311 Pin::new(&mut self.0).poll_close(cx)
312 }
313
314 fn poll_write_vectored(
315 mut self: Pin<&mut Self>,
316 cx: &mut Context<'_>,
317 bufs: &[IoSlice<'_>],
318 ) -> Poll<io::Result<usize>> {
319 Pin::new(&mut self.0).poll_write_vectored(cx, bufs)
320 }
321}
322
323impl<T: Transport> Transport for NativeTlsServerTransport<T> {
324 fn peer_addr(&self) -> io::Result<Option<SocketAddr>> {
325 self.0.get_ref().peer_addr()
326 }
327
328 }
334
335#[cfg(test)]
336mod tests {
337 use super::{
338 EC_PUBLIC_KEY_OID, RSA_ENCRYPTION_OID, extract_cert_chain_der, normalize_key_to_pkcs8_der,
339 };
340 use pkcs8::PrivateKeyInfo;
341
342 const RSA_CERT: &[u8] = include_bytes!("../tests/fixtures/rsa.crt");
343 const RSA_PKCS1: &[u8] = include_bytes!("../tests/fixtures/rsa-pkcs1.key");
344 const EC_CERT: &[u8] = include_bytes!("../tests/fixtures/ec.crt");
345 const EC_SEC1: &[u8] = include_bytes!("../tests/fixtures/ec-sec1.key");
346 const EC_PKCS8: &[u8] = include_bytes!("../tests/fixtures/ec-pkcs8.key");
347
348 fn parse_pkcs8_der(der: &[u8]) -> PrivateKeyInfo<'_> {
349 PrivateKeyInfo::try_from(der).expect("output not parseable as PKCS#8")
350 }
351
352 #[test]
353 fn pkcs1_wraps_to_pkcs8_with_rsa_oid() {
354 let der = normalize_key_to_pkcs8_der(RSA_PKCS1);
355 assert_eq!(parse_pkcs8_der(&der).algorithm.oid, RSA_ENCRYPTION_OID);
356 }
357
358 #[test]
359 fn sec1_wraps_to_pkcs8_with_ec_oid_and_curve_param() {
360 let der = normalize_key_to_pkcs8_der(EC_SEC1);
361 let pki = parse_pkcs8_der(&der);
362 assert_eq!(pki.algorithm.oid, EC_PUBLIC_KEY_OID);
363 assert!(
364 pki.algorithm.parameters.is_some(),
365 "EC PKCS#8 must carry namedCurve OID in algorithm parameters"
366 );
367 }
368
369 #[test]
370 fn pkcs8_pass_through_preserves_algorithm() {
371 let der = normalize_key_to_pkcs8_der(EC_PKCS8);
372 assert_eq!(parse_pkcs8_der(&der).algorithm.oid, EC_PUBLIC_KEY_OID);
373 }
374
375 #[test]
376 fn cert_extracted_from_concatenated_bundle() {
377 let mut bundle = Vec::new();
378 bundle.extend_from_slice(EC_CERT);
379 bundle.extend_from_slice(EC_SEC1);
380
381 let extracted = extract_cert_chain_der(&bundle);
382 let original: Vec<Vec<u8>> = pem::parse_many(EC_CERT)
383 .unwrap()
384 .into_iter()
385 .map(pem::Pem::into_contents)
386 .collect();
387 assert_eq!(extracted, original);
388 }
389
390 #[test]
391 fn key_extracted_from_concatenated_bundle() {
392 let mut bundle = Vec::new();
393 bundle.extend_from_slice(RSA_CERT);
394 bundle.extend_from_slice(RSA_PKCS1);
395 let der = normalize_key_to_pkcs8_der(&bundle);
396 assert_eq!(parse_pkcs8_der(&der).algorithm.oid, RSA_ENCRYPTION_OID);
397 }
398}