trillium_websockets/
websocket_connection.rs1use crate::{Result, Role, WebSocketConfig};
2use async_tungstenite::{
3 WebSocketReceiver, WebSocketSender, WebSocketStream,
4 tungstenite::{self, Message},
5};
6use futures_lite::{Stream, StreamExt};
7use std::{
8 borrow::Cow,
9 fmt::Debug,
10 net::IpAddr,
11 pin::Pin,
12 sync::Arc,
13 task::{self, Poll},
14};
15use swansong::{Interrupt, Swansong};
16use trillium::{Headers, Method, Transport, TypeSet, Upgrade};
17use trillium_http::{HttpContext, type_set::entry::Entry};
18
19pub struct WebSocketConn {
28 request_headers: Headers,
29 path: Cow<'static, str>,
30 method: Method,
31 state: TypeSet,
32 peer_ip: Option<IpAddr>,
33 context: Arc<HttpContext>,
34 sink: WebSocketSender<Box<dyn Transport>>,
35 stream: Option<WStream>,
36}
37
38impl Debug for WebSocketConn {
39 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40 f.debug_struct("WebSocketConn")
41 .field("request_headers", &self.request_headers)
42 .field("path", &self.path)
43 .field("method", &self.method)
44 .field("state", &self.state)
45 .field("peer_ip", &self.peer_ip)
46 .field("context", &self.context)
47 .field("stream", &self.stream)
48 .finish_non_exhaustive()
49 }
50}
51
52impl WebSocketConn {
53 pub async fn send_string(&mut self, string: String) -> Result<()> {
55 self.send(Message::text(string)).await
56 }
57
58 pub async fn send_bytes(&mut self, bin: Vec<u8>) -> Result<()> {
60 self.send(Message::binary(bin)).await
61 }
62
63 #[cfg(feature = "json")]
64 pub async fn send_json(&mut self, json: &impl serde::Serialize) -> Result<()> {
67 self.send_string(serde_json::to_string(json)?).await
68 }
69
70 pub async fn send(&mut self, message: Message) -> Result<()> {
72 self.sink.send(message).await.map_err(Into::into)
73 }
74
75 #[doc(hidden)]
80 pub async fn new(
81 upgrade: impl Into<Upgrade>,
82 config: Option<WebSocketConfig>,
83 role: Role,
84 ) -> Self {
85 let mut upgrade = upgrade.into();
86 let request_headers = upgrade.take_request_headers();
87 let path = upgrade.path().to_string().into();
88 let method = upgrade.method();
89 let state = upgrade.take_state();
90 let context = upgrade.context().clone();
91 let peer_ip = upgrade.peer_ip();
92 let (buffer, transport) = upgrade.into_transport();
93
94 let wss = if buffer.is_empty() {
95 WebSocketStream::from_raw_socket(transport, role, config).await
96 } else {
97 WebSocketStream::from_partially_read(transport, buffer, role, config).await
98 };
99
100 let (sink, stream) = wss.split();
101 let stream = Some(WStream {
102 stream: context.swansong().interrupt(stream),
103 });
104
105 Self {
106 request_headers,
107 path,
108 method,
109 state,
110 peer_ip,
111 sink,
112 stream,
113 context,
114 }
115 }
116
117 pub fn swansong(&self) -> Swansong {
119 self.context.swansong().clone()
120 }
121
122 pub async fn close(&mut self) -> Result<()> {
124 self.send(Message::Close(None)).await
125 }
126
127 pub fn headers(&self) -> &Headers {
129 &self.request_headers
130 }
131
132 pub fn peer_ip(&self) -> Option<IpAddr> {
134 self.peer_ip
135 }
136
137 pub fn set_peer_ip(&mut self, peer_ip: Option<IpAddr>) -> &mut Self {
139 self.peer_ip = peer_ip;
140 self
141 }
142
143 pub fn path(&self) -> &str {
146 self.path.split('?').next().unwrap_or_default()
147 }
148
149 pub fn querystring(&self) -> &str {
152 self.path
153 .split_once('?')
154 .map(|(_, query)| query)
155 .unwrap_or_default()
156 }
157
158 pub fn method(&self) -> Method {
160 self.method
161 }
162
163 pub fn state<T: Send + Sync + 'static>(&self) -> Option<&T> {
168 self.state.get()
169 }
170
171 pub fn state_mut<T: Send + Sync + 'static>(&mut self) -> Option<&mut T> {
173 self.state.get_mut()
174 }
175
176 pub fn insert_state<T: Send + Sync + 'static>(&mut self, state: T) -> Option<T> {
180 self.state.insert(state)
181 }
182
183 pub fn state_entry<T: Send + Sync + 'static>(&mut self) -> Entry<'_, T> {
186 self.state.entry()
187 }
188
189 pub fn take_state<T: Send + Sync + 'static>(&mut self) -> Option<T> {
194 self.state.take()
195 }
196
197 pub fn take_inbound_stream(&mut self) -> Option<impl Stream<Item = MessageResult> + use<>> {
199 self.stream.take()
200 }
201
202 pub fn inbound_stream(&mut self) -> Option<impl Stream<Item = MessageResult> + '_> {
204 self.stream.as_mut()
205 }
206}
207
208type MessageResult = std::result::Result<Message, tungstenite::Error>;
209
210pub struct WStream {
211 stream: Interrupt<WebSocketReceiver<Box<dyn Transport>>>,
212}
213
214impl Debug for WStream {
215 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
216 f.debug_struct("WStream").finish_non_exhaustive()
217 }
218}
219
220impl Stream for WStream {
221 type Item = MessageResult;
222
223 fn poll_next(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Option<Self::Item>> {
224 self.stream.poll_next(cx)
225 }
226}
227
228impl AsMut<TypeSet> for WebSocketConn {
229 fn as_mut(&mut self) -> &mut TypeSet {
230 &mut self.state
231 }
232}
233
234impl AsRef<TypeSet> for WebSocketConn {
235 fn as_ref(&self) -> &TypeSet {
236 &self.state
237 }
238}
239
240impl Stream for WebSocketConn {
241 type Item = MessageResult;
242
243 fn poll_next(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Option<Self::Item>> {
244 match self.stream.as_mut() {
245 Some(stream) => stream.poll_next(cx),
246 None => Poll::Ready(None),
247 }
248 }
249}