Skip to main content

trillium_quinn/
connection.rs

1use async_compat::Compat;
2use futures_lite::{AsyncRead, AsyncWrite};
3use quinn::VarInt;
4use std::{
5    fmt::{self, Debug, Formatter},
6    io,
7    net::SocketAddr,
8};
9use trillium_macros::{AsyncRead, AsyncWrite};
10use trillium_server_common::{
11    QuicConnectionTrait, QuicTransportBidi, QuicTransportReceive, QuicTransportSend, Transport,
12};
13
14/// A bidirectional QUIC stream, combining quinn's split send/recv
15/// into a single [`Transport`].
16#[derive(AsyncRead, AsyncWrite)]
17pub struct QuinnTransport {
18    #[async_read]
19    recv: Compat<quinn::RecvStream>,
20    #[async_write]
21    send: Compat<quinn::SendStream>,
22}
23
24impl QuinnTransport {
25    fn new(recv: quinn::RecvStream, send: quinn::SendStream) -> Self {
26        Self {
27            recv: Compat::new(recv),
28            send: Compat::new(send),
29        }
30    }
31}
32
33impl QuicTransportReceive for QuinnTransport {
34    fn stop(&mut self, code: u64) {
35        let error_code = VarInt::from_u64(code).unwrap_or_default();
36        let _ = self.recv.get_mut().stop(error_code);
37    }
38}
39
40impl QuicTransportSend for QuinnTransport {
41    fn reset(&mut self, code: u64) {
42        let error_code = VarInt::from_u64(code).unwrap_or_default();
43        let _ = self.send.get_mut().reset(error_code);
44    }
45}
46
47impl QuicTransportBidi for QuinnTransport {}
48
49// `negotiated_alpn` is left at the trait default (`None`). trillium-quinn is positioned as the
50// QUIC adapter for trillium-http's HTTP/3 support, where the ALPN value is always `h3`. Nothing
51// in the framework currently needs to read it back per stream, and h1-vs-h2 dispatch only ever
52// runs on a TCP listener.
53impl Transport for QuinnTransport {}
54
55/// A QUIC connection backed by quinn, implementing [`QuicConnectionTrait`].
56#[derive(Clone, Debug)]
57pub struct QuinnConnection(quinn::Connection);
58
59impl QuinnConnection {
60    pub(crate) fn new(connection: quinn::Connection) -> Self {
61        Self(connection)
62    }
63}
64
65#[derive(AsyncRead)]
66pub struct QuinnRecv(Compat<quinn::RecvStream>);
67impl Debug for QuinnRecv {
68    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
69        f.debug_tuple("QuinnRecv").finish_non_exhaustive()
70    }
71}
72impl From<quinn::RecvStream> for QuinnRecv {
73    fn from(value: quinn::RecvStream) -> Self {
74        Self(Compat::new(value))
75    }
76}
77impl QuicTransportReceive for QuinnRecv {
78    fn stop(&mut self, code: u64) {
79        let error_code = VarInt::from_u64(code).unwrap_or_default();
80        let _ = self.0.get_mut().stop(error_code);
81    }
82}
83
84#[derive(AsyncWrite)]
85pub struct QuinnSend(Compat<quinn::SendStream>);
86
87impl Debug for QuinnSend {
88    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
89        f.debug_tuple("QuinnSend").finish_non_exhaustive()
90    }
91}
92impl From<quinn::SendStream> for QuinnSend {
93    fn from(value: quinn::SendStream) -> Self {
94        Self(Compat::new(value))
95    }
96}
97impl QuicTransportSend for QuinnSend {
98    fn reset(&mut self, code: u64) {
99        let error_code = VarInt::from_u64(code).unwrap_or_default();
100        let _ = self.0.get_mut().reset(error_code);
101    }
102}
103
104impl QuicConnectionTrait for QuinnConnection {
105    type BidiStream = QuinnTransport;
106    type RecvStream = QuinnRecv;
107    type SendStream = QuinnSend;
108
109    async fn accept_bidi(&self) -> io::Result<(u64, Self::BidiStream)> {
110        let (send, recv) = self.0.accept_bi().await.map_err(conn_err)?;
111        let stream_id = VarInt::from(recv.id()).into_inner();
112        Ok((stream_id, QuinnTransport::new(recv, send)))
113    }
114
115    async fn accept_uni(&self) -> io::Result<(u64, Self::RecvStream)> {
116        let recv = self.0.accept_uni().await.map_err(conn_err)?;
117        let stream_id = VarInt::from(recv.id()).into_inner();
118        Ok((stream_id, recv.into()))
119    }
120
121    async fn open_uni(&self) -> io::Result<(u64, Self::SendStream)> {
122        let send = self.0.open_uni().await.map_err(conn_err)?;
123        let stream_id = VarInt::from(send.id()).into_inner();
124        Ok((stream_id, send.into()))
125    }
126
127    async fn open_bidi(&self) -> io::Result<(u64, Self::BidiStream)> {
128        let (send, recv) = self.0.open_bi().await.map_err(conn_err)?;
129        let stream_id = VarInt::from(recv.id()).into_inner();
130        Ok((stream_id, QuinnTransport::new(recv, send)))
131    }
132
133    fn remote_address(&self) -> SocketAddr {
134        self.0.remote_address()
135    }
136
137    fn close(&self, error_code: u64, reason: &[u8]) {
138        self.0
139            .close(VarInt::from_u64(error_code).unwrap_or(VarInt::MAX), reason);
140    }
141
142    fn send_datagram(&self, data: &[u8]) -> io::Result<()> {
143        self.0
144            .send_datagram(data.to_vec().into())
145            .map_err(io::Error::other)
146    }
147
148    async fn recv_datagram<F: FnOnce(&[u8]) + Send>(&self, callback: F) -> io::Result<()> {
149        self.0
150            .read_datagram()
151            .await
152            .map(|d| callback(&d))
153            .map_err(conn_err)
154    }
155
156    fn max_datagram_size(&self) -> Option<usize> {
157        self.0.max_datagram_size()
158    }
159}
160
161fn conn_err(e: quinn::ConnectionError) -> io::Error {
162    io::Error::new(io::ErrorKind::ConnectionReset, e)
163}