Skip to main content

trillium_websockets/
websocket_connection.rs

1use crate::{Result, Role, WebSocketConfig};
2use async_tungstenite::{
3    WebSocketReceiver, WebSocketSender, WebSocketStream,
4    tungstenite::{self, Message},
5};
6use futures_lite::{Stream, StreamExt};
7use std::{
8    borrow::Cow,
9    fmt::Debug,
10    net::IpAddr,
11    pin::Pin,
12    sync::Arc,
13    task::{self, Poll},
14};
15use swansong::{Interrupt, Swansong};
16use trillium::{Headers, Method, Transport, TypeSet, Upgrade};
17use trillium_http::{HttpContext, type_set::entry::Entry};
18
19/// A struct that represents an specific websocket connection.
20///
21/// This can be thought of as a combination of a [`async_tungstenite::WebSocketStream`] and a
22/// [`trillium::Conn`], as it contains a combination of their fields and
23/// associated functions.
24///
25/// The WebSocketConn implements `Stream<Item=Result<Message, Error>>`,
26/// and can be polled with `StreamExt::next`
27pub struct WebSocketConn {
28    request_headers: Headers,
29    path: Cow<'static, str>,
30    method: Method,
31    state: TypeSet,
32    peer_ip: Option<IpAddr>,
33    context: Arc<HttpContext>,
34    sink: WebSocketSender<Box<dyn Transport>>,
35    stream: Option<WStream>,
36}
37
38impl Debug for WebSocketConn {
39    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40        f.debug_struct("WebSocketConn")
41            .field("request_headers", &self.request_headers)
42            .field("path", &self.path)
43            .field("method", &self.method)
44            .field("state", &self.state)
45            .field("peer_ip", &self.peer_ip)
46            .field("context", &self.context)
47            .field("stream", &self.stream)
48            .finish_non_exhaustive()
49    }
50}
51
52impl WebSocketConn {
53    /// send a [`Message::Text`] variant
54    pub async fn send_string(&mut self, string: String) -> Result<()> {
55        self.send(Message::text(string)).await
56    }
57
58    /// send a [`Message::Binary`] variant
59    pub async fn send_bytes(&mut self, bin: Vec<u8>) -> Result<()> {
60        self.send(Message::binary(bin)).await
61    }
62
63    #[cfg(feature = "json")]
64    /// send a [`Message::Text`] that contains json
65    /// note that json messages are not actually part of the websocket specification
66    pub async fn send_json(&mut self, json: &impl serde::Serialize) -> Result<()> {
67        self.send_string(serde_json::to_string(json)?).await
68    }
69
70    /// Sends a [`Message`] to the client
71    pub async fn send(&mut self, message: Message) -> Result<()> {
72        self.sink.send(message).await.map_err(Into::into)
73    }
74
75    /// Create a `WebSocketConn` from an HTTP upgrade, with optional config and the specified role
76    ///
77    /// You should not typically need to call this; the trillium client and server both provide
78    /// your code with a `WebSocketConn`.
79    #[doc(hidden)]
80    pub async fn new(
81        upgrade: impl Into<Upgrade>,
82        config: Option<WebSocketConfig>,
83        role: Role,
84    ) -> Self {
85        let mut upgrade = upgrade.into();
86        let request_headers = upgrade.take_request_headers();
87        let path = upgrade.path().to_string().into();
88        let method = upgrade.method();
89        let state = upgrade.take_state();
90        let context = upgrade.context().clone();
91        let peer_ip = upgrade.peer_ip();
92        let (buffer, transport) = upgrade.into_transport();
93
94        let wss = if buffer.is_empty() {
95            WebSocketStream::from_raw_socket(transport, role, config).await
96        } else {
97            WebSocketStream::from_partially_read(transport, buffer, role, config).await
98        };
99
100        let (sink, stream) = wss.split();
101        let stream = Some(WStream {
102            stream: context.swansong().interrupt(stream),
103        });
104
105        Self {
106            request_headers,
107            path,
108            method,
109            state,
110            peer_ip,
111            sink,
112            stream,
113            context,
114        }
115    }
116
117    /// retrieve a clone of the server's [`Swansong`]
118    pub fn swansong(&self) -> Swansong {
119        self.context.swansong().clone()
120    }
121
122    /// close the websocket connection gracefully
123    pub async fn close(&mut self) -> Result<()> {
124        self.send(Message::Close(None)).await
125    }
126
127    /// retrieve the request headers for this conn
128    pub fn headers(&self) -> &Headers {
129        &self.request_headers
130    }
131
132    /// retrieves the peer ip for this conn, if available
133    pub fn peer_ip(&self) -> Option<IpAddr> {
134        self.peer_ip
135    }
136
137    /// Sets the peer ip for this conn
138    pub fn set_peer_ip(&mut self, peer_ip: Option<IpAddr>) -> &mut Self {
139        self.peer_ip = peer_ip;
140        self
141    }
142
143    /// retrieves the path part of the request url, up to and excluding
144    /// any query component
145    pub fn path(&self) -> &str {
146        self.path.split('?').next().unwrap_or_default()
147    }
148
149    /// Retrieves the query component of the path, excluding `?`. Returns
150    /// an empty string if there is no query component.
151    pub fn querystring(&self) -> &str {
152        self.path
153            .split_once('?')
154            .map(|(_, query)| query)
155            .unwrap_or_default()
156    }
157
158    /// retrieve the request method for this conn
159    pub fn method(&self) -> Method {
160        self.method
161    }
162
163    /// retrieve state from the state set that has been accumulated by
164    /// trillium handlers run on the [`trillium::Conn`] before it
165    /// became a websocket. see [`trillium::Conn::state`] for more
166    /// information
167    pub fn state<T: Send + Sync + 'static>(&self) -> Option<&T> {
168        self.state.get()
169    }
170
171    /// retrieve a mutable borrow of the state from the state set
172    pub fn state_mut<T: Send + Sync + 'static>(&mut self) -> Option<&mut T> {
173        self.state.get_mut()
174    }
175
176    /// inserts new state
177    ///
178    /// returns the previously set state of the same type, if any existed
179    pub fn insert_state<T: Send + Sync + 'static>(&mut self, state: T) -> Option<T> {
180        self.state.insert(state)
181    }
182
183    /// Returns an [`Entry`] for the state typeset that can be used with functions like
184    /// [`Entry::or_insert`], [`Entry::or_insert_with`], [`Entry::and_modify`], and others.
185    pub fn state_entry<T: Send + Sync + 'static>(&mut self) -> Entry<'_, T> {
186        self.state.entry()
187    }
188
189    /// take some type T out of the state set that has been
190    /// accumulated by trillium handlers run on the [`trillium::Conn`]
191    /// before it became a websocket. see [`trillium::Conn::take_state`]
192    /// for more information
193    pub fn take_state<T: Send + Sync + 'static>(&mut self) -> Option<T> {
194        self.state.take()
195    }
196
197    /// take the inbound Message stream from this conn
198    pub fn take_inbound_stream(&mut self) -> Option<impl Stream<Item = MessageResult> + use<>> {
199        self.stream.take()
200    }
201
202    /// borrow the inbound Message stream from this conn
203    pub fn inbound_stream(&mut self) -> Option<impl Stream<Item = MessageResult> + '_> {
204        self.stream.as_mut()
205    }
206}
207
208type MessageResult = std::result::Result<Message, tungstenite::Error>;
209
210pub struct WStream {
211    stream: Interrupt<WebSocketReceiver<Box<dyn Transport>>>,
212}
213
214impl Debug for WStream {
215    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
216        f.debug_struct("WStream").finish_non_exhaustive()
217    }
218}
219
220impl Stream for WStream {
221    type Item = MessageResult;
222
223    fn poll_next(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Option<Self::Item>> {
224        self.stream.poll_next(cx)
225    }
226}
227
228impl AsMut<TypeSet> for WebSocketConn {
229    fn as_mut(&mut self) -> &mut TypeSet {
230        &mut self.state
231    }
232}
233
234impl AsRef<TypeSet> for WebSocketConn {
235    fn as_ref(&self) -> &TypeSet {
236        &self.state
237    }
238}
239
240impl Stream for WebSocketConn {
241    type Item = MessageResult;
242
243    fn poll_next(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Option<Self::Item>> {
244        match self.stream.as_mut() {
245            Some(stream) => stream.poll_next(cx),
246            None => Poll::Ready(None),
247        }
248    }
249}