Skip to main content

trillium_quinn/
connection.rs

1use async_compat::Compat;
2use futures_lite::{AsyncRead, AsyncWrite};
3use quinn::VarInt;
4use std::{
5    fmt::{self, Debug, Formatter},
6    io,
7    net::SocketAddr,
8};
9use trillium_macros::{AsyncRead, AsyncWrite};
10use trillium_server_common::{
11    QuicConnectionTrait, QuicTransportBidi, QuicTransportReceive, QuicTransportSend, Transport,
12};
13
14/// A bidirectional QUIC stream, combining quinn's split send/recv
15/// into a single [`Transport`].
16#[derive(AsyncRead, AsyncWrite)]
17pub struct QuinnTransport {
18    #[async_read]
19    recv: Compat<quinn::RecvStream>,
20    #[async_write]
21    send: Compat<quinn::SendStream>,
22}
23
24impl QuinnTransport {
25    fn new(recv: quinn::RecvStream, send: quinn::SendStream) -> Self {
26        Self {
27            recv: Compat::new(recv),
28            send: Compat::new(send),
29        }
30    }
31}
32
33impl QuicTransportReceive for QuinnTransport {
34    fn stop(&mut self, code: u64) {
35        let error_code = VarInt::from_u64(code).unwrap_or_default();
36        let _ = self.recv.get_mut().stop(error_code);
37    }
38}
39
40impl QuicTransportSend for QuinnTransport {
41    fn reset(&mut self, code: u64) {
42        let error_code = VarInt::from_u64(code).unwrap_or_default();
43        let _ = self.send.get_mut().reset(error_code);
44    }
45
46    fn set_priority(&mut self, priority: i32) {
47        // Errors only when the stream is already gone, in which case there's nothing to
48        // prioritize.
49        let _ = self.send.get_mut().set_priority(priority);
50    }
51}
52
53impl QuicTransportBidi for QuinnTransport {}
54
55// `negotiated_alpn` is left at the trait default (`None`). trillium-quinn is positioned as the
56// QUIC adapter for trillium-http's HTTP/3 support, where the ALPN value is always `h3`. Nothing
57// in the framework currently needs to read it back per stream, and h1-vs-h2 dispatch only ever
58// runs on a TCP listener.
59impl Transport for QuinnTransport {}
60
61/// A QUIC connection backed by quinn, implementing [`QuicConnectionTrait`].
62#[derive(Clone, Debug)]
63pub struct QuinnConnection(quinn::Connection);
64
65impl QuinnConnection {
66    pub(crate) fn new(connection: quinn::Connection) -> Self {
67        Self(connection)
68    }
69}
70
71#[derive(AsyncRead)]
72pub struct QuinnRecv(Compat<quinn::RecvStream>);
73impl Debug for QuinnRecv {
74    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
75        f.debug_tuple("QuinnRecv").finish_non_exhaustive()
76    }
77}
78impl From<quinn::RecvStream> for QuinnRecv {
79    fn from(value: quinn::RecvStream) -> Self {
80        Self(Compat::new(value))
81    }
82}
83impl QuicTransportReceive for QuinnRecv {
84    fn stop(&mut self, code: u64) {
85        let error_code = VarInt::from_u64(code).unwrap_or_default();
86        let _ = self.0.get_mut().stop(error_code);
87    }
88}
89
90#[derive(AsyncWrite)]
91pub struct QuinnSend(Compat<quinn::SendStream>);
92
93impl Debug for QuinnSend {
94    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
95        f.debug_tuple("QuinnSend").finish_non_exhaustive()
96    }
97}
98impl From<quinn::SendStream> for QuinnSend {
99    fn from(value: quinn::SendStream) -> Self {
100        Self(Compat::new(value))
101    }
102}
103impl QuicTransportSend for QuinnSend {
104    fn reset(&mut self, code: u64) {
105        let error_code = VarInt::from_u64(code).unwrap_or_default();
106        let _ = self.0.get_mut().reset(error_code);
107    }
108
109    fn set_priority(&mut self, priority: i32) {
110        let _ = self.0.get_mut().set_priority(priority);
111    }
112}
113
114impl QuicConnectionTrait for QuinnConnection {
115    type BidiStream = QuinnTransport;
116    type RecvStream = QuinnRecv;
117    type SendStream = QuinnSend;
118
119    async fn accept_bidi(&self) -> io::Result<(u64, Self::BidiStream)> {
120        let (send, recv) = self.0.accept_bi().await.map_err(conn_err)?;
121        let stream_id = VarInt::from(recv.id()).into_inner();
122        Ok((stream_id, QuinnTransport::new(recv, send)))
123    }
124
125    async fn accept_uni(&self) -> io::Result<(u64, Self::RecvStream)> {
126        let recv = self.0.accept_uni().await.map_err(conn_err)?;
127        let stream_id = VarInt::from(recv.id()).into_inner();
128        Ok((stream_id, recv.into()))
129    }
130
131    async fn open_uni(&self) -> io::Result<(u64, Self::SendStream)> {
132        let send = self.0.open_uni().await.map_err(conn_err)?;
133        let stream_id = VarInt::from(send.id()).into_inner();
134        Ok((stream_id, send.into()))
135    }
136
137    async fn open_bidi(&self) -> io::Result<(u64, Self::BidiStream)> {
138        let (send, recv) = self.0.open_bi().await.map_err(conn_err)?;
139        let stream_id = VarInt::from(recv.id()).into_inner();
140        Ok((stream_id, QuinnTransport::new(recv, send)))
141    }
142
143    fn remote_address(&self) -> SocketAddr {
144        self.0.remote_address()
145    }
146
147    fn close(&self, error_code: u64, reason: &[u8]) {
148        self.0
149            .close(VarInt::from_u64(error_code).unwrap_or(VarInt::MAX), reason);
150    }
151
152    fn send_datagram(&self, data: &[u8]) -> io::Result<()> {
153        self.0
154            .send_datagram(data.to_vec().into())
155            .map_err(io::Error::other)
156    }
157
158    async fn recv_datagram<F: FnOnce(&[u8]) + Send>(&self, callback: F) -> io::Result<()> {
159        self.0
160            .read_datagram()
161            .await
162            .map(|d| callback(&d))
163            .map_err(conn_err)
164    }
165
166    fn max_datagram_size(&self) -> Option<usize> {
167        self.0.max_datagram_size()
168    }
169}
170
171fn conn_err(e: quinn::ConnectionError) -> io::Error {
172    io::Error::new(io::ErrorKind::ConnectionReset, e)
173}