Skip to main content

trillium_native_tls/
server.rs

1use crate::Identity;
2use async_native_tls::{Error, TlsAcceptor, TlsStream};
3use pem::Pem;
4use pkcs8::{
5    AlgorithmIdentifierRef, ObjectIdentifier, PrivateKeyInfo,
6    der::{Decode, Encode, asn1::AnyRef},
7};
8use std::{
9    io::{self, IoSlice, IoSliceMut},
10    net::SocketAddr,
11    pin::Pin,
12    task::{Context, Poll},
13};
14use trillium_server_common::{Acceptor, AsyncRead, AsyncWrite, Transport};
15
16/// trillium [`Acceptor`] for native-tls
17
18#[derive(Clone, Debug)]
19pub struct NativeTlsAcceptor(TlsAcceptor);
20
21impl NativeTlsAcceptor {
22    /// constructs a NativeTlsAcceptor from a [`native_tls::TlsAcceptor`],
23    /// an [`async_native_tls::TlsAcceptor`], or an [`Identity`]
24    pub fn new(t: impl Into<Self>) -> Self {
25        t.into()
26    }
27
28    /// Construct a `NativeTlsAcceptor` from a PEM-encoded certificate chain
29    /// and a PEM-encoded private key.
30    ///
31    /// This is the recommended entrypoint and matches the input format used by
32    /// `trillium-rustls` and `trillium-openssl`. The cert input may contain one
33    /// or more `CERTIFICATE` blocks (the leaf followed by any intermediates).
34    /// The key input is accepted in any of the three common PEM key forms:
35    ///
36    /// - `-----BEGIN PRIVATE KEY-----` (PKCS#8)
37    /// - `-----BEGIN RSA PRIVATE KEY-----` (PKCS#1)
38    /// - `-----BEGIN EC PRIVATE KEY-----` (SEC1)
39    ///
40    /// Either argument may also be a single concatenated bundle containing
41    /// both the cert chain and the key; the relevant blocks are extracted from
42    /// each input. Encrypted keys are not supported here — decrypt first or
43    /// use [`Self::from_pkcs12`].
44    ///
45    /// Internally we first try [`Identity::from_pkcs8`] with the normalized
46    /// PEM inputs; on backends that reject that import path (notably macOS
47    /// Secure Transport, which refuses EC keys this way with
48    /// `errSecUnknownFormat`), we fall back to packaging the cert chain and
49    /// key into an in-memory PKCS#12 archive and calling
50    /// [`Identity::from_pkcs12`]. The fallback only runs when the first
51    /// attempt fails, so OpenSSL-backed platforms never hit it.
52    ///
53    /// **Windows + EC keys:** SChannel rejects EC keys via both paths — its
54    /// PKCS#8 PEM import is strict, and our fallback archive omits the
55    /// `LocalKeyId` attribute SChannel uses to pair cert and key. For EC
56    /// keys on Windows, prefer `trillium-rustls`, or supply a pre-built
57    /// PKCS#12 archive (e.g. from `openssl pkcs12 -export`) via
58    /// [`Self::from_pkcs12`]. RSA keys work on Windows.
59    ///
60    /// # Example
61    ///
62    /// ```rust,no_run
63    /// use trillium_native_tls::NativeTlsAcceptor;
64    /// const CERT: &[u8] = include_bytes!("../tests/fixtures/rsa.crt");
65    /// const KEY: &[u8] = include_bytes!("../tests/fixtures/rsa-pkcs8.key");
66    /// let acceptor = NativeTlsAcceptor::from_cert_and_key(CERT, KEY);
67    /// ```
68    pub fn from_cert_and_key(cert: &[u8], key: &[u8]) -> Self {
69        let cert_chain_der = extract_cert_chain_der(cert);
70        let key_pkcs8_der = normalize_key_to_pkcs8_der(key);
71
72        let cert_chain_pem = encode_cert_chain_pem(&cert_chain_der);
73        let key_pkcs8_pem = encode_pkcs8_pem(&key_pkcs8_der);
74        let pkcs8_err = match Identity::from_pkcs8(&cert_chain_pem, &key_pkcs8_pem) {
75            Ok(identity) => return identity.into(),
76            Err(e) => e,
77        };
78
79        let p12_der = build_pkcs12_der(&cert_chain_der, &key_pkcs8_der);
80        match Identity::from_pkcs12(&p12_der, INTERNAL_P12_PASSWORD) {
81            Ok(identity) => identity.into(),
82            Err(p12_err) => panic!(
83                "could not build Identity from provided cert and key.\n  from_pkcs8 error: \
84                 {pkcs8_err}\n  from_pkcs12 fallback error: {p12_err}"
85            ),
86        }
87    }
88
89    /// Construct a `NativeTlsAcceptor` from a PKCS#12 archive and password.
90    ///
91    /// PKCS#12 (`.p12`/`.pfx`) bundles a certificate chain and a private key
92    /// in a single password-protected archive. Prefer
93    /// [`Self::from_cert_and_key`] when you have separate cert and key PEM
94    /// files, which is by far the more common provisioning format.
95    pub fn from_pkcs12(der: &[u8], password: &str) -> Self {
96        Identity::from_pkcs12(der, password)
97            .expect("could not build Identity from provided pkcs12 key and password")
98            .into()
99    }
100
101    /// Construct a `NativeTlsAcceptor` directly from PKCS#8 PEM cert and key
102    /// inputs, without normalization.
103    ///
104    /// Prefer [`Self::from_cert_and_key`], which accepts the same inputs plus
105    /// PKCS#1 and SEC1 keys. This constructor is retained for backwards
106    /// compatibility and forwards directly to [`Identity::from_pkcs8`].
107    pub fn from_pkcs8(pem: &[u8], key: &[u8]) -> Self {
108        Identity::from_pkcs8(pem, key)
109            .expect("could not build Identity from provided pem and key")
110            .into()
111    }
112}
113
114const PEM_TAG_PKCS8: &str = "PRIVATE KEY";
115const PEM_TAG_PKCS1: &str = "RSA PRIVATE KEY";
116const PEM_TAG_SEC1: &str = "EC PRIVATE KEY";
117const PEM_TAG_CERT: &str = "CERTIFICATE";
118
119const RSA_ENCRYPTION_OID: ObjectIdentifier = ObjectIdentifier::new_unwrap("1.2.840.113549.1.1.1");
120const EC_PUBLIC_KEY_OID: ObjectIdentifier = ObjectIdentifier::new_unwrap("1.2.840.10045.2.1");
121
122// Password used for the in-memory PKCS#12 archive built by `from_cert_and_key`.
123// The archive lives only inside this process, so the password is just a
124// well-known value that lets us round-trip through the PKCS#12 import path.
125const INTERNAL_P12_PASSWORD: &str = "trillium";
126
127fn parse_pem_blocks(input: &[u8]) -> Vec<Pem> {
128    pem::parse_many(input).expect("could not parse PEM input")
129}
130
131fn normalize_key_to_pkcs8_der(input: &[u8]) -> Vec<u8> {
132    let blocks = parse_pem_blocks(input);
133    let key = blocks
134        .iter()
135        .find(|b| matches!(b.tag(), PEM_TAG_PKCS8 | PEM_TAG_PKCS1 | PEM_TAG_SEC1))
136        .expect(
137            "no private key block found in key input (expected PRIVATE KEY, RSA PRIVATE KEY, or \
138             EC PRIVATE KEY)",
139        );
140
141    match key.tag() {
142        PEM_TAG_PKCS8 => key.contents().to_vec(),
143        PEM_TAG_PKCS1 => wrap_pkcs1_in_pkcs8(key.contents()),
144        PEM_TAG_SEC1 => wrap_sec1_in_pkcs8(key.contents()),
145        _ => unreachable!(),
146    }
147}
148
149fn wrap_pkcs1_in_pkcs8(pkcs1_der: &[u8]) -> Vec<u8> {
150    let algorithm = AlgorithmIdentifierRef {
151        oid: RSA_ENCRYPTION_OID,
152        parameters: Some(AnyRef::NULL),
153    };
154    PrivateKeyInfo::new(algorithm, pkcs1_der)
155        .to_der()
156        .expect("could not encode PKCS#1 key as PKCS#8")
157}
158
159fn wrap_sec1_in_pkcs8(sec1_der: &[u8]) -> Vec<u8> {
160    let parsed =
161        sec1::EcPrivateKey::from_der(sec1_der).expect("could not parse SEC1 EC private key");
162    let curve_oid = parsed
163        .parameters
164        .and_then(|p| p.named_curve())
165        .expect("EC private key is missing namedCurve parameters");
166    let curve_param: AnyRef<'_> = (&curve_oid).into();
167    let algorithm = AlgorithmIdentifierRef {
168        oid: EC_PUBLIC_KEY_OID,
169        parameters: Some(curve_param),
170    };
171    PrivateKeyInfo::new(algorithm, sec1_der)
172        .to_der()
173        .expect("could not encode SEC1 key as PKCS#8")
174}
175
176fn encode_cert_chain_pem(cert_chain_der: &[Vec<u8>]) -> Vec<u8> {
177    let blocks: Vec<Pem> = cert_chain_der
178        .iter()
179        .map(|d| Pem::new(PEM_TAG_CERT, d.clone()))
180        .collect();
181    pem::encode_many(&blocks).into_bytes()
182}
183
184fn encode_pkcs8_pem(key_pkcs8_der: &[u8]) -> Vec<u8> {
185    pem::encode(&Pem::new(PEM_TAG_PKCS8, key_pkcs8_der.to_vec())).into_bytes()
186}
187
188fn extract_cert_chain_der(input: &[u8]) -> Vec<Vec<u8>> {
189    let certs: Vec<Vec<u8>> = parse_pem_blocks(input)
190        .into_iter()
191        .filter(|b| b.tag() == PEM_TAG_CERT)
192        .map(|b| b.into_contents())
193        .collect();
194    assert!(
195        !certs.is_empty(),
196        "no CERTIFICATE blocks found in cert input"
197    );
198    certs
199}
200
201fn build_pkcs12_der(cert_chain_der: &[Vec<u8>], key_pkcs8_der: &[u8]) -> Vec<u8> {
202    let leaf = cert_chain_der.first().expect("cert chain was empty");
203    let intermediates: Vec<&[u8]> = cert_chain_der.iter().skip(1).map(Vec::as_slice).collect();
204    let pfx = p12::PFX::new_with_cas(
205        leaf,
206        key_pkcs8_der,
207        &intermediates,
208        INTERNAL_P12_PASSWORD,
209        "",
210    )
211    .expect("could not build PKCS#12 archive from cert and key");
212    pfx.to_der()
213}
214
215impl From<Identity> for NativeTlsAcceptor {
216    fn from(i: Identity) -> Self {
217        native_tls::TlsAcceptor::new(i).unwrap().into()
218    }
219}
220
221impl From<native_tls::TlsAcceptor> for NativeTlsAcceptor {
222    fn from(i: native_tls::TlsAcceptor) -> Self {
223        Self(i.into())
224    }
225}
226
227impl From<TlsAcceptor> for NativeTlsAcceptor {
228    fn from(i: TlsAcceptor) -> Self {
229        Self(i)
230    }
231}
232
233impl From<(&[u8], &str)> for NativeTlsAcceptor {
234    fn from(i: (&[u8], &str)) -> Self {
235        Self::from_pkcs12(i.0, i.1)
236    }
237}
238
239impl<Input> Acceptor<Input> for NativeTlsAcceptor
240where
241    Input: Transport,
242{
243    type Error = Error;
244    type Output = NativeTlsServerTransport<Input>;
245
246    async fn accept(&self, input: Input) -> Result<Self::Output, Self::Error> {
247        self.0.accept(input).await.map(NativeTlsServerTransport)
248    }
249}
250
251/// Server Tls Transport
252///
253/// A wrapper type around [`TlsStream`] that also implements [`Transport`]
254#[derive(Debug)]
255pub struct NativeTlsServerTransport<T>(TlsStream<T>);
256
257impl<T: AsyncWrite + AsyncRead + Unpin> AsRef<T> for NativeTlsServerTransport<T> {
258    fn as_ref(&self) -> &T {
259        self.0.get_ref()
260    }
261}
262impl<T: AsyncWrite + AsyncRead + Unpin> AsMut<T> for NativeTlsServerTransport<T> {
263    fn as_mut(&mut self) -> &mut T {
264        self.0.get_mut()
265    }
266}
267
268impl<T> AsRef<TlsStream<T>> for NativeTlsServerTransport<T> {
269    fn as_ref(&self) -> &TlsStream<T> {
270        &self.0
271    }
272}
273impl<T> AsMut<TlsStream<T>> for NativeTlsServerTransport<T> {
274    fn as_mut(&mut self) -> &mut TlsStream<T> {
275        &mut self.0
276    }
277}
278
279impl<T: AsyncRead + AsyncWrite + Unpin> AsyncRead for NativeTlsServerTransport<T> {
280    fn poll_read(
281        mut self: Pin<&mut Self>,
282        cx: &mut Context<'_>,
283        buf: &mut [u8],
284    ) -> Poll<io::Result<usize>> {
285        Pin::new(&mut self.0).poll_read(cx, buf)
286    }
287
288    fn poll_read_vectored(
289        mut self: Pin<&mut Self>,
290        cx: &mut Context<'_>,
291        bufs: &mut [IoSliceMut<'_>],
292    ) -> Poll<io::Result<usize>> {
293        Pin::new(&mut self.0).poll_read_vectored(cx, bufs)
294    }
295}
296
297impl<T: AsyncWrite + AsyncRead + Unpin> AsyncWrite for NativeTlsServerTransport<T> {
298    fn poll_write(
299        mut self: Pin<&mut Self>,
300        cx: &mut Context<'_>,
301        buf: &[u8],
302    ) -> Poll<io::Result<usize>> {
303        Pin::new(&mut self.0).poll_write(cx, buf)
304    }
305
306    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
307        Pin::new(&mut self.0).poll_flush(cx)
308    }
309
310    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
311        Pin::new(&mut self.0).poll_close(cx)
312    }
313
314    fn poll_write_vectored(
315        mut self: Pin<&mut Self>,
316        cx: &mut Context<'_>,
317        bufs: &[IoSlice<'_>],
318    ) -> Poll<io::Result<usize>> {
319        Pin::new(&mut self.0).poll_write_vectored(cx, bufs)
320    }
321}
322
323impl<T: Transport> Transport for NativeTlsServerTransport<T> {
324    fn peer_addr(&self) -> io::Result<Option<SocketAddr>> {
325        self.0.get_ref().peer_addr()
326    }
327
328    // `negotiated_alpn` is left at the trait default (`None`). Server-side ALPN advertisement in
329    // `native-tls` lives behind the `alpn-accept` cargo feature, which `async-native-tls` 0.6 does
330    // not enable, and the wrapper's `TlsStream` does not expose `negotiated_alpn` either — so
331    // `trillium-native-tls` cannot perform ALPN-based h2 dispatch today. Revisit once the upstream
332    // wrapper grows the missing surface.
333}
334
335#[cfg(test)]
336mod tests {
337    use super::{
338        EC_PUBLIC_KEY_OID, RSA_ENCRYPTION_OID, extract_cert_chain_der, normalize_key_to_pkcs8_der,
339    };
340    use pkcs8::PrivateKeyInfo;
341
342    const RSA_CERT: &[u8] = include_bytes!("../tests/fixtures/rsa.crt");
343    const RSA_PKCS1: &[u8] = include_bytes!("../tests/fixtures/rsa-pkcs1.key");
344    const EC_CERT: &[u8] = include_bytes!("../tests/fixtures/ec.crt");
345    const EC_SEC1: &[u8] = include_bytes!("../tests/fixtures/ec-sec1.key");
346    const EC_PKCS8: &[u8] = include_bytes!("../tests/fixtures/ec-pkcs8.key");
347
348    fn parse_pkcs8_der(der: &[u8]) -> PrivateKeyInfo<'_> {
349        PrivateKeyInfo::try_from(der).expect("output not parseable as PKCS#8")
350    }
351
352    #[test]
353    fn pkcs1_wraps_to_pkcs8_with_rsa_oid() {
354        let der = normalize_key_to_pkcs8_der(RSA_PKCS1);
355        assert_eq!(parse_pkcs8_der(&der).algorithm.oid, RSA_ENCRYPTION_OID);
356    }
357
358    #[test]
359    fn sec1_wraps_to_pkcs8_with_ec_oid_and_curve_param() {
360        let der = normalize_key_to_pkcs8_der(EC_SEC1);
361        let pki = parse_pkcs8_der(&der);
362        assert_eq!(pki.algorithm.oid, EC_PUBLIC_KEY_OID);
363        assert!(
364            pki.algorithm.parameters.is_some(),
365            "EC PKCS#8 must carry namedCurve OID in algorithm parameters"
366        );
367    }
368
369    #[test]
370    fn pkcs8_pass_through_preserves_algorithm() {
371        let der = normalize_key_to_pkcs8_der(EC_PKCS8);
372        assert_eq!(parse_pkcs8_der(&der).algorithm.oid, EC_PUBLIC_KEY_OID);
373    }
374
375    #[test]
376    fn cert_extracted_from_concatenated_bundle() {
377        let mut bundle = Vec::new();
378        bundle.extend_from_slice(EC_CERT);
379        bundle.extend_from_slice(EC_SEC1);
380
381        let extracted = extract_cert_chain_der(&bundle);
382        let original: Vec<Vec<u8>> = pem::parse_many(EC_CERT)
383            .unwrap()
384            .into_iter()
385            .map(pem::Pem::into_contents)
386            .collect();
387        assert_eq!(extracted, original);
388    }
389
390    #[test]
391    fn key_extracted_from_concatenated_bundle() {
392        let mut bundle = Vec::new();
393        bundle.extend_from_slice(RSA_CERT);
394        bundle.extend_from_slice(RSA_PKCS1);
395        let der = normalize_key_to_pkcs8_der(&bundle);
396        assert_eq!(parse_pkcs8_der(&der).algorithm.oid, RSA_ENCRYPTION_OID);
397    }
398}