use crate::{Result, WebSocketConfig, WebsocketPeerIp};
use async_tungstenite::{
tungstenite::{self, protocol::Role, Message},
WebSocketStream,
};
use futures_util::{
stream::{SplitSink, SplitStream, Stream},
SinkExt, StreamExt,
};
use std::{
net::IpAddr,
pin::Pin,
task::{Context, Poll},
};
use stopper::{Stopper, StreamStopper};
use trillium::{Headers, Method, StateSet, Upgrade};
use trillium_http::transport::BoxedTransport;
#[derive(Debug)]
pub struct WebSocketConn {
request_headers: Headers,
path: String,
method: Method,
state: StateSet,
stopper: Stopper,
sink: SplitSink<Wss, Message>,
stream: Option<WStream>,
}
type Wss = WebSocketStream<BoxedTransport>;
impl WebSocketConn {
pub async fn send_string(&mut self, string: String) -> Result<()> {
self.send(Message::Text(string)).await.map_err(Into::into)
}
pub async fn send_bytes(&mut self, bin: Vec<u8>) -> Result<()> {
self.send(Message::Binary(bin)).await.map_err(Into::into)
}
#[cfg(feature = "json")]
pub async fn send_json(&mut self, json: &impl serde::Serialize) -> Result<()> {
self.send_string(serde_json::to_string(json)?).await
}
pub async fn send(&mut self, message: Message) -> Result<()> {
self.sink.send(message).await.map_err(Into::into)
}
pub(crate) async fn new(upgrade: Upgrade, config: Option<WebSocketConfig>) -> Self {
let Upgrade {
request_headers,
path,
method,
state,
buffer,
transport,
stopper,
} = upgrade;
let wss = if let Some(vec) = buffer {
WebSocketStream::from_partially_read(transport, vec, Role::Server, config).await
} else {
WebSocketStream::from_raw_socket(transport, Role::Server, config).await
};
let (sink, stream) = wss.split();
let stream = Some(WStream {
stream: stopper.stop_stream(stream),
});
Self {
request_headers,
path,
method,
state,
sink,
stream,
stopper,
}
}
pub fn stopper(&self) -> Stopper {
self.stopper.clone()
}
pub async fn close(&mut self) -> Result<()> {
self.send(Message::Close(None)).await
}
pub fn headers(&self) -> &Headers {
&self.request_headers
}
pub fn peer_ip(&self) -> Option<IpAddr> {
self.state.get::<WebsocketPeerIp>().and_then(|i| i.0)
}
pub fn path(&self) -> &str {
self.path.split('?').next().unwrap_or_default()
}
pub fn querystring(&self) -> &str {
match self.path.split_once('?') {
Some((_, query)) => query,
None => "",
}
}
pub fn method(&self) -> Method {
self.method
}
pub fn state<T: 'static>(&self) -> Option<&T> {
self.state.get()
}
pub fn state_mut<T: 'static>(&mut self) -> Option<&mut T> {
self.state.get_mut()
}
pub fn set_state<T: Send + Sync + 'static>(&mut self, val: T) {
self.state.insert(val);
}
pub fn take_state<T: 'static>(&mut self) -> Option<T> {
self.state.take()
}
pub fn take_inbound_stream(&mut self) -> Option<impl Stream<Item = MessageResult>> {
self.stream.take()
}
pub fn inbound_stream(&mut self) -> Option<impl Stream<Item = MessageResult> + '_> {
self.stream.as_mut()
}
}
type MessageResult = std::result::Result<Message, tungstenite::Error>;
#[derive(Debug)]
pub struct WStream {
stream: StreamStopper<SplitStream<Wss>>,
}
impl Stream for WStream {
type Item = MessageResult;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.stream.poll_next_unpin(cx)
}
}
impl AsMut<StateSet> for WebSocketConn {
fn as_mut(&mut self) -> &mut StateSet {
&mut self.state
}
}
impl AsRef<StateSet> for WebSocketConn {
fn as_ref(&self) -> &StateSet {
&self.state
}
}
impl Stream for WebSocketConn {
type Item = MessageResult;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match self.stream.as_mut() {
Some(stream) => stream.poll_next_unpin(cx),
None => Poll::Ready(None),
}
}
}