trillium_client/
websocket.rs1use crate::{Conn, WebSocketConfig, WebSocketConn};
4use std::{
5 borrow::Cow,
6 error::Error,
7 fmt::{self, Display},
8 ops::{Deref, DerefMut},
9};
10use trillium_http::{
11 KnownHeaderName::{
12 Connection, SecWebsocketAccept, SecWebsocketKey, SecWebsocketVersion,
13 Upgrade as UpgradeHeader,
14 },
15 Method, Status, Upgrade, Version,
16};
17pub use trillium_websockets::Message;
18use trillium_websockets::{Role, websocket_accept_hash, websocket_key};
19
20impl Conn {
21 fn set_websocket_upgrade_headers_h1(&mut self) {
22 let headers = self.request_headers_mut();
23 headers.try_insert(UpgradeHeader, "websocket");
24 headers.try_insert(Connection, "upgrade");
25 headers.try_insert(SecWebsocketVersion, "13");
26 headers.try_insert(SecWebsocketKey, websocket_key());
27 }
28
29 pub async fn into_websocket(self) -> Result<WebSocketConn, WebSocketUpgradeError> {
44 self.into_websocket_with_config(WebSocketConfig::default())
45 .await
46 }
47
48 pub async fn into_websocket_with_config(
50 self,
51 config: WebSocketConfig,
52 ) -> Result<WebSocketConn, WebSocketUpgradeError> {
53 if self.status().is_some() {
54 return Err(WebSocketUpgradeError::new(self, ErrorKind::AlreadyExecuted));
55 }
56
57 match self.http_version() {
58 Version::Http2 => self.into_websocket_extended_connect(config).await,
59 Version::Http3 => Err(WebSocketUpgradeError::new(
60 self,
61 ErrorKind::ExtendedConnectUnsupported,
62 )),
63 _ => self.into_websocket_h1(config).await,
64 }
65 }
66
67 async fn into_websocket_h1(
68 mut self,
69 config: WebSocketConfig,
70 ) -> Result<WebSocketConn, WebSocketUpgradeError> {
71 self.set_websocket_upgrade_headers_h1();
72 if let Err(e) = (&mut self).await {
73 return Err(WebSocketUpgradeError::new(self, e.into()));
74 }
75 let status = self.status().expect("Response did not include status");
76 if status != Status::SwitchingProtocols {
77 return Err(WebSocketUpgradeError::new(self, ErrorKind::Status(status)));
78 }
79 let key = self
80 .request_headers()
81 .get_str(SecWebsocketKey)
82 .expect("Request did not include Sec-WebSocket-Key");
83 let accept_key = websocket_accept_hash(key);
84 if self.response_headers().get_str(SecWebsocketAccept) != Some(&accept_key) {
85 return Err(WebSocketUpgradeError::new(self, ErrorKind::InvalidAccept));
86 }
87 let peer_ip = self.peer_addr().map(|addr| addr.ip());
88 let mut conn = WebSocketConn::new(Upgrade::from(self), Some(config), Role::Client).await;
89 conn.set_peer_ip(peer_ip);
90 Ok(conn)
91 }
92
93 async fn into_websocket_extended_connect(
94 mut self,
95 config: WebSocketConfig,
96 ) -> Result<WebSocketConn, WebSocketUpgradeError> {
97 self.request_headers_mut()
103 .try_insert(SecWebsocketVersion, "13");
104 self.set_method(Method::Connect);
105 self.protocol = Some(Cow::Borrowed("websocket"));
106
107 if let Err(e) = (&mut self).await {
113 let kind = match e {
114 trillium_http::Error::ExtendedConnectUnsupported => {
115 ErrorKind::ExtendedConnectUnsupported
116 }
117 other => other.into(),
118 };
119 return Err(WebSocketUpgradeError::new(self, kind));
120 }
121
122 let status = self.status().expect("Response did not include status");
123 if status != Status::Ok {
124 return Err(WebSocketUpgradeError::new(self, ErrorKind::Status(status)));
125 }
126
127 let peer_ip = self.peer_addr().map(|addr| addr.ip());
128 let mut conn = WebSocketConn::new(Upgrade::from(self), Some(config), Role::Client).await;
129 conn.set_peer_ip(peer_ip);
130 Ok(conn)
131 }
132}
133
134#[derive(thiserror::Error, Debug)]
136#[non_exhaustive]
137pub enum ErrorKind {
138 #[error(transparent)]
140 Http(#[from] trillium_http::Error),
141
142 #[error("Unexpected response status {0} for websocket upgrade")]
145 Status(Status),
146
147 #[error("Response Sec-WebSocket-Accept was missing or invalid")]
149 InvalidAccept,
150
151 #[error(
155 "Conn::into_websocket called after execution — build the conn and await into_websocket \
156 instead of awaiting the conn separately"
157 )]
158 AlreadyExecuted,
159
160 #[error("peer does not support extended CONNECT, or h3 client websocket framing is missing")]
167 ExtendedConnectUnsupported,
168}
169
170#[derive(Debug)]
175pub struct WebSocketUpgradeError {
176 pub kind: ErrorKind,
178 conn: Box<Conn>,
179}
180
181impl WebSocketUpgradeError {
182 fn new(conn: Conn, kind: ErrorKind) -> Self {
183 let conn = Box::new(conn);
184 Self { conn, kind }
185 }
186}
187
188impl From<WebSocketUpgradeError> for Conn {
189 fn from(value: WebSocketUpgradeError) -> Self {
190 *value.conn
191 }
192}
193
194impl Deref for WebSocketUpgradeError {
195 type Target = Conn;
196
197 fn deref(&self) -> &Self::Target {
198 &self.conn
199 }
200}
201impl DerefMut for WebSocketUpgradeError {
202 fn deref_mut(&mut self) -> &mut Self::Target {
203 &mut self.conn
204 }
205}
206
207impl Error for WebSocketUpgradeError {}
208
209impl Display for WebSocketUpgradeError {
210 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
211 self.kind.fmt(f)
212 }
213}