trillium_websockets/
lib.rs1#![forbid(unsafe_code)]
2#![deny(
3 clippy::dbg_macro,
4 missing_copy_implementations,
5 rustdoc::missing_crate_level_docs,
6 missing_debug_implementations,
7 missing_docs,
8 nonstandard_style,
9 unused_qualifications
10)]
11
12#[cfg(test)]
55#[doc = include_str!("../README.md")]
56mod readme {}
57
58mod bidirectional_stream;
59mod websocket_connection;
60mod websocket_handler;
61
62pub use async_tungstenite::{
63 self,
64 tungstenite::{
65 self, Message,
66 protocol::{Role, WebSocketConfig},
67 },
68};
69use base64::{Engine, engine::general_purpose::STANDARD as BASE64};
70use bidirectional_stream::{BidirectionalStream, Direction};
71use futures_lite::stream::StreamExt;
72use sha1::{Digest, Sha1};
73use std::{
74 net::IpAddr,
75 ops::{Deref, DerefMut},
76};
77use trillium::{
78 Conn, Handler,
79 KnownHeaderName::{
80 Connection, SecWebsocketAccept, SecWebsocketKey, SecWebsocketProtocol, SecWebsocketVersion,
81 Upgrade as UpgradeHeader,
82 },
83 Status, Upgrade,
84};
85pub use websocket_connection::WebSocketConn;
86pub use websocket_handler::WebSocketHandler;
87
88const WEBSOCKET_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
89
90#[derive(thiserror::Error, Debug)]
91#[non_exhaustive]
92pub enum Error {
95 #[error(transparent)]
96 WebSocket(#[from] tungstenite::Error),
98
99 #[cfg(feature = "json")]
100 #[error(transparent)]
101 Json(#[from] serde_json::Error),
103}
104
105pub type Result<T = Message> = std::result::Result<T, Error>;
107
108#[cfg(feature = "json")]
109mod json;
110
111#[cfg(feature = "json")]
112pub use json::{JsonHandler, JsonWebSocketHandler, json_websocket};
113
114#[derive(Debug)]
117pub struct WebSocket<H> {
118 handler: H,
119 protocols: Vec<String>,
120 config: Option<WebSocketConfig>,
121 required: bool,
122}
123
124impl<H> Deref for WebSocket<H> {
125 type Target = H;
126
127 fn deref(&self) -> &Self::Target {
128 &self.handler
129 }
130}
131
132impl<H> DerefMut for WebSocket<H> {
133 fn deref_mut(&mut self) -> &mut Self::Target {
134 &mut self.handler
135 }
136}
137
138pub fn websocket<H>(websocket_handler: H) -> WebSocket<H>
141where
142 H: WebSocketHandler,
143{
144 WebSocket::new(websocket_handler)
145}
146
147impl<H> WebSocket<H>
148where
149 H: WebSocketHandler,
150{
151 pub fn new(handler: H) -> Self {
154 Self {
155 handler,
156 protocols: Default::default(),
157 config: None,
158 required: false,
159 }
160 }
161
162 pub fn with_protocols(self, protocols: &[&str]) -> Self {
166 Self {
167 protocols: protocols.iter().map(ToString::to_string).collect(),
168 ..self
169 }
170 }
171
172 pub fn with_protocol_config(self, config: WebSocketConfig) -> Self {
174 Self {
175 config: Some(config),
176 ..self
177 }
178 }
179
180 pub fn required(mut self) -> Self {
183 self.required = true;
184 self
185 }
186}
187
188struct IsWebsocket;
189
190#[cfg(test)]
191mod tests;
192
193struct WebsocketPeerIp(Option<IpAddr>);
197
198impl<H> Handler for WebSocket<H>
199where
200 H: WebSocketHandler,
201{
202 async fn run(&self, mut conn: Conn) -> Conn {
203 if !upgrade_requested(&conn) {
204 if self.required {
205 return conn.with_status(Status::UpgradeRequired).halt();
206 } else {
207 return conn;
208 }
209 }
210
211 let websocket_peer_ip = WebsocketPeerIp(conn.peer_ip());
212
213 let Some(sec_websocket_key) = conn.request_headers().get_str(SecWebsocketKey) else {
214 return conn.with_status(Status::BadRequest).halt();
215 };
216 let sec_websocket_accept = websocket_accept_hash(sec_websocket_key);
217
218 let protocol = websocket_protocol(&conn, &self.protocols);
219
220 let headers = conn.response_headers_mut();
221
222 headers.extend([
223 (UpgradeHeader, "websocket"),
224 (Connection, "Upgrade"),
225 (SecWebsocketVersion, "13"),
226 ]);
227
228 headers.insert(SecWebsocketAccept, sec_websocket_accept);
229
230 if let Some(protocol) = protocol {
231 headers.insert(SecWebsocketProtocol, protocol);
232 }
233
234 conn.halt()
235 .with_state(websocket_peer_ip)
236 .with_state(IsWebsocket)
237 .with_status(Status::SwitchingProtocols)
238 }
239
240 fn has_upgrade(&self, upgrade: &Upgrade) -> bool {
241 upgrade.state().contains::<IsWebsocket>()
242 }
243
244 async fn upgrade(&self, mut upgrade: Upgrade) {
245 let peer_ip = upgrade
246 .state_mut()
247 .take::<WebsocketPeerIp>()
248 .and_then(|i| i.0);
249 let mut conn = WebSocketConn::new(upgrade, self.config, Role::Server).await;
250 conn.set_peer_ip(peer_ip);
251
252 let Some((mut conn, outbound)) = self.handler.connect(conn).await else {
253 return;
254 };
255
256 let inbound = conn.take_inbound_stream();
257
258 let mut stream = std::pin::pin!(BidirectionalStream { inbound, outbound });
259 while let Some(message) = stream.next().await {
260 match message {
261 Direction::Inbound(Ok(Message::Close(close_frame))) => {
262 self.handler.disconnect(&mut conn, close_frame).await;
263 break;
264 }
265
266 Direction::Inbound(Ok(message)) => {
267 self.handler.inbound(message, &mut conn).await;
268 }
269
270 Direction::Outbound(message) => {
271 if let Err(e) = self.handler.send(message, &mut conn).await {
272 log::warn!("outbound websocket error: {:?}", e);
273 break;
274 }
275 }
276
277 _ => {
278 self.handler.disconnect(&mut conn, None).await;
279 break;
280 }
281 }
282 }
283
284 if let Some(err) = conn.close().await.err() {
285 log::warn!("websocket close error: {:?}", err);
286 };
287 }
288}
289
290fn websocket_protocol(conn: &Conn, protocols: &[String]) -> Option<String> {
291 conn.request_headers()
292 .get_str(SecWebsocketProtocol)
293 .and_then(|value| {
294 value
295 .split(',')
296 .map(str::trim)
297 .find(|req_p| protocols.iter().any(|x| x == req_p))
298 .map(|s| s.to_owned())
299 })
300}
301
302fn connection_is_upgrade(conn: &Conn) -> bool {
303 conn.request_headers()
304 .get_str(Connection)
305 .map(|connection| {
306 connection
307 .split(',')
308 .any(|c| c.trim().eq_ignore_ascii_case("upgrade"))
309 })
310 .unwrap_or(false)
311}
312
313fn upgrade_to_websocket(conn: &Conn) -> bool {
314 conn.request_headers()
315 .eq_ignore_ascii_case(UpgradeHeader, "websocket")
316}
317
318fn upgrade_requested(conn: &Conn) -> bool {
319 connection_is_upgrade(conn) && upgrade_to_websocket(conn)
320}
321
322pub fn websocket_key() -> String {
324 BASE64.encode(fastrand::u128(..).to_ne_bytes())
325}
326
327pub fn websocket_accept_hash(websocket_key: &str) -> String {
329 let hash = Sha1::new()
330 .chain_update(websocket_key)
331 .chain_update(WEBSOCKET_GUID)
332 .finalize();
333 BASE64.encode(&hash[..])
334}