Skip to main content

trillium_native_tls/
server.rs

1use crate::Identity;
2use async_native_tls::{Error, TlsAcceptor, TlsStream};
3use pem::Pem;
4use pkcs8::{
5    AlgorithmIdentifierRef, ObjectIdentifier, PrivateKeyInfoRef,
6    der::{
7        Decode, Encode,
8        asn1::{AnyRef, OctetStringRef},
9    },
10};
11use std::{
12    io::{self, IoSlice, IoSliceMut},
13    net::SocketAddr,
14    pin::Pin,
15    task::{Context, Poll},
16};
17use trillium_server_common::{Acceptor, AsyncRead, AsyncWrite, Transport};
18
19/// trillium [`Acceptor`] for native-tls
20
21#[derive(Clone, Debug)]
22pub struct NativeTlsAcceptor(TlsAcceptor);
23
24impl NativeTlsAcceptor {
25    /// constructs a NativeTlsAcceptor from a [`native_tls::TlsAcceptor`],
26    /// an [`async_native_tls::TlsAcceptor`], or an [`Identity`]
27    pub fn new(t: impl Into<Self>) -> Self {
28        t.into()
29    }
30
31    /// Construct a `NativeTlsAcceptor` from a PEM-encoded certificate chain
32    /// and a PEM-encoded private key.
33    ///
34    /// This is the recommended entrypoint and matches the input format used by
35    /// `trillium-rustls` and `trillium-openssl`. The cert input may contain one
36    /// or more `CERTIFICATE` blocks (the leaf followed by any intermediates).
37    /// The key input is accepted in any of the three common PEM key forms:
38    ///
39    /// - `-----BEGIN PRIVATE KEY-----` (PKCS#8)
40    /// - `-----BEGIN RSA PRIVATE KEY-----` (PKCS#1)
41    /// - `-----BEGIN EC PRIVATE KEY-----` (SEC1)
42    ///
43    /// Either argument may also be a single concatenated bundle containing
44    /// both the cert chain and the key; the relevant blocks are extracted from
45    /// each input. Encrypted keys are not supported here — decrypt first or
46    /// use [`Self::from_pkcs12`].
47    ///
48    /// Internally we first try [`Identity::from_pkcs8`] with the normalized
49    /// PEM inputs; on backends that reject that import path (notably macOS
50    /// Secure Transport, which refuses EC keys this way with
51    /// `errSecUnknownFormat`), we fall back to packaging the cert chain and
52    /// key into an in-memory PKCS#12 archive and calling
53    /// [`Identity::from_pkcs12`]. The fallback only runs when the first
54    /// attempt fails, so OpenSSL-backed platforms never hit it.
55    ///
56    /// **Windows + EC keys:** SChannel rejects EC keys via both paths — its
57    /// PKCS#8 PEM import is strict, and our fallback archive omits the
58    /// `LocalKeyId` attribute SChannel uses to pair cert and key. For EC
59    /// keys on Windows, prefer `trillium-rustls`, or supply a pre-built
60    /// PKCS#12 archive (e.g. from `openssl pkcs12 -export`) via
61    /// [`Self::from_pkcs12`]. RSA keys work on Windows.
62    ///
63    /// # Example
64    ///
65    /// ```rust,no_run
66    /// use trillium_native_tls::NativeTlsAcceptor;
67    /// const CERT: &[u8] = include_bytes!("../tests/fixtures/rsa.crt");
68    /// const KEY: &[u8] = include_bytes!("../tests/fixtures/rsa-pkcs8.key");
69    /// let acceptor = NativeTlsAcceptor::from_cert_and_key(CERT, KEY);
70    /// ```
71    pub fn from_cert_and_key(cert: &[u8], key: &[u8]) -> Self {
72        let cert_chain_der = extract_cert_chain_der(cert);
73        let key_pkcs8_der = normalize_key_to_pkcs8_der(key);
74
75        let cert_chain_pem = encode_cert_chain_pem(&cert_chain_der);
76        let key_pkcs8_pem = encode_pkcs8_pem(&key_pkcs8_der);
77        let pkcs8_err = match Identity::from_pkcs8(&cert_chain_pem, &key_pkcs8_pem) {
78            Ok(identity) => return identity.into(),
79            Err(e) => e,
80        };
81
82        let p12_der = build_pkcs12_der(&cert_chain_der, &key_pkcs8_der);
83        match Identity::from_pkcs12(&p12_der, INTERNAL_P12_PASSWORD) {
84            Ok(identity) => identity.into(),
85            Err(p12_err) => panic!(
86                "could not build Identity from provided cert and key.\n  from_pkcs8 error: \
87                 {pkcs8_err}\n  from_pkcs12 fallback error: {p12_err}"
88            ),
89        }
90    }
91
92    /// Construct a `NativeTlsAcceptor` from a PKCS#12 archive and password.
93    ///
94    /// PKCS#12 (`.p12`/`.pfx`) bundles a certificate chain and a private key
95    /// in a single password-protected archive. Prefer
96    /// [`Self::from_cert_and_key`] when you have separate cert and key PEM
97    /// files, which is by far the more common provisioning format.
98    pub fn from_pkcs12(der: &[u8], password: &str) -> Self {
99        Identity::from_pkcs12(der, password)
100            .expect("could not build Identity from provided pkcs12 key and password")
101            .into()
102    }
103
104    /// Construct a `NativeTlsAcceptor` directly from PKCS#8 PEM cert and key
105    /// inputs, without normalization.
106    ///
107    /// Prefer [`Self::from_cert_and_key`], which accepts the same inputs plus
108    /// PKCS#1 and SEC1 keys. This constructor is retained for backwards
109    /// compatibility and forwards directly to [`Identity::from_pkcs8`].
110    pub fn from_pkcs8(pem: &[u8], key: &[u8]) -> Self {
111        Identity::from_pkcs8(pem, key)
112            .expect("could not build Identity from provided pem and key")
113            .into()
114    }
115}
116
117const PEM_TAG_PKCS8: &str = "PRIVATE KEY";
118const PEM_TAG_PKCS1: &str = "RSA PRIVATE KEY";
119const PEM_TAG_SEC1: &str = "EC PRIVATE KEY";
120const PEM_TAG_CERT: &str = "CERTIFICATE";
121
122const RSA_ENCRYPTION_OID: ObjectIdentifier = ObjectIdentifier::new_unwrap("1.2.840.113549.1.1.1");
123const EC_PUBLIC_KEY_OID: ObjectIdentifier = ObjectIdentifier::new_unwrap("1.2.840.10045.2.1");
124
125// Password used for the in-memory PKCS#12 archive built by `from_cert_and_key`.
126// The archive lives only inside this process, so the password is just a
127// well-known value that lets us round-trip through the PKCS#12 import path.
128const INTERNAL_P12_PASSWORD: &str = "trillium";
129
130fn parse_pem_blocks(input: &[u8]) -> Vec<Pem> {
131    pem::parse_many(input).expect("could not parse PEM input")
132}
133
134fn normalize_key_to_pkcs8_der(input: &[u8]) -> Vec<u8> {
135    let blocks = parse_pem_blocks(input);
136    let key = blocks
137        .iter()
138        .find(|b| matches!(b.tag(), PEM_TAG_PKCS8 | PEM_TAG_PKCS1 | PEM_TAG_SEC1))
139        .expect(
140            "no private key block found in key input (expected PRIVATE KEY, RSA PRIVATE KEY, or \
141             EC PRIVATE KEY)",
142        );
143
144    match key.tag() {
145        PEM_TAG_PKCS8 => key.contents().to_vec(),
146        PEM_TAG_PKCS1 => wrap_pkcs1_in_pkcs8(key.contents()),
147        PEM_TAG_SEC1 => wrap_sec1_in_pkcs8(key.contents()),
148        _ => unreachable!(),
149    }
150}
151
152fn wrap_pkcs1_in_pkcs8(pkcs1_der: &[u8]) -> Vec<u8> {
153    let algorithm = AlgorithmIdentifierRef {
154        oid: RSA_ENCRYPTION_OID,
155        parameters: Some(AnyRef::NULL),
156    };
157    let private_key =
158        OctetStringRef::new(pkcs1_der).expect("could not wrap PKCS#1 key as OCTET STRING");
159    PrivateKeyInfoRef::new(algorithm, private_key)
160        .to_der()
161        .expect("could not encode PKCS#1 key as PKCS#8")
162}
163
164fn wrap_sec1_in_pkcs8(sec1_der: &[u8]) -> Vec<u8> {
165    let parsed =
166        sec1::EcPrivateKey::from_der(sec1_der).expect("could not parse SEC1 EC private key");
167    let curve_oid = parsed
168        .parameters
169        .and_then(|p| p.named_curve())
170        .expect("EC private key is missing namedCurve parameters");
171    let curve_param: AnyRef<'_> = (&curve_oid).into();
172    let algorithm = AlgorithmIdentifierRef {
173        oid: EC_PUBLIC_KEY_OID,
174        parameters: Some(curve_param),
175    };
176    let private_key =
177        OctetStringRef::new(sec1_der).expect("could not wrap SEC1 key as OCTET STRING");
178    PrivateKeyInfoRef::new(algorithm, private_key)
179        .to_der()
180        .expect("could not encode SEC1 key as PKCS#8")
181}
182
183fn encode_cert_chain_pem(cert_chain_der: &[Vec<u8>]) -> Vec<u8> {
184    let blocks: Vec<Pem> = cert_chain_der
185        .iter()
186        .map(|d| Pem::new(PEM_TAG_CERT, d.clone()))
187        .collect();
188    pem::encode_many(&blocks).into_bytes()
189}
190
191fn encode_pkcs8_pem(key_pkcs8_der: &[u8]) -> Vec<u8> {
192    pem::encode(&Pem::new(PEM_TAG_PKCS8, key_pkcs8_der.to_vec())).into_bytes()
193}
194
195fn extract_cert_chain_der(input: &[u8]) -> Vec<Vec<u8>> {
196    let certs: Vec<Vec<u8>> = parse_pem_blocks(input)
197        .into_iter()
198        .filter(|b| b.tag() == PEM_TAG_CERT)
199        .map(|b| b.into_contents())
200        .collect();
201    assert!(
202        !certs.is_empty(),
203        "no CERTIFICATE blocks found in cert input"
204    );
205    certs
206}
207
208fn build_pkcs12_der(cert_chain_der: &[Vec<u8>], key_pkcs8_der: &[u8]) -> Vec<u8> {
209    let leaf = cert_chain_der.first().expect("cert chain was empty");
210    let intermediates: Vec<&[u8]> = cert_chain_der.iter().skip(1).map(Vec::as_slice).collect();
211    let pfx = p12::PFX::new_with_cas(
212        leaf,
213        key_pkcs8_der,
214        &intermediates,
215        INTERNAL_P12_PASSWORD,
216        "",
217    )
218    .expect("could not build PKCS#12 archive from cert and key");
219    pfx.to_der()
220}
221
222impl From<Identity> for NativeTlsAcceptor {
223    fn from(i: Identity) -> Self {
224        native_tls::TlsAcceptor::new(i).unwrap().into()
225    }
226}
227
228impl From<native_tls::TlsAcceptor> for NativeTlsAcceptor {
229    fn from(i: native_tls::TlsAcceptor) -> Self {
230        Self(i.into())
231    }
232}
233
234impl From<TlsAcceptor> for NativeTlsAcceptor {
235    fn from(i: TlsAcceptor) -> Self {
236        Self(i)
237    }
238}
239
240impl From<(&[u8], &str)> for NativeTlsAcceptor {
241    fn from(i: (&[u8], &str)) -> Self {
242        Self::from_pkcs12(i.0, i.1)
243    }
244}
245
246impl<Input> Acceptor<Input> for NativeTlsAcceptor
247where
248    Input: Transport,
249{
250    type Error = Error;
251    type Output = NativeTlsServerTransport<Input>;
252
253    async fn accept(&self, input: Input) -> Result<Self::Output, Self::Error> {
254        self.0.accept(input).await.map(NativeTlsServerTransport)
255    }
256}
257
258/// Server Tls Transport
259///
260/// A wrapper type around [`TlsStream`] that also implements [`Transport`]
261#[derive(Debug)]
262pub struct NativeTlsServerTransport<T>(TlsStream<T>);
263
264impl<T: AsyncWrite + AsyncRead + Unpin> AsRef<T> for NativeTlsServerTransport<T> {
265    fn as_ref(&self) -> &T {
266        self.0.get_ref()
267    }
268}
269impl<T: AsyncWrite + AsyncRead + Unpin> AsMut<T> for NativeTlsServerTransport<T> {
270    fn as_mut(&mut self) -> &mut T {
271        self.0.get_mut()
272    }
273}
274
275impl<T> AsRef<TlsStream<T>> for NativeTlsServerTransport<T> {
276    fn as_ref(&self) -> &TlsStream<T> {
277        &self.0
278    }
279}
280impl<T> AsMut<TlsStream<T>> for NativeTlsServerTransport<T> {
281    fn as_mut(&mut self) -> &mut TlsStream<T> {
282        &mut self.0
283    }
284}
285
286impl<T: AsyncRead + AsyncWrite + Unpin> AsyncRead for NativeTlsServerTransport<T> {
287    fn poll_read(
288        mut self: Pin<&mut Self>,
289        cx: &mut Context<'_>,
290        buf: &mut [u8],
291    ) -> Poll<io::Result<usize>> {
292        Pin::new(&mut self.0).poll_read(cx, buf)
293    }
294
295    fn poll_read_vectored(
296        mut self: Pin<&mut Self>,
297        cx: &mut Context<'_>,
298        bufs: &mut [IoSliceMut<'_>],
299    ) -> Poll<io::Result<usize>> {
300        Pin::new(&mut self.0).poll_read_vectored(cx, bufs)
301    }
302}
303
304impl<T: AsyncWrite + AsyncRead + Unpin> AsyncWrite for NativeTlsServerTransport<T> {
305    fn poll_write(
306        mut self: Pin<&mut Self>,
307        cx: &mut Context<'_>,
308        buf: &[u8],
309    ) -> Poll<io::Result<usize>> {
310        Pin::new(&mut self.0).poll_write(cx, buf)
311    }
312
313    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
314        Pin::new(&mut self.0).poll_flush(cx)
315    }
316
317    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
318        Pin::new(&mut self.0).poll_close(cx)
319    }
320
321    fn poll_write_vectored(
322        mut self: Pin<&mut Self>,
323        cx: &mut Context<'_>,
324        bufs: &[IoSlice<'_>],
325    ) -> Poll<io::Result<usize>> {
326        Pin::new(&mut self.0).poll_write_vectored(cx, bufs)
327    }
328}
329
330impl<T: Transport> Transport for NativeTlsServerTransport<T> {
331    fn peer_addr(&self) -> io::Result<Option<SocketAddr>> {
332        self.0.get_ref().peer_addr()
333    }
334
335    // `negotiated_alpn` is left at the trait default (`None`). Server-side ALPN advertisement in
336    // `native-tls` lives behind the `alpn-accept` cargo feature, which `async-native-tls` 0.6 does
337    // not enable, and the wrapper's `TlsStream` does not expose `negotiated_alpn` either — so
338    // `trillium-native-tls` cannot perform ALPN-based h2 dispatch today. Revisit once the upstream
339    // wrapper grows the missing surface.
340}
341
342#[cfg(test)]
343mod tests {
344    use super::{
345        EC_PUBLIC_KEY_OID, RSA_ENCRYPTION_OID, extract_cert_chain_der, normalize_key_to_pkcs8_der,
346    };
347    use pkcs8::PrivateKeyInfoRef;
348
349    const RSA_CERT: &[u8] = include_bytes!("../tests/fixtures/rsa.crt");
350    const RSA_PKCS1: &[u8] = include_bytes!("../tests/fixtures/rsa-pkcs1.key");
351    const EC_CERT: &[u8] = include_bytes!("../tests/fixtures/ec.crt");
352    const EC_SEC1: &[u8] = include_bytes!("../tests/fixtures/ec-sec1.key");
353    const EC_PKCS8: &[u8] = include_bytes!("../tests/fixtures/ec-pkcs8.key");
354
355    fn parse_pkcs8_der(der: &[u8]) -> PrivateKeyInfoRef<'_> {
356        PrivateKeyInfoRef::try_from(der).expect("output not parseable as PKCS#8")
357    }
358
359    #[test]
360    fn pkcs1_wraps_to_pkcs8_with_rsa_oid() {
361        let der = normalize_key_to_pkcs8_der(RSA_PKCS1);
362        assert_eq!(parse_pkcs8_der(&der).algorithm.oid, RSA_ENCRYPTION_OID);
363    }
364
365    #[test]
366    fn sec1_wraps_to_pkcs8_with_ec_oid_and_curve_param() {
367        let der = normalize_key_to_pkcs8_der(EC_SEC1);
368        let pki = parse_pkcs8_der(&der);
369        assert_eq!(pki.algorithm.oid, EC_PUBLIC_KEY_OID);
370        assert!(
371            pki.algorithm.parameters.is_some(),
372            "EC PKCS#8 must carry namedCurve OID in algorithm parameters"
373        );
374    }
375
376    #[test]
377    fn pkcs8_pass_through_preserves_algorithm() {
378        let der = normalize_key_to_pkcs8_der(EC_PKCS8);
379        assert_eq!(parse_pkcs8_der(&der).algorithm.oid, EC_PUBLIC_KEY_OID);
380    }
381
382    #[test]
383    fn cert_extracted_from_concatenated_bundle() {
384        let mut bundle = Vec::new();
385        bundle.extend_from_slice(EC_CERT);
386        bundle.extend_from_slice(EC_SEC1);
387
388        let extracted = extract_cert_chain_der(&bundle);
389        let original: Vec<Vec<u8>> = pem::parse_many(EC_CERT)
390            .unwrap()
391            .into_iter()
392            .map(pem::Pem::into_contents)
393            .collect();
394        assert_eq!(extracted, original);
395    }
396
397    #[test]
398    fn key_extracted_from_concatenated_bundle() {
399        let mut bundle = Vec::new();
400        bundle.extend_from_slice(RSA_CERT);
401        bundle.extend_from_slice(RSA_PKCS1);
402        let der = normalize_key_to_pkcs8_der(&bundle);
403        assert_eq!(parse_pkcs8_der(&der).algorithm.oid, RSA_ENCRYPTION_OID);
404    }
405}