1use crate::Identity;
2use async_native_tls::{Error, TlsAcceptor, TlsStream};
3use pem::Pem;
4use pkcs8::{
5 AlgorithmIdentifierRef, ObjectIdentifier, PrivateKeyInfoRef,
6 der::{
7 Decode, Encode,
8 asn1::{AnyRef, OctetStringRef},
9 },
10};
11use std::{
12 io::{self, IoSlice, IoSliceMut},
13 net::SocketAddr,
14 pin::Pin,
15 task::{Context, Poll},
16};
17use trillium_server_common::{Acceptor, AsyncRead, AsyncWrite, Transport};
18
19#[derive(Clone, Debug)]
22pub struct NativeTlsAcceptor(TlsAcceptor);
23
24impl NativeTlsAcceptor {
25 pub fn new(t: impl Into<Self>) -> Self {
28 t.into()
29 }
30
31 pub fn from_cert_and_key(cert: &[u8], key: &[u8]) -> Self {
72 let cert_chain_der = extract_cert_chain_der(cert);
73 let key_pkcs8_der = normalize_key_to_pkcs8_der(key);
74
75 let cert_chain_pem = encode_cert_chain_pem(&cert_chain_der);
76 let key_pkcs8_pem = encode_pkcs8_pem(&key_pkcs8_der);
77 let pkcs8_err = match Identity::from_pkcs8(&cert_chain_pem, &key_pkcs8_pem) {
78 Ok(identity) => return identity.into(),
79 Err(e) => e,
80 };
81
82 let p12_der = build_pkcs12_der(&cert_chain_der, &key_pkcs8_der);
83 match Identity::from_pkcs12(&p12_der, INTERNAL_P12_PASSWORD) {
84 Ok(identity) => identity.into(),
85 Err(p12_err) => panic!(
86 "could not build Identity from provided cert and key.\n from_pkcs8 error: \
87 {pkcs8_err}\n from_pkcs12 fallback error: {p12_err}"
88 ),
89 }
90 }
91
92 pub fn from_pkcs12(der: &[u8], password: &str) -> Self {
99 Identity::from_pkcs12(der, password)
100 .expect("could not build Identity from provided pkcs12 key and password")
101 .into()
102 }
103
104 pub fn from_pkcs8(pem: &[u8], key: &[u8]) -> Self {
111 Identity::from_pkcs8(pem, key)
112 .expect("could not build Identity from provided pem and key")
113 .into()
114 }
115}
116
117const PEM_TAG_PKCS8: &str = "PRIVATE KEY";
118const PEM_TAG_PKCS1: &str = "RSA PRIVATE KEY";
119const PEM_TAG_SEC1: &str = "EC PRIVATE KEY";
120const PEM_TAG_CERT: &str = "CERTIFICATE";
121
122const RSA_ENCRYPTION_OID: ObjectIdentifier = ObjectIdentifier::new_unwrap("1.2.840.113549.1.1.1");
123const EC_PUBLIC_KEY_OID: ObjectIdentifier = ObjectIdentifier::new_unwrap("1.2.840.10045.2.1");
124
125const INTERNAL_P12_PASSWORD: &str = "trillium";
129
130fn parse_pem_blocks(input: &[u8]) -> Vec<Pem> {
131 pem::parse_many(input).expect("could not parse PEM input")
132}
133
134fn normalize_key_to_pkcs8_der(input: &[u8]) -> Vec<u8> {
135 let blocks = parse_pem_blocks(input);
136 let key = blocks
137 .iter()
138 .find(|b| matches!(b.tag(), PEM_TAG_PKCS8 | PEM_TAG_PKCS1 | PEM_TAG_SEC1))
139 .expect(
140 "no private key block found in key input (expected PRIVATE KEY, RSA PRIVATE KEY, or \
141 EC PRIVATE KEY)",
142 );
143
144 match key.tag() {
145 PEM_TAG_PKCS8 => key.contents().to_vec(),
146 PEM_TAG_PKCS1 => wrap_pkcs1_in_pkcs8(key.contents()),
147 PEM_TAG_SEC1 => wrap_sec1_in_pkcs8(key.contents()),
148 _ => unreachable!(),
149 }
150}
151
152fn wrap_pkcs1_in_pkcs8(pkcs1_der: &[u8]) -> Vec<u8> {
153 let algorithm = AlgorithmIdentifierRef {
154 oid: RSA_ENCRYPTION_OID,
155 parameters: Some(AnyRef::NULL),
156 };
157 let private_key =
158 OctetStringRef::new(pkcs1_der).expect("could not wrap PKCS#1 key as OCTET STRING");
159 PrivateKeyInfoRef::new(algorithm, private_key)
160 .to_der()
161 .expect("could not encode PKCS#1 key as PKCS#8")
162}
163
164fn wrap_sec1_in_pkcs8(sec1_der: &[u8]) -> Vec<u8> {
165 let parsed =
166 sec1::EcPrivateKey::from_der(sec1_der).expect("could not parse SEC1 EC private key");
167 let curve_oid = parsed
168 .parameters
169 .and_then(|p| p.named_curve())
170 .expect("EC private key is missing namedCurve parameters");
171 let curve_param: AnyRef<'_> = (&curve_oid).into();
172 let algorithm = AlgorithmIdentifierRef {
173 oid: EC_PUBLIC_KEY_OID,
174 parameters: Some(curve_param),
175 };
176 let private_key =
177 OctetStringRef::new(sec1_der).expect("could not wrap SEC1 key as OCTET STRING");
178 PrivateKeyInfoRef::new(algorithm, private_key)
179 .to_der()
180 .expect("could not encode SEC1 key as PKCS#8")
181}
182
183fn encode_cert_chain_pem(cert_chain_der: &[Vec<u8>]) -> Vec<u8> {
184 let blocks: Vec<Pem> = cert_chain_der
185 .iter()
186 .map(|d| Pem::new(PEM_TAG_CERT, d.clone()))
187 .collect();
188 pem::encode_many(&blocks).into_bytes()
189}
190
191fn encode_pkcs8_pem(key_pkcs8_der: &[u8]) -> Vec<u8> {
192 pem::encode(&Pem::new(PEM_TAG_PKCS8, key_pkcs8_der.to_vec())).into_bytes()
193}
194
195fn extract_cert_chain_der(input: &[u8]) -> Vec<Vec<u8>> {
196 let certs: Vec<Vec<u8>> = parse_pem_blocks(input)
197 .into_iter()
198 .filter(|b| b.tag() == PEM_TAG_CERT)
199 .map(|b| b.into_contents())
200 .collect();
201 assert!(
202 !certs.is_empty(),
203 "no CERTIFICATE blocks found in cert input"
204 );
205 certs
206}
207
208fn build_pkcs12_der(cert_chain_der: &[Vec<u8>], key_pkcs8_der: &[u8]) -> Vec<u8> {
209 let leaf = cert_chain_der.first().expect("cert chain was empty");
210 let intermediates: Vec<&[u8]> = cert_chain_der.iter().skip(1).map(Vec::as_slice).collect();
211 let pfx = p12::PFX::new_with_cas(
212 leaf,
213 key_pkcs8_der,
214 &intermediates,
215 INTERNAL_P12_PASSWORD,
216 "",
217 )
218 .expect("could not build PKCS#12 archive from cert and key");
219 pfx.to_der()
220}
221
222impl From<Identity> for NativeTlsAcceptor {
223 fn from(i: Identity) -> Self {
224 native_tls::TlsAcceptor::new(i).unwrap().into()
225 }
226}
227
228impl From<native_tls::TlsAcceptor> for NativeTlsAcceptor {
229 fn from(i: native_tls::TlsAcceptor) -> Self {
230 Self(i.into())
231 }
232}
233
234impl From<TlsAcceptor> for NativeTlsAcceptor {
235 fn from(i: TlsAcceptor) -> Self {
236 Self(i)
237 }
238}
239
240impl From<(&[u8], &str)> for NativeTlsAcceptor {
241 fn from(i: (&[u8], &str)) -> Self {
242 Self::from_pkcs12(i.0, i.1)
243 }
244}
245
246impl<Input> Acceptor<Input> for NativeTlsAcceptor
247where
248 Input: Transport,
249{
250 type Error = Error;
251 type Output = NativeTlsServerTransport<Input>;
252
253 async fn accept(&self, input: Input) -> Result<Self::Output, Self::Error> {
254 self.0.accept(input).await.map(NativeTlsServerTransport)
255 }
256}
257
258#[derive(Debug)]
262pub struct NativeTlsServerTransport<T>(TlsStream<T>);
263
264impl<T: AsyncWrite + AsyncRead + Unpin> AsRef<T> for NativeTlsServerTransport<T> {
265 fn as_ref(&self) -> &T {
266 self.0.get_ref()
267 }
268}
269impl<T: AsyncWrite + AsyncRead + Unpin> AsMut<T> for NativeTlsServerTransport<T> {
270 fn as_mut(&mut self) -> &mut T {
271 self.0.get_mut()
272 }
273}
274
275impl<T> AsRef<TlsStream<T>> for NativeTlsServerTransport<T> {
276 fn as_ref(&self) -> &TlsStream<T> {
277 &self.0
278 }
279}
280impl<T> AsMut<TlsStream<T>> for NativeTlsServerTransport<T> {
281 fn as_mut(&mut self) -> &mut TlsStream<T> {
282 &mut self.0
283 }
284}
285
286impl<T: AsyncRead + AsyncWrite + Unpin> AsyncRead for NativeTlsServerTransport<T> {
287 fn poll_read(
288 mut self: Pin<&mut Self>,
289 cx: &mut Context<'_>,
290 buf: &mut [u8],
291 ) -> Poll<io::Result<usize>> {
292 Pin::new(&mut self.0).poll_read(cx, buf)
293 }
294
295 fn poll_read_vectored(
296 mut self: Pin<&mut Self>,
297 cx: &mut Context<'_>,
298 bufs: &mut [IoSliceMut<'_>],
299 ) -> Poll<io::Result<usize>> {
300 Pin::new(&mut self.0).poll_read_vectored(cx, bufs)
301 }
302}
303
304impl<T: AsyncWrite + AsyncRead + Unpin> AsyncWrite for NativeTlsServerTransport<T> {
305 fn poll_write(
306 mut self: Pin<&mut Self>,
307 cx: &mut Context<'_>,
308 buf: &[u8],
309 ) -> Poll<io::Result<usize>> {
310 Pin::new(&mut self.0).poll_write(cx, buf)
311 }
312
313 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
314 Pin::new(&mut self.0).poll_flush(cx)
315 }
316
317 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
318 Pin::new(&mut self.0).poll_close(cx)
319 }
320
321 fn poll_write_vectored(
322 mut self: Pin<&mut Self>,
323 cx: &mut Context<'_>,
324 bufs: &[IoSlice<'_>],
325 ) -> Poll<io::Result<usize>> {
326 Pin::new(&mut self.0).poll_write_vectored(cx, bufs)
327 }
328}
329
330impl<T: Transport> Transport for NativeTlsServerTransport<T> {
331 fn peer_addr(&self) -> io::Result<Option<SocketAddr>> {
332 self.0.get_ref().peer_addr()
333 }
334
335 }
341
342#[cfg(test)]
343mod tests {
344 use super::{
345 EC_PUBLIC_KEY_OID, RSA_ENCRYPTION_OID, extract_cert_chain_der, normalize_key_to_pkcs8_der,
346 };
347 use pkcs8::PrivateKeyInfoRef;
348
349 const RSA_CERT: &[u8] = include_bytes!("../tests/fixtures/rsa.crt");
350 const RSA_PKCS1: &[u8] = include_bytes!("../tests/fixtures/rsa-pkcs1.key");
351 const EC_CERT: &[u8] = include_bytes!("../tests/fixtures/ec.crt");
352 const EC_SEC1: &[u8] = include_bytes!("../tests/fixtures/ec-sec1.key");
353 const EC_PKCS8: &[u8] = include_bytes!("../tests/fixtures/ec-pkcs8.key");
354
355 fn parse_pkcs8_der(der: &[u8]) -> PrivateKeyInfoRef<'_> {
356 PrivateKeyInfoRef::try_from(der).expect("output not parseable as PKCS#8")
357 }
358
359 #[test]
360 fn pkcs1_wraps_to_pkcs8_with_rsa_oid() {
361 let der = normalize_key_to_pkcs8_der(RSA_PKCS1);
362 assert_eq!(parse_pkcs8_der(&der).algorithm.oid, RSA_ENCRYPTION_OID);
363 }
364
365 #[test]
366 fn sec1_wraps_to_pkcs8_with_ec_oid_and_curve_param() {
367 let der = normalize_key_to_pkcs8_der(EC_SEC1);
368 let pki = parse_pkcs8_der(&der);
369 assert_eq!(pki.algorithm.oid, EC_PUBLIC_KEY_OID);
370 assert!(
371 pki.algorithm.parameters.is_some(),
372 "EC PKCS#8 must carry namedCurve OID in algorithm parameters"
373 );
374 }
375
376 #[test]
377 fn pkcs8_pass_through_preserves_algorithm() {
378 let der = normalize_key_to_pkcs8_der(EC_PKCS8);
379 assert_eq!(parse_pkcs8_der(&der).algorithm.oid, EC_PUBLIC_KEY_OID);
380 }
381
382 #[test]
383 fn cert_extracted_from_concatenated_bundle() {
384 let mut bundle = Vec::new();
385 bundle.extend_from_slice(EC_CERT);
386 bundle.extend_from_slice(EC_SEC1);
387
388 let extracted = extract_cert_chain_der(&bundle);
389 let original: Vec<Vec<u8>> = pem::parse_many(EC_CERT)
390 .unwrap()
391 .into_iter()
392 .map(pem::Pem::into_contents)
393 .collect();
394 assert_eq!(extracted, original);
395 }
396
397 #[test]
398 fn key_extracted_from_concatenated_bundle() {
399 let mut bundle = Vec::new();
400 bundle.extend_from_slice(RSA_CERT);
401 bundle.extend_from_slice(RSA_PKCS1);
402 let der = normalize_key_to_pkcs8_der(&bundle);
403 assert_eq!(parse_pkcs8_der(&der).algorithm.oid, RSA_ENCRYPTION_OID);
404 }
405}