trillium_client/
websocket.rs1use crate::{Conn, WebSocketConfig, WebSocketConn};
4use std::{
5 error::Error,
6 fmt::{self, Display},
7 ops::{Deref, DerefMut},
8};
9use trillium_http::{
10 KnownHeaderName::{
11 Connection, SecWebsocketAccept, SecWebsocketKey, SecWebsocketVersion,
12 Upgrade as UpgradeHeader,
13 },
14 Status, Upgrade,
15};
16pub use trillium_websockets::Message;
17use trillium_websockets::{Role, websocket_accept_hash, websocket_key};
18
19impl Conn {
20 fn set_websocket_upgrade_headers(&mut self) {
21 let headers = self.request_headers_mut();
22 headers.try_insert(UpgradeHeader, "websocket");
23 headers.try_insert(Connection, "upgrade");
24 headers.try_insert(SecWebsocketVersion, "13");
25 headers.try_insert(SecWebsocketKey, websocket_key());
26 }
27
28 pub async fn into_websocket(self) -> Result<WebSocketConn, WebSocketUpgradeError> {
38 self.into_websocket_with_config(WebSocketConfig::default())
39 .await
40 }
41
42 pub async fn into_websocket_with_config(
52 mut self,
53 config: WebSocketConfig,
54 ) -> Result<WebSocketConn, WebSocketUpgradeError> {
55 let status = match self.status() {
56 Some(status) => status,
57 None => {
58 self.set_websocket_upgrade_headers();
59 if let Err(e) = (&mut self).await {
60 return Err(WebSocketUpgradeError::new(self, e.into()));
61 }
62 self.status().expect("Response did not include status")
63 }
64 };
65 if status != Status::SwitchingProtocols {
66 return Err(WebSocketUpgradeError::new(self, ErrorKind::Status(status)));
67 }
68 let key = self
69 .request_headers()
70 .get_str(SecWebsocketKey)
71 .expect("Request did not include Sec-WebSocket-Key");
72 let accept_key = websocket_accept_hash(key);
73 if self.response_headers().get_str(SecWebsocketAccept) != Some(&accept_key) {
74 return Err(WebSocketUpgradeError::new(self, ErrorKind::InvalidAccept));
75 }
76 let peer_ip = self.peer_addr().map(|addr| addr.ip());
77 let mut conn = WebSocketConn::new(Upgrade::from(self), Some(config), Role::Client).await;
78 conn.set_peer_ip(peer_ip);
79 Ok(conn)
80 }
81}
82
83#[derive(thiserror::Error, Debug)]
85#[non_exhaustive]
86pub enum ErrorKind {
87 #[error(transparent)]
89 Http(#[from] trillium_http::Error),
90
91 #[error("Expected status 101 (Switching Protocols), got {0}")]
93 Status(Status),
94
95 #[error("Response Sec-WebSocket-Accept was missing or invalid")]
97 InvalidAccept,
98}
99
100#[derive(Debug)]
105pub struct WebSocketUpgradeError {
106 pub kind: ErrorKind,
108 conn: Box<Conn>,
109}
110
111impl WebSocketUpgradeError {
112 fn new(conn: Conn, kind: ErrorKind) -> Self {
113 let conn = Box::new(conn);
114 Self { conn, kind }
115 }
116}
117
118impl From<WebSocketUpgradeError> for Conn {
119 fn from(value: WebSocketUpgradeError) -> Self {
120 *value.conn
121 }
122}
123
124impl Deref for WebSocketUpgradeError {
125 type Target = Conn;
126
127 fn deref(&self) -> &Self::Target {
128 &self.conn
129 }
130}
131impl DerefMut for WebSocketUpgradeError {
132 fn deref_mut(&mut self) -> &mut Self::Target {
133 &mut self.conn
134 }
135}
136
137impl Error for WebSocketUpgradeError {}
138
139impl Display for WebSocketUpgradeError {
140 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
141 self.kind.fmt(f)
142 }
143}