1use crate::{Conn, ResponseBody};
19use futures_lite::{AsyncRead, stream::Stream};
20use std::{
21 collections::VecDeque,
22 error::Error,
23 fmt::{self, Debug, Display, Formatter},
24 ops::{Deref, DerefMut},
25 pin::Pin,
26 task::{Context, Poll, ready},
27 time::Duration,
28};
29use trillium_http::{KnownHeaderName, Status};
30
31const READ_BUF_LEN: usize = 8 * 1024;
32
33impl Conn {
34 pub async fn into_sse(mut self) -> Result<EventStream, SseError> {
49 if self.status().is_some() {
50 return Err(SseError::new(self, SseErrorKind::AlreadyExecuted));
51 }
52
53 self.request_headers_mut()
54 .try_insert(KnownHeaderName::Accept, "text/event-stream");
55
56 if let Err(e) = (&mut self).await {
57 return Err(SseError::new(self, e.into()));
58 }
59
60 let status = self.status().expect("Response did not include status");
61 if !status.is_success() {
62 return Err(SseError::new(self, SseErrorKind::Status(status)));
63 }
64
65 if !is_event_stream(
66 self.response_headers()
67 .get_str(KnownHeaderName::ContentType),
68 ) {
69 let content_type = self
70 .response_headers()
71 .get_str(KnownHeaderName::ContentType)
72 .map(String::from);
73 return Err(SseError::new(
74 self,
75 SseErrorKind::UnexpectedContentType(content_type),
76 ));
77 }
78
79 match self.take_response_body() {
80 Some(body) => Ok(EventStream::new(body)),
81 None => Err(SseError::new(self, SseErrorKind::NoBody)),
82 }
83 }
84}
85
86fn is_event_stream(content_type: Option<&str>) -> bool {
89 content_type.is_some_and(|ct| {
90 ct.split(';')
91 .next()
92 .is_some_and(|media_type| media_type.trim().eq_ignore_ascii_case("text/event-stream"))
93 })
94}
95
96#[derive(Debug, Clone, Eq, PartialEq)]
105pub struct Event {
106 data: String,
107 event_type: Option<String>,
108 id: Option<String>,
109 retry: Option<Duration>,
110}
111
112impl Event {
113 #[must_use]
115 pub fn data(&self) -> &str {
116 &self.data
117 }
118
119 #[must_use]
121 pub fn event_type(&self) -> Option<&str> {
122 self.event_type.as_deref()
123 }
124
125 #[must_use]
127 pub fn id(&self) -> Option<&str> {
128 self.id.as_deref()
129 }
130
131 #[must_use]
136 pub fn retry(&self) -> Option<Duration> {
137 self.retry
138 }
139}
140
141#[derive(Debug)]
148pub struct EventStream {
149 body: ResponseBody<'static>,
150 decoder: Decoder,
151 pending: VecDeque<Event>,
152 read_buf: Box<[u8]>,
153 done: bool,
154}
155
156impl EventStream {
157 fn new(body: ResponseBody<'static>) -> Self {
158 Self {
159 body,
160 decoder: Decoder::default(),
161 pending: VecDeque::new(),
162 read_buf: vec![0; READ_BUF_LEN].into_boxed_slice(),
163 done: false,
164 }
165 }
166}
167
168impl Stream for EventStream {
169 type Item = trillium_http::Result<Event>;
170
171 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
172 let this = self.get_mut();
173 loop {
174 if let Some(event) = this.pending.pop_front() {
175 return Poll::Ready(Some(Ok(event)));
176 }
177 if this.done {
178 return Poll::Ready(None);
179 }
180 match ready!(Pin::new(&mut this.body).poll_read(cx, &mut this.read_buf)) {
181 Ok(0) => {
183 this.done = true;
184 return Poll::Ready(None);
185 }
186 Ok(n) => this.decoder.push(&this.read_buf[..n], &mut this.pending),
187 Err(e) => {
188 this.done = true;
189 return Poll::Ready(Some(Err(e.into())));
190 }
191 }
192 }
193 }
194}
195
196#[derive(Debug, Default)]
202struct Decoder {
203 line: Vec<u8>,
204 last_char_was_cr: bool,
205 data: String,
206 event_type: Option<String>,
207 id: Option<String>,
208 retry: Option<Duration>,
209 has_data: bool,
210}
211
212impl Decoder {
213 fn push(&mut self, bytes: &[u8], out: &mut VecDeque<Event>) {
214 for &byte in bytes {
215 match byte {
216 b'\r' => {
217 self.line_done(out);
218 self.last_char_was_cr = true;
219 }
220 b'\n' if self.last_char_was_cr => self.last_char_was_cr = false,
221 b'\n' => self.line_done(out),
222 _ => {
223 self.last_char_was_cr = false;
224 self.line.push(byte);
225 }
226 }
227 }
228 }
229
230 fn line_done(&mut self, out: &mut VecDeque<Event>) {
231 if self.line.is_empty() {
232 self.dispatch(out);
233 } else {
234 let mut line = std::mem::take(&mut self.line);
235 self.process_field(&line);
236 line.clear();
237 self.line = line;
238 }
239 }
240
241 fn process_field(&mut self, line: &[u8]) {
242 let (field, value) = match line.iter().position(|&b| b == b':') {
243 Some(0) => return, Some(colon) => {
245 let value = &line[colon + 1..];
246 let value = value.strip_prefix(b" ").unwrap_or(value);
247 (&line[..colon], value)
248 }
249 None => (line, &b""[..]),
250 };
251
252 match field {
253 b"event" => self.event_type = Some(String::from_utf8_lossy(value).into_owned()),
254 b"data" => {
255 self.data.push_str(&String::from_utf8_lossy(value));
256 self.data.push('\n');
257 self.has_data = true;
258 }
259 b"id" => {
260 if !value.contains(&0) {
261 self.id = Some(String::from_utf8_lossy(value).into_owned());
262 }
263 }
264 b"retry" => {
265 if !value.is_empty()
266 && value.iter().all(u8::is_ascii_digit)
267 && let Ok(ms) = std::str::from_utf8(value).unwrap_or_default().parse()
268 {
269 self.retry = Some(Duration::from_millis(ms));
270 }
271 }
272 _ => {}
273 }
274 }
275
276 fn dispatch(&mut self, out: &mut VecDeque<Event>) {
277 if !self.has_data {
278 self.data.clear();
281 self.event_type = None;
282 return;
283 }
284
285 if self.data.ends_with('\n') {
286 self.data.pop();
287 }
288
289 out.push_back(Event {
290 data: std::mem::take(&mut self.data),
291 event_type: self.event_type.take().filter(|s| !s.is_empty()),
292 id: self.id.clone(),
293 retry: self.retry.take(),
294 });
295 self.has_data = false;
296 }
297}
298
299#[derive(thiserror::Error, Debug)]
301#[non_exhaustive]
302pub enum SseErrorKind {
303 #[error(transparent)]
305 Http(#[from] trillium_http::Error),
306
307 #[error("Unexpected response status {0} for SSE request")]
309 Status(Status),
310
311 #[error("Unexpected content-type for SSE request: {0:?}")]
313 UnexpectedContentType(Option<String>),
314
315 #[error(
319 "Conn::into_sse called after execution — build the conn and await into_sse instead of \
320 awaiting the conn separately"
321 )]
322 AlreadyExecuted,
323
324 #[error("SSE response had no body")]
326 NoBody,
327}
328
329#[derive(Debug)]
334pub struct SseError {
335 pub kind: SseErrorKind,
337 conn: Box<Conn>,
338}
339
340impl SseError {
341 fn new(conn: Conn, kind: SseErrorKind) -> Self {
342 Self {
343 kind,
344 conn: Box::new(conn),
345 }
346 }
347}
348
349impl From<SseError> for Conn {
350 fn from(value: SseError) -> Self {
351 *value.conn
352 }
353}
354
355impl Deref for SseError {
356 type Target = Conn;
357
358 fn deref(&self) -> &Self::Target {
359 &self.conn
360 }
361}
362
363impl DerefMut for SseError {
364 fn deref_mut(&mut self) -> &mut Self::Target {
365 &mut self.conn
366 }
367}
368
369impl Error for SseError {}
370
371impl Display for SseError {
372 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
373 Display::fmt(&self.kind, f)
374 }
375}
376
377#[cfg(test)]
378mod tests {
379 use super::*;
380
381 fn decode(input: &[u8]) -> Vec<Event> {
385 let mut whole = Decoder::default();
386 let mut whole_out = VecDeque::new();
387 whole.push(input, &mut whole_out);
388
389 let mut split = Decoder::default();
390 let mut split_out = VecDeque::new();
391 for byte in input {
392 split.push(&[*byte], &mut split_out);
393 }
394
395 assert_eq!(whole_out, split_out, "chunked decode diverged from whole");
396 whole_out.into()
397 }
398
399 #[test]
400 fn fields_comments_and_terminators() {
401 let events =
402 decode(b": this is a comment\nevent: greeting\ndata: hello\nid: 42\nretry: 3000\n\n");
403 assert_eq!(events.len(), 1);
404 let event = &events[0];
405 assert_eq!(event.data(), "hello");
406 assert_eq!(event.event_type(), Some("greeting"));
407 assert_eq!(event.id(), Some("42"));
408 assert_eq!(event.retry(), Some(Duration::from_millis(3000)));
409 }
410
411 #[test]
412 fn multiline_data_joins_with_newline() {
413 let events = decode(b"data: one\ndata: two\ndata:three\n\n");
414 assert_eq!(events[0].data(), "one\ntwo\nthree");
416 }
417
418 #[test]
419 fn crlf_and_cr_terminators() {
420 let crlf = decode(b"data: a\r\n\r\n");
421 assert_eq!(crlf[0].data(), "a");
422 let cr = decode(b"data: b\r\r");
423 assert_eq!(cr[0].data(), "b");
424 }
425
426 #[test]
427 fn empty_data_line_dispatches_empty_event() {
428 let events = decode(b"data\n\n");
430 assert_eq!(events.len(), 1);
431 assert_eq!(events[0].data(), "");
432 }
433
434 #[test]
435 fn blank_lines_without_data_dispatch_nothing() {
436 assert!(decode(b"\n\n\n").is_empty());
437 assert!(decode(b": just a comment\n\n").is_empty());
438 }
439
440 #[test]
441 fn incomplete_trailing_event_is_discarded() {
442 assert!(decode(b"data: pending\n").is_empty());
444 }
445
446 #[test]
447 fn id_persists_across_events_retry_does_not() {
448 let events = decode(b"id: 1\nretry: 500\ndata: a\n\ndata: b\n\n");
449 assert_eq!(events[0].id(), Some("1"));
450 assert_eq!(events[0].retry(), Some(Duration::from_millis(500)));
451 assert_eq!(events[1].id(), Some("1"));
453 assert_eq!(events[1].retry(), None);
454 }
455
456 #[test]
457 fn invalid_retry_is_ignored() {
458 let events = decode(b"retry: not-a-number\ndata: a\n\n");
459 assert_eq!(events[0].retry(), None);
460 }
461}