trillium_websockets/
lib.rs1#![forbid(unsafe_code)]
2#![deny(
3 clippy::dbg_macro,
4 missing_copy_implementations,
5 rustdoc::missing_crate_level_docs,
6 missing_debug_implementations,
7 missing_docs,
8 nonstandard_style,
9 unused_qualifications
10)]
11
12#[cfg(test)]
55#[doc = include_str!("../README.md")]
56mod readme {}
57
58mod bidirectional_stream;
59mod websocket_connection;
60mod websocket_handler;
61
62pub use async_tungstenite::{
63 self,
64 tungstenite::{
65 self, Message,
66 protocol::{Role, WebSocketConfig},
67 },
68};
69use base64::{Engine, engine::general_purpose::STANDARD as BASE64};
70use bidirectional_stream::{BidirectionalStream, Direction};
71use futures_lite::stream::StreamExt;
72use sha1::{Digest, Sha1};
73use std::{
74 net::IpAddr,
75 ops::{Deref, DerefMut},
76};
77use trillium::{
78 Conn, Handler, Info,
79 KnownHeaderName::{
80 Connection, SecWebsocketAccept, SecWebsocketKey, SecWebsocketProtocol, SecWebsocketVersion,
81 Upgrade as UpgradeHeader,
82 },
83 Method, Status, Upgrade, Version,
84};
85pub use websocket_connection::WebSocketConn;
86pub use websocket_handler::WebSocketHandler;
87
88const WEBSOCKET_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
89
90#[derive(thiserror::Error, Debug)]
91#[non_exhaustive]
92pub enum Error {
95 #[error(transparent)]
96 WebSocket(#[from] tungstenite::Error),
98
99 #[cfg(feature = "json")]
100 #[error(transparent)]
101 Json(#[from] serde_json::Error),
103}
104
105pub type Result<T = Message> = std::result::Result<T, Error>;
107
108#[cfg(feature = "json")]
109mod json;
110
111#[cfg(feature = "json")]
112pub use json::{JsonHandler, JsonWebSocketHandler, json_websocket};
113
114#[derive(Debug)]
117pub struct WebSocket<H> {
118 handler: H,
119 protocols: Vec<String>,
120 config: Option<WebSocketConfig>,
121 required: bool,
122}
123
124impl<H> Deref for WebSocket<H> {
125 type Target = H;
126
127 fn deref(&self) -> &Self::Target {
128 &self.handler
129 }
130}
131
132impl<H> DerefMut for WebSocket<H> {
133 fn deref_mut(&mut self) -> &mut Self::Target {
134 &mut self.handler
135 }
136}
137
138pub fn websocket<H>(websocket_handler: H) -> WebSocket<H>
141where
142 H: WebSocketHandler,
143{
144 WebSocket::new(websocket_handler)
145}
146
147impl<H> WebSocket<H>
148where
149 H: WebSocketHandler,
150{
151 async fn run_h1(&self, mut conn: Conn) -> Conn {
152 if !upgrade_requested(&conn) {
153 if self.required {
154 return conn.with_status(Status::UpgradeRequired).halt();
155 } else {
156 return conn;
157 }
158 }
159
160 let websocket_peer_ip = WebsocketPeerIp(conn.peer_ip());
161
162 let Some(sec_websocket_key) = conn.request_headers().get_str(SecWebsocketKey) else {
163 return conn.with_status(Status::BadRequest).halt();
164 };
165 let sec_websocket_accept = websocket_accept_hash(sec_websocket_key);
166
167 let protocol = websocket_protocol(&conn, &self.protocols);
168
169 let headers = conn.response_headers_mut();
170
171 headers.extend([
172 (UpgradeHeader, "websocket"),
173 (Connection, "Upgrade"),
174 (SecWebsocketVersion, "13"),
175 ]);
176
177 headers.insert(SecWebsocketAccept, sec_websocket_accept);
178
179 if let Some(protocol) = protocol {
180 headers.insert(SecWebsocketProtocol, protocol);
181 }
182
183 conn.halt()
184 .with_state(websocket_peer_ip)
185 .with_state(IsWebsocket)
186 .with_status(Status::SwitchingProtocols)
187 }
188
189 pub fn new(handler: H) -> Self {
192 Self {
193 handler,
194 protocols: Default::default(),
195 config: None,
196 required: false,
197 }
198 }
199
200 pub fn with_protocols(self, protocols: &[&str]) -> Self {
204 Self {
205 protocols: protocols.iter().map(ToString::to_string).collect(),
206 ..self
207 }
208 }
209
210 pub fn with_protocol_config(self, config: WebSocketConfig) -> Self {
212 Self {
213 config: Some(config),
214 ..self
215 }
216 }
217
218 pub fn required(mut self) -> Self {
221 self.required = true;
222 self
223 }
224}
225
226struct IsWebsocket;
227
228#[cfg(test)]
229mod tests;
230
231struct WebsocketPeerIp(Option<IpAddr>);
235
236impl<H> Handler for WebSocket<H>
237where
238 H: WebSocketHandler,
239{
240 async fn run(&self, mut conn: Conn) -> Conn {
241 match conn.http_version() {
242 Version::Http1_0 | Version::Http1_1 => self.run_h1(conn).await,
243 Version::Http2 | Version::Http3 => {
248 if extended_connect_websocket_request(&conn) {
249 let websocket_peer_ip = WebsocketPeerIp(conn.peer_ip());
250 let protocol = websocket_protocol(&conn, &self.protocols);
251
252 if let Some(protocol) = protocol {
253 conn.response_headers_mut()
254 .insert(SecWebsocketProtocol, protocol);
255 }
256
257 conn.halt()
258 .with_state(websocket_peer_ip)
259 .with_state(IsWebsocket)
260 .with_status(Status::Ok)
261 } else if self.required {
262 conn.with_status(Status::UpgradeRequired).halt()
263 } else {
264 conn
265 }
266 }
267 _ => {
268 if self.required {
269 conn.with_status(Status::UpgradeRequired).halt()
270 } else {
271 conn
272 }
273 }
274 }
275 }
276
277 async fn init(&mut self, info: &mut Info) {
278 info.config_mut().set_extended_connect_enabled(true);
281 }
282
283 fn has_upgrade(&self, upgrade: &Upgrade) -> bool {
284 upgrade.state().contains::<IsWebsocket>()
285 }
286
287 async fn upgrade(&self, mut upgrade: Upgrade) {
288 let peer_ip = upgrade
289 .state_mut()
290 .take::<WebsocketPeerIp>()
291 .and_then(|i| i.0);
292 let mut conn = WebSocketConn::new(upgrade, self.config, Role::Server).await;
293 conn.set_peer_ip(peer_ip);
294
295 let Some((mut conn, outbound)) = self.handler.connect(conn).await else {
296 return;
297 };
298
299 let inbound = conn.take_inbound_stream();
300
301 let mut stream = std::pin::pin!(BidirectionalStream { inbound, outbound });
302 while let Some(message) = stream.next().await {
303 match message {
304 Direction::Inbound(Ok(Message::Close(close_frame))) => {
305 self.handler.disconnect(&mut conn, close_frame).await;
306 break;
307 }
308
309 Direction::Inbound(Ok(message)) => {
310 self.handler.inbound(message, &mut conn).await;
311 }
312
313 Direction::Outbound(message) => {
314 if let Err(e) = self.handler.send(message, &mut conn).await {
315 log::warn!("outbound websocket error: {:?}", e);
316 break;
317 }
318 }
319
320 _ => {
321 self.handler.disconnect(&mut conn, None).await;
322 break;
323 }
324 }
325 }
326
327 if let Some(err) = conn.close().await.err() {
328 log::warn!("websocket close error: {:?}", err);
329 };
330 }
331}
332
333fn websocket_protocol(conn: &Conn, protocols: &[String]) -> Option<String> {
334 conn.request_headers()
335 .get_str(SecWebsocketProtocol)
336 .and_then(|value| {
337 value
338 .split(',')
339 .map(str::trim)
340 .find(|req_p| protocols.iter().any(|x| x == req_p))
341 .map(|s| s.to_owned())
342 })
343}
344
345fn connection_is_upgrade(conn: &Conn) -> bool {
346 conn.request_headers()
347 .get_str(Connection)
348 .map(|connection| {
349 connection
350 .split(',')
351 .any(|c| c.trim().eq_ignore_ascii_case("upgrade"))
352 })
353 .unwrap_or(false)
354}
355
356fn upgrade_to_websocket(conn: &Conn) -> bool {
357 conn.request_headers()
358 .eq_ignore_ascii_case(UpgradeHeader, "websocket")
359}
360
361fn upgrade_requested(conn: &Conn) -> bool {
362 connection_is_upgrade(conn) && upgrade_to_websocket(conn)
363}
364
365fn extended_connect_websocket_request(conn: &Conn) -> bool {
370 if conn.method() != Method::Connect {
371 return false;
372 }
373 let inner: &trillium_http::Conn<Box<dyn trillium::Transport>> = conn.as_ref();
374 inner
375 .protocol()
376 .is_some_and(|p| p.eq_ignore_ascii_case("websocket"))
377}
378
379pub fn websocket_key() -> String {
381 BASE64.encode(fastrand::u128(..).to_ne_bytes())
382}
383
384pub fn websocket_accept_hash(websocket_key: &str) -> String {
386 let hash = Sha1::new()
387 .chain_update(websocket_key)
388 .chain_update(WEBSOCKET_GUID)
389 .finalize();
390 BASE64.encode(&hash[..])
391}