trillium_websockets/
lib.rs1#![forbid(unsafe_code)]
2#![deny(
3 clippy::dbg_macro,
4 missing_copy_implementations,
5 rustdoc::missing_crate_level_docs,
6 missing_debug_implementations,
7 missing_docs,
8 nonstandard_style,
9 unused_qualifications
10)]
11
12#[cfg(test)]
55#[doc = include_str!("../README.md")]
56mod readme {}
57
58mod bidirectional_stream;
59mod websocket_connection;
60mod websocket_handler;
61
62pub use async_tungstenite::{
63 self,
64 tungstenite::{
65 self, Message,
66 protocol::{Role, WebSocketConfig},
67 },
68};
69use base64::{Engine, engine::general_purpose::STANDARD as BASE64};
70use bidirectional_stream::{BidirectionalStream, Direction};
71use futures_lite::stream::StreamExt;
72use sha1::{Digest, Sha1};
73use std::{
74 net::IpAddr,
75 ops::{Deref, DerefMut},
76};
77use trillium::{
78 Conn, Handler, Info,
79 KnownHeaderName::{
80 Connection, SecWebsocketAccept, SecWebsocketKey, SecWebsocketProtocol, SecWebsocketVersion,
81 Upgrade as UpgradeHeader,
82 },
83 Method, Status, Upgrade, Version,
84};
85pub use websocket_connection::WebSocketConn;
86pub use websocket_handler::WebSocketHandler;
87
88const WEBSOCKET_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
89
90#[derive(thiserror::Error, Debug)]
91#[non_exhaustive]
92pub enum Error {
95 #[error(transparent)]
96 WebSocket(#[from] tungstenite::Error),
98
99 #[cfg(feature = "json")]
100 #[error(transparent)]
101 Json(#[from] serde_json::Error),
103}
104
105pub type Result<T = Message> = std::result::Result<T, Error>;
107
108#[cfg(feature = "json")]
109mod json;
110
111#[cfg(feature = "json")]
112pub use json::{JsonHandler, JsonWebSocketHandler, json_websocket};
113
114#[derive(Debug)]
117pub struct WebSocket<H> {
118 handler: H,
119 protocols: Vec<String>,
120 config: Option<WebSocketConfig>,
121 required: bool,
122}
123
124impl<H> Deref for WebSocket<H> {
125 type Target = H;
126
127 fn deref(&self) -> &Self::Target {
128 &self.handler
129 }
130}
131
132impl<H> DerefMut for WebSocket<H> {
133 fn deref_mut(&mut self) -> &mut Self::Target {
134 &mut self.handler
135 }
136}
137
138pub fn websocket<H>(websocket_handler: H) -> WebSocket<H>
141where
142 H: WebSocketHandler,
143{
144 WebSocket::new(websocket_handler)
145}
146
147impl<H> WebSocket<H>
148where
149 H: WebSocketHandler,
150{
151 async fn run_h1(&self, mut conn: Conn) -> Conn {
152 if !upgrade_requested(&conn) {
153 if self.required {
154 return conn.with_status(Status::UpgradeRequired).halt();
155 } else {
156 return conn;
157 }
158 }
159
160 if !supported_websocket_version(&conn) {
161 return reject_unsupported_version(conn);
162 }
163
164 let websocket_peer_ip = WebsocketPeerIp(conn.peer_ip());
165
166 let Some(sec_websocket_key) = conn.request_headers().get_str(SecWebsocketKey) else {
167 return conn.with_status(Status::BadRequest).halt();
168 };
169 let sec_websocket_accept = websocket_accept_hash(sec_websocket_key);
170
171 let protocol = websocket_protocol(&conn, &self.protocols);
172
173 let headers = conn.response_headers_mut();
174
175 headers.extend([
176 (UpgradeHeader, "websocket"),
177 (Connection, "Upgrade"),
178 (SecWebsocketVersion, "13"),
179 ]);
180
181 headers.insert(SecWebsocketAccept, sec_websocket_accept);
182
183 if let Some(protocol) = protocol {
184 headers.insert(SecWebsocketProtocol, protocol);
185 }
186
187 conn.halt()
188 .with_state(websocket_peer_ip)
189 .with_state(IsWebsocket)
190 .with_status(Status::SwitchingProtocols)
191 }
192
193 pub fn new(handler: H) -> Self {
196 Self {
197 handler,
198 protocols: Default::default(),
199 config: None,
200 required: false,
201 }
202 }
203
204 pub fn with_protocols(self, protocols: &[&str]) -> Self {
208 Self {
209 protocols: protocols.iter().map(ToString::to_string).collect(),
210 ..self
211 }
212 }
213
214 pub fn with_protocol_config(self, config: WebSocketConfig) -> Self {
216 Self {
217 config: Some(config),
218 ..self
219 }
220 }
221
222 pub fn required(mut self) -> Self {
225 self.required = true;
226 self
227 }
228}
229
230struct IsWebsocket;
231
232#[cfg(test)]
233mod tests;
234
235struct WebsocketPeerIp(Option<IpAddr>);
239
240impl<H> Handler for WebSocket<H>
241where
242 H: WebSocketHandler,
243{
244 async fn run(&self, mut conn: Conn) -> Conn {
245 match conn.http_version() {
246 Version::Http1_0 | Version::Http1_1 => self.run_h1(conn).await,
247 Version::Http2 | Version::Http3 => {
252 if extended_connect_websocket_request(&conn) {
253 if !supported_websocket_version(&conn) {
254 return reject_unsupported_version(conn);
255 }
256
257 let websocket_peer_ip = WebsocketPeerIp(conn.peer_ip());
258 let protocol = websocket_protocol(&conn, &self.protocols);
259
260 if let Some(protocol) = protocol {
261 conn.response_headers_mut()
262 .insert(SecWebsocketProtocol, protocol);
263 }
264
265 conn.halt()
266 .with_state(websocket_peer_ip)
267 .with_state(IsWebsocket)
268 .with_status(Status::Ok)
269 } else if self.required {
270 conn.with_status(Status::UpgradeRequired).halt()
271 } else {
272 conn
273 }
274 }
275 _ => {
276 if self.required {
277 conn.with_status(Status::UpgradeRequired).halt()
278 } else {
279 conn
280 }
281 }
282 }
283 }
284
285 async fn init(&mut self, info: &mut Info) {
286 info.config_mut().set_extended_connect_enabled(true);
289 }
290
291 fn has_upgrade(&self, upgrade: &Upgrade) -> bool {
292 upgrade.state().contains::<IsWebsocket>()
293 }
294
295 async fn upgrade(&self, mut upgrade: Upgrade) {
296 let peer_ip = upgrade
297 .state_mut()
298 .take::<WebsocketPeerIp>()
299 .and_then(|i| i.0);
300 let mut conn = WebSocketConn::new(upgrade, self.config, Role::Server).await;
301 conn.set_peer_ip(peer_ip);
302
303 let Some((mut conn, outbound)) = self.handler.connect(conn).await else {
304 return;
305 };
306
307 let inbound = conn.take_inbound_stream();
308
309 let mut stream = std::pin::pin!(BidirectionalStream { inbound, outbound });
310 while let Some(message) = stream.next().await {
311 match message {
312 Direction::Inbound(Ok(Message::Close(close_frame))) => {
313 self.handler.disconnect(&mut conn, close_frame).await;
314 break;
315 }
316
317 Direction::Inbound(Ok(message)) => {
318 self.handler.inbound(message, &mut conn).await;
319 }
320
321 Direction::Outbound(message) => {
322 if let Err(e) = self.handler.send(message, &mut conn).await {
323 log::warn!("outbound websocket error: {:?}", e);
324 break;
325 }
326 }
327
328 _ => {
329 self.handler.disconnect(&mut conn, None).await;
330 break;
331 }
332 }
333 }
334
335 if let Some(err) = conn.close().await.err() {
336 log::warn!("websocket close error: {:?}", err);
337 };
338 }
339}
340
341fn websocket_protocol(conn: &Conn, protocols: &[String]) -> Option<String> {
342 conn.request_headers()
343 .get_str(SecWebsocketProtocol)
344 .and_then(|value| {
345 value
346 .split(',')
347 .map(str::trim)
348 .find(|req_p| protocols.iter().any(|x| x == req_p))
349 .map(|s| s.to_owned())
350 })
351}
352
353fn connection_is_upgrade(conn: &Conn) -> bool {
354 conn.request_headers()
355 .get_str(Connection)
356 .map(|connection| {
357 connection
358 .split(',')
359 .any(|c| c.trim().eq_ignore_ascii_case("upgrade"))
360 })
361 .unwrap_or(false)
362}
363
364fn upgrade_to_websocket(conn: &Conn) -> bool {
365 conn.request_headers()
366 .eq_ignore_ascii_case(UpgradeHeader, "websocket")
367}
368
369fn supported_websocket_version(conn: &Conn) -> bool {
370 conn.request_headers().get_str(SecWebsocketVersion) == Some("13")
371}
372
373fn reject_unsupported_version(conn: Conn) -> Conn {
374 conn.with_status(Status::UpgradeRequired)
375 .with_response_header(SecWebsocketVersion, "13")
376 .halt()
377}
378
379fn upgrade_requested(conn: &Conn) -> bool {
380 conn.method() == Method::Get
381 && conn.http_version() == Version::Http1_1
382 && connection_is_upgrade(conn)
383 && upgrade_to_websocket(conn)
384}
385
386fn extended_connect_websocket_request(conn: &Conn) -> bool {
391 if conn.method() != Method::Connect {
392 return false;
393 }
394 let inner: &trillium_http::Conn<Box<dyn trillium::Transport>> = conn.as_ref();
395 inner
396 .protocol()
397 .is_some_and(|p| p.eq_ignore_ascii_case("websocket"))
398}
399
400pub fn websocket_key() -> String {
402 BASE64.encode(fastrand::u128(..).to_ne_bytes())
403}
404
405pub fn websocket_accept_hash(websocket_key: &str) -> String {
407 let hash = Sha1::new()
408 .chain_update(websocket_key)
409 .chain_update(WEBSOCKET_GUID)
410 .finalize();
411 BASE64.encode(&hash[..])
412}