Skip to main content

trillium_websockets/
lib.rs

1#![forbid(unsafe_code)]
2#![deny(
3    clippy::dbg_macro,
4    missing_copy_implementations,
5    rustdoc::missing_crate_level_docs,
6    missing_debug_implementations,
7    missing_docs,
8    nonstandard_style,
9    unused_qualifications
10)]
11
12//! # A websocket trillium handler
13//!
14//! There are three primary ways to use this crate
15//!
16//! ## With an async function that receives a [`WebSocketConn`]
17//!
18//! This is the simplest way to use trillium websockets, but does not
19//! provide any of the affordances that implementing the
20//! [`WebSocketHandler`] trait does. It is best for very simple websockets
21//! or for usages that require moving the WebSocketConn elsewhere in an
22//! application. The WebSocketConn is fully owned at this point, and will
23//! disconnect when dropped, not when the async function passed to
24//! `websocket` completes.
25//!
26//! ```
27//! use futures_lite::stream::StreamExt;
28//! use trillium_websockets::{Message, WebSocketConn, websocket};
29//!
30//! let handler = websocket(|mut conn: WebSocketConn| async move {
31//!     while let Some(Ok(Message::Text(input))) = conn.next().await {
32//!         conn.send_string(format!("received your message: {}", &input))
33//!             .await;
34//!     }
35//! });
36//! # // tests at tests/tests.rs for example simplicity
37//! ```
38//!
39//!
40//! ## Implementing [`WebSocketHandler`]
41//!
42//! [`WebSocketHandler`] provides support for sending outbound messages as a
43//! stream, and simplifies common patterns like executing async code on
44//! received messages.
45//!
46//! ## Using [`JsonWebSocketHandler`]
47//!
48//! [`JsonWebSocketHandler`] provides a thin serialization and
49//! deserialization layer on top of [`WebSocketHandler`] for this common
50//! use case.  See the [`JsonWebSocketHandler`] documentation for example
51//! usage. In order to use this trait, the `json` cargo feature must be
52//! enabled.
53
54#[cfg(test)]
55#[doc = include_str!("../README.md")]
56mod readme {}
57
58mod bidirectional_stream;
59mod websocket_connection;
60mod websocket_handler;
61
62pub use async_tungstenite::{
63    self,
64    tungstenite::{
65        self, Message,
66        protocol::{Role, WebSocketConfig},
67    },
68};
69use base64::{Engine, engine::general_purpose::STANDARD as BASE64};
70use bidirectional_stream::{BidirectionalStream, Direction};
71use futures_lite::stream::StreamExt;
72use sha1::{Digest, Sha1};
73use std::{
74    net::IpAddr,
75    ops::{Deref, DerefMut},
76};
77use trillium::{
78    Conn, Handler,
79    KnownHeaderName::{
80        Connection, SecWebsocketAccept, SecWebsocketKey, SecWebsocketProtocol, SecWebsocketVersion,
81        Upgrade as UpgradeHeader,
82    },
83    Status, Upgrade,
84};
85pub use websocket_connection::WebSocketConn;
86pub use websocket_handler::WebSocketHandler;
87
88const WEBSOCKET_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
89
90#[derive(thiserror::Error, Debug)]
91#[non_exhaustive]
92/// An Error type that represents all exceptional conditions that can be encoutered in the operation
93/// of this crate
94pub enum Error {
95    #[error(transparent)]
96    /// an error in the underlying websocket implementation
97    WebSocket(#[from] tungstenite::Error),
98
99    #[cfg(feature = "json")]
100    #[error(transparent)]
101    /// an error in json serialization or deserialization
102    Json(#[from] serde_json::Error),
103}
104
105/// a Result type for this crate
106pub type Result<T = Message> = std::result::Result<T, Error>;
107
108#[cfg(feature = "json")]
109mod json;
110
111#[cfg(feature = "json")]
112pub use json::{JsonHandler, JsonWebSocketHandler, json_websocket};
113
114/// The trillium handler.
115/// See crate-level docs for example usage.
116#[derive(Debug)]
117pub struct WebSocket<H> {
118    handler: H,
119    protocols: Vec<String>,
120    config: Option<WebSocketConfig>,
121    required: bool,
122}
123
124impl<H> Deref for WebSocket<H> {
125    type Target = H;
126
127    fn deref(&self) -> &Self::Target {
128        &self.handler
129    }
130}
131
132impl<H> DerefMut for WebSocket<H> {
133    fn deref_mut(&mut self) -> &mut Self::Target {
134        &mut self.handler
135    }
136}
137
138/// Builds a new trillium handler from the provided
139/// WebSocketHandler. Alias for [`WebSocket::new`]
140pub fn websocket<H>(websocket_handler: H) -> WebSocket<H>
141where
142    H: WebSocketHandler,
143{
144    WebSocket::new(websocket_handler)
145}
146
147impl<H> WebSocket<H>
148where
149    H: WebSocketHandler,
150{
151    /// Build a new WebSocket with an async handler function that
152    /// receives a [`WebSocketConn`]
153    pub fn new(handler: H) -> Self {
154        Self {
155            handler,
156            protocols: Default::default(),
157            config: None,
158            required: false,
159        }
160    }
161
162    /// `protocols` is a sequence of known protocols. On successful handshake,
163    /// the returned response headers contain the first protocol in this list
164    /// which the server also knows.
165    pub fn with_protocols(self, protocols: &[&str]) -> Self {
166        Self {
167            protocols: protocols.iter().map(ToString::to_string).collect(),
168            ..self
169        }
170    }
171
172    /// configure the websocket protocol
173    pub fn with_protocol_config(self, config: WebSocketConfig) -> Self {
174        Self {
175            config: Some(config),
176            ..self
177        }
178    }
179
180    /// configure this handler to halt and send back a [`426 Upgrade
181    /// Required`][Status::UpgradeRequired] if a websocket cannot be negotiated
182    pub fn required(mut self) -> Self {
183        self.required = true;
184        self
185    }
186}
187
188struct IsWebsocket;
189
190#[cfg(test)]
191mod tests;
192
193// this is a workaround for the fact that Upgrade is a public struct,
194// so adding peer_ip to that struct would be a breaking change. We
195// stash a copy in state for now.
196struct WebsocketPeerIp(Option<IpAddr>);
197
198impl<H> Handler for WebSocket<H>
199where
200    H: WebSocketHandler,
201{
202    async fn run(&self, mut conn: Conn) -> Conn {
203        if !upgrade_requested(&conn) {
204            if self.required {
205                return conn.with_status(Status::UpgradeRequired).halt();
206            } else {
207                return conn;
208            }
209        }
210
211        let websocket_peer_ip = WebsocketPeerIp(conn.peer_ip());
212
213        let Some(sec_websocket_key) = conn.request_headers().get_str(SecWebsocketKey) else {
214            return conn.with_status(Status::BadRequest).halt();
215        };
216        let sec_websocket_accept = websocket_accept_hash(sec_websocket_key);
217
218        let protocol = websocket_protocol(&conn, &self.protocols);
219
220        let headers = conn.response_headers_mut();
221
222        headers.extend([
223            (UpgradeHeader, "websocket"),
224            (Connection, "Upgrade"),
225            (SecWebsocketVersion, "13"),
226        ]);
227
228        headers.insert(SecWebsocketAccept, sec_websocket_accept);
229
230        if let Some(protocol) = protocol {
231            headers.insert(SecWebsocketProtocol, protocol);
232        }
233
234        conn.halt()
235            .with_state(websocket_peer_ip)
236            .with_state(IsWebsocket)
237            .with_status(Status::SwitchingProtocols)
238    }
239
240    fn has_upgrade(&self, upgrade: &Upgrade) -> bool {
241        upgrade.state().contains::<IsWebsocket>()
242    }
243
244    async fn upgrade(&self, mut upgrade: Upgrade) {
245        let peer_ip = upgrade
246            .state_mut()
247            .take::<WebsocketPeerIp>()
248            .and_then(|i| i.0);
249        let mut conn = WebSocketConn::new(upgrade, self.config, Role::Server).await;
250        conn.set_peer_ip(peer_ip);
251
252        let Some((mut conn, outbound)) = self.handler.connect(conn).await else {
253            return;
254        };
255
256        let inbound = conn.take_inbound_stream();
257
258        let mut stream = std::pin::pin!(BidirectionalStream { inbound, outbound });
259        while let Some(message) = stream.next().await {
260            match message {
261                Direction::Inbound(Ok(Message::Close(close_frame))) => {
262                    self.handler.disconnect(&mut conn, close_frame).await;
263                    break;
264                }
265
266                Direction::Inbound(Ok(message)) => {
267                    self.handler.inbound(message, &mut conn).await;
268                }
269
270                Direction::Outbound(message) => {
271                    if let Err(e) = self.handler.send(message, &mut conn).await {
272                        log::warn!("outbound websocket error: {:?}", e);
273                        break;
274                    }
275                }
276
277                _ => {
278                    self.handler.disconnect(&mut conn, None).await;
279                    break;
280                }
281            }
282        }
283
284        if let Some(err) = conn.close().await.err() {
285            log::warn!("websocket close error: {:?}", err);
286        };
287    }
288}
289
290fn websocket_protocol(conn: &Conn, protocols: &[String]) -> Option<String> {
291    conn.request_headers()
292        .get_str(SecWebsocketProtocol)
293        .and_then(|value| {
294            value
295                .split(',')
296                .map(str::trim)
297                .find(|req_p| protocols.iter().any(|x| x == req_p))
298                .map(|s| s.to_owned())
299        })
300}
301
302fn connection_is_upgrade(conn: &Conn) -> bool {
303    conn.request_headers()
304        .get_str(Connection)
305        .map(|connection| {
306            connection
307                .split(',')
308                .any(|c| c.trim().eq_ignore_ascii_case("upgrade"))
309        })
310        .unwrap_or(false)
311}
312
313fn upgrade_to_websocket(conn: &Conn) -> bool {
314    conn.request_headers()
315        .eq_ignore_ascii_case(UpgradeHeader, "websocket")
316}
317
318fn upgrade_requested(conn: &Conn) -> bool {
319    connection_is_upgrade(conn) && upgrade_to_websocket(conn)
320}
321
322/// Generate a random key suitable for Sec-WebSocket-Key
323pub fn websocket_key() -> String {
324    BASE64.encode(fastrand::u128(..).to_ne_bytes())
325}
326
327/// Generate the expected Sec-WebSocket-Accept hash from the Sec-WebSocket-Key
328pub fn websocket_accept_hash(websocket_key: &str) -> String {
329    let hash = Sha1::new()
330        .chain_update(websocket_key)
331        .chain_update(WEBSOCKET_GUID)
332        .finalize();
333    BASE64.encode(&hash[..])
334}