Skip to main content

trillium_websockets/
lib.rs

1#![forbid(unsafe_code)]
2#![deny(
3    clippy::dbg_macro,
4    missing_copy_implementations,
5    rustdoc::missing_crate_level_docs,
6    missing_debug_implementations,
7    missing_docs,
8    nonstandard_style,
9    unused_qualifications
10)]
11
12//! # A websocket trillium handler
13//!
14//! There are three primary ways to use this crate
15//!
16//! ## With an async function that receives a [`WebSocketConn`]
17//!
18//! This is the simplest way to use trillium websockets, but does not
19//! provide any of the affordances that implementing the
20//! [`WebSocketHandler`] trait does. It is best for very simple websockets
21//! or for usages that require moving the WebSocketConn elsewhere in an
22//! application. The WebSocketConn is fully owned at this point, and will
23//! disconnect when dropped, not when the async function passed to
24//! `websocket` completes.
25//!
26//! ```
27//! use futures_lite::stream::StreamExt;
28//! use trillium_websockets::{Message, WebSocketConn, websocket};
29//!
30//! let handler = websocket(|mut conn: WebSocketConn| async move {
31//!     while let Some(Ok(Message::Text(input))) = conn.next().await {
32//!         conn.send_string(format!("received your message: {}", &input))
33//!             .await;
34//!     }
35//! });
36//! # // tests at tests/tests.rs for example simplicity
37//! ```
38//!
39//!
40//! ## Implementing [`WebSocketHandler`]
41//!
42//! [`WebSocketHandler`] provides support for sending outbound messages as a
43//! stream, and simplifies common patterns like executing async code on
44//! received messages.
45//!
46//! ## Using [`JsonWebSocketHandler`]
47//!
48//! [`JsonWebSocketHandler`] provides a thin serialization and
49//! deserialization layer on top of [`WebSocketHandler`] for this common
50//! use case.  See the [`JsonWebSocketHandler`] documentation for example
51//! usage. In order to use this trait, the `json` cargo feature must be
52//! enabled.
53
54#[cfg(test)]
55#[doc = include_str!("../README.md")]
56mod readme {}
57
58mod bidirectional_stream;
59mod websocket_connection;
60mod websocket_handler;
61
62pub use async_tungstenite::{
63    self,
64    tungstenite::{
65        self, Message,
66        protocol::{Role, WebSocketConfig},
67    },
68};
69use base64::{Engine, engine::general_purpose::STANDARD as BASE64};
70use bidirectional_stream::{BidirectionalStream, Direction};
71use futures_lite::stream::StreamExt;
72use sha1::{Digest, Sha1};
73use std::{
74    net::IpAddr,
75    ops::{Deref, DerefMut},
76};
77use trillium::{
78    Conn, Handler, Info,
79    KnownHeaderName::{
80        Connection, SecWebsocketAccept, SecWebsocketKey, SecWebsocketProtocol, SecWebsocketVersion,
81        Upgrade as UpgradeHeader,
82    },
83    Method, Status, Upgrade, Version,
84};
85pub use websocket_connection::WebSocketConn;
86pub use websocket_handler::WebSocketHandler;
87
88const WEBSOCKET_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
89
90#[derive(thiserror::Error, Debug)]
91#[non_exhaustive]
92/// An Error type that represents all exceptional conditions that can be encoutered in the operation
93/// of this crate
94pub enum Error {
95    #[error(transparent)]
96    /// an error in the underlying websocket implementation
97    WebSocket(#[from] tungstenite::Error),
98
99    #[cfg(feature = "json")]
100    #[error(transparent)]
101    /// an error in json serialization or deserialization
102    Json(#[from] serde_json::Error),
103}
104
105/// a Result type for this crate
106pub type Result<T = Message> = std::result::Result<T, Error>;
107
108#[cfg(feature = "json")]
109mod json;
110
111#[cfg(feature = "json")]
112pub use json::{JsonHandler, JsonWebSocketHandler, json_websocket};
113
114/// The trillium handler.
115/// See crate-level docs for example usage.
116#[derive(Debug)]
117pub struct WebSocket<H> {
118    handler: H,
119    protocols: Vec<String>,
120    config: Option<WebSocketConfig>,
121    required: bool,
122}
123
124impl<H> Deref for WebSocket<H> {
125    type Target = H;
126
127    fn deref(&self) -> &Self::Target {
128        &self.handler
129    }
130}
131
132impl<H> DerefMut for WebSocket<H> {
133    fn deref_mut(&mut self) -> &mut Self::Target {
134        &mut self.handler
135    }
136}
137
138/// Builds a new trillium handler from the provided
139/// WebSocketHandler. Alias for [`WebSocket::new`]
140pub fn websocket<H>(websocket_handler: H) -> WebSocket<H>
141where
142    H: WebSocketHandler,
143{
144    WebSocket::new(websocket_handler)
145}
146
147impl<H> WebSocket<H>
148where
149    H: WebSocketHandler,
150{
151    async fn run_h1(&self, mut conn: Conn) -> Conn {
152        if !upgrade_requested(&conn) {
153            if self.required {
154                return conn.with_status(Status::UpgradeRequired).halt();
155            } else {
156                return conn;
157            }
158        }
159
160        let websocket_peer_ip = WebsocketPeerIp(conn.peer_ip());
161
162        let Some(sec_websocket_key) = conn.request_headers().get_str(SecWebsocketKey) else {
163            return conn.with_status(Status::BadRequest).halt();
164        };
165        let sec_websocket_accept = websocket_accept_hash(sec_websocket_key);
166
167        let protocol = websocket_protocol(&conn, &self.protocols);
168
169        let headers = conn.response_headers_mut();
170
171        headers.extend([
172            (UpgradeHeader, "websocket"),
173            (Connection, "Upgrade"),
174            (SecWebsocketVersion, "13"),
175        ]);
176
177        headers.insert(SecWebsocketAccept, sec_websocket_accept);
178
179        if let Some(protocol) = protocol {
180            headers.insert(SecWebsocketProtocol, protocol);
181        }
182
183        conn.halt()
184            .with_state(websocket_peer_ip)
185            .with_state(IsWebsocket)
186            .with_status(Status::SwitchingProtocols)
187    }
188
189    /// Build a new WebSocket with an async handler function that
190    /// receives a [`WebSocketConn`]
191    pub fn new(handler: H) -> Self {
192        Self {
193            handler,
194            protocols: Default::default(),
195            config: None,
196            required: false,
197        }
198    }
199
200    /// `protocols` is a sequence of known protocols. On successful handshake,
201    /// the returned response headers contain the first protocol in this list
202    /// which the server also knows.
203    pub fn with_protocols(self, protocols: &[&str]) -> Self {
204        Self {
205            protocols: protocols.iter().map(ToString::to_string).collect(),
206            ..self
207        }
208    }
209
210    /// configure the websocket protocol
211    pub fn with_protocol_config(self, config: WebSocketConfig) -> Self {
212        Self {
213            config: Some(config),
214            ..self
215        }
216    }
217
218    /// configure this handler to halt and send back a [`426 Upgrade
219    /// Required`][Status::UpgradeRequired] if a websocket cannot be negotiated
220    pub fn required(mut self) -> Self {
221        self.required = true;
222        self
223    }
224}
225
226struct IsWebsocket;
227
228#[cfg(test)]
229mod tests;
230
231// this is a workaround for the fact that Upgrade is a public struct,
232// so adding peer_ip to that struct would be a breaking change. We
233// stash a copy in state for now.
234struct WebsocketPeerIp(Option<IpAddr>);
235
236impl<H> Handler for WebSocket<H>
237where
238    H: WebSocketHandler,
239{
240    async fn run(&self, mut conn: Conn) -> Conn {
241        match conn.http_version() {
242            Version::Http1_0 | Version::Http1_1 => self.run_h1(conn).await,
243            // Extended-CONNECT bootstrap of WebSockets — RFC 8441 (h2) and RFC 9220 (h3) define
244            // the same shape: `:method = CONNECT`, `:protocol = websocket`, no SHA1/Key/Accept
245            // handshake. The server replies with status 200 and the stream stays open as a
246            // bidirectional byte channel carrying WebSocket frames.
247            Version::Http2 | Version::Http3 => {
248                if extended_connect_websocket_request(&conn) {
249                    let websocket_peer_ip = WebsocketPeerIp(conn.peer_ip());
250                    let protocol = websocket_protocol(&conn, &self.protocols);
251
252                    if let Some(protocol) = protocol {
253                        conn.response_headers_mut()
254                            .insert(SecWebsocketProtocol, protocol);
255                    }
256
257                    conn.halt()
258                        .with_state(websocket_peer_ip)
259                        .with_state(IsWebsocket)
260                        .with_status(Status::Ok)
261                } else if self.required {
262                    conn.with_status(Status::UpgradeRequired).halt()
263                } else {
264                    conn
265                }
266            }
267            _ => {
268                if self.required {
269                    conn.with_status(Status::UpgradeRequired).halt()
270                } else {
271                    conn
272                }
273            }
274        }
275    }
276
277    async fn init(&mut self, info: &mut Info) {
278        // Required for h2 (RFC 8441 §3) and h3 (RFC 9220 §3) clients to attempt the extended
279        // CONNECT bootstrap of WebSockets. Harmless on h1.
280        info.config_mut().set_extended_connect_enabled(true);
281    }
282
283    fn has_upgrade(&self, upgrade: &Upgrade) -> bool {
284        upgrade.state().contains::<IsWebsocket>()
285    }
286
287    async fn upgrade(&self, mut upgrade: Upgrade) {
288        let peer_ip = upgrade
289            .state_mut()
290            .take::<WebsocketPeerIp>()
291            .and_then(|i| i.0);
292        let mut conn = WebSocketConn::new(upgrade, self.config, Role::Server).await;
293        conn.set_peer_ip(peer_ip);
294
295        let Some((mut conn, outbound)) = self.handler.connect(conn).await else {
296            return;
297        };
298
299        let inbound = conn.take_inbound_stream();
300
301        let mut stream = std::pin::pin!(BidirectionalStream { inbound, outbound });
302        while let Some(message) = stream.next().await {
303            match message {
304                Direction::Inbound(Ok(Message::Close(close_frame))) => {
305                    self.handler.disconnect(&mut conn, close_frame).await;
306                    break;
307                }
308
309                Direction::Inbound(Ok(message)) => {
310                    self.handler.inbound(message, &mut conn).await;
311                }
312
313                Direction::Outbound(message) => {
314                    if let Err(e) = self.handler.send(message, &mut conn).await {
315                        log::warn!("outbound websocket error: {:?}", e);
316                        break;
317                    }
318                }
319
320                _ => {
321                    self.handler.disconnect(&mut conn, None).await;
322                    break;
323                }
324            }
325        }
326
327        if let Some(err) = conn.close().await.err() {
328            log::warn!("websocket close error: {:?}", err);
329        };
330    }
331}
332
333fn websocket_protocol(conn: &Conn, protocols: &[String]) -> Option<String> {
334    conn.request_headers()
335        .get_str(SecWebsocketProtocol)
336        .and_then(|value| {
337            value
338                .split(',')
339                .map(str::trim)
340                .find(|req_p| protocols.iter().any(|x| x == req_p))
341                .map(|s| s.to_owned())
342        })
343}
344
345fn connection_is_upgrade(conn: &Conn) -> bool {
346    conn.request_headers()
347        .get_str(Connection)
348        .map(|connection| {
349            connection
350                .split(',')
351                .any(|c| c.trim().eq_ignore_ascii_case("upgrade"))
352        })
353        .unwrap_or(false)
354}
355
356fn upgrade_to_websocket(conn: &Conn) -> bool {
357    conn.request_headers()
358        .eq_ignore_ascii_case(UpgradeHeader, "websocket")
359}
360
361fn upgrade_requested(conn: &Conn) -> bool {
362    connection_is_upgrade(conn) && upgrade_to_websocket(conn)
363}
364
365/// Detect a WebSocket bootstrap over extended CONNECT (RFC 8441 for h2, RFC 9220 for h3).
366///
367/// The peer must use `CONNECT` and carry a `:protocol` pseudo-header equal to "websocket"
368/// (case-insensitive per the RFCs).
369fn extended_connect_websocket_request(conn: &Conn) -> bool {
370    if conn.method() != Method::Connect {
371        return false;
372    }
373    let inner: &trillium_http::Conn<Box<dyn trillium::Transport>> = conn.as_ref();
374    inner
375        .protocol()
376        .is_some_and(|p| p.eq_ignore_ascii_case("websocket"))
377}
378
379/// Generate a random key suitable for Sec-WebSocket-Key
380pub fn websocket_key() -> String {
381    BASE64.encode(fastrand::u128(..).to_ne_bytes())
382}
383
384/// Generate the expected Sec-WebSocket-Accept hash from the Sec-WebSocket-Key
385pub fn websocket_accept_hash(websocket_key: &str) -> String {
386    let hash = Sha1::new()
387        .chain_update(websocket_key)
388        .chain_update(WEBSOCKET_GUID)
389        .finalize();
390    BASE64.encode(&hash[..])
391}