1use crate::parse_utils::{parse_quoted_string, parse_token};
2use std::{
3 borrow::Cow,
4 fmt::{Display, Write},
5 net::IpAddr,
6};
7use trillium::{
8 Headers,
9 KnownHeaderName::{
10 Forwarded as ForwardedHeader, XforwardedBy, XforwardedFor, XforwardedHost, XforwardedProto,
11 XforwardedSsl,
12 },
13};
14
15#[derive(Debug, Clone, Default, PartialEq, Eq)]
18pub struct Forwarded<'a> {
19 by: Option<Cow<'a, str>>,
20 forwarded_for: Vec<Cow<'a, str>>,
21 host: Option<Cow<'a, str>>,
22 proto: Option<Cow<'a, str>>,
23}
24
25impl<'a> Forwarded<'a> {
26 pub fn from_headers(headers: &'a Headers) -> Result<Option<Self>, ParseError> {
86 if let Some(forwarded) = Self::from_forwarded_header(headers)? {
87 Ok(Some(forwarded))
88 } else {
89 Self::from_x_headers(headers)
90 }
91 }
92
93 pub fn from_forwarded_header(headers: &'a Headers) -> Result<Option<Self>, ParseError> {
124 if let Some(headers) = headers.get_str(ForwardedHeader) {
125 Ok(Some(Self::parse(headers)?))
126 } else {
127 Ok(None)
128 }
129 }
130
131 pub fn from_x_headers(headers: &'a Headers) -> Result<Option<Self>, ParseError> {
162 let forwarded_for: Vec<Cow<'a, str>> = headers
163 .get_str(XforwardedFor)
164 .map(|hv| {
165 hv.split(',')
166 .map(|v| {
167 let v = v.trim();
168 match v.parse::<IpAddr>().ok() {
169 Some(IpAddr::V6(v6)) => Cow::Owned(format!(r#"[{v6}]"#)),
170 _ => Cow::Borrowed(v),
171 }
172 })
173 .collect()
174 })
175 .unwrap_or_default();
176
177 let by = headers.get_str(XforwardedBy).map(Cow::Borrowed);
178
179 let proto = headers
180 .get_str(XforwardedProto)
181 .map(Cow::Borrowed)
182 .or_else(|| {
183 if headers.eq_ignore_ascii_case(XforwardedSsl, "on") {
184 Some(Cow::Borrowed("https"))
185 } else {
186 None
187 }
188 });
189
190 let host = headers.get_str(XforwardedHost).map(Cow::Borrowed);
191
192 if !forwarded_for.is_empty() || by.is_some() || proto.is_some() || host.is_some() {
193 Ok(Some(Self {
194 forwarded_for,
195 by,
196 proto,
197 host,
198 }))
199 } else {
200 Ok(None)
201 }
202 }
203
204 pub fn parse(input: &'a str) -> Result<Self, ParseError> {
226 let mut input = input;
227 let mut forwarded = Forwarded::new();
228
229 while !input.is_empty() {
230 input = if starts_with_ignore_case("for=", input) {
231 forwarded.parse_for(input)?
232 } else {
233 forwarded.parse_forwarded_pair(input)?
234 }
235 }
236
237 Ok(forwarded)
238 }
239
240 fn parse_forwarded_pair(&mut self, input: &'a str) -> Result<&'a str, ParseError> {
241 let (key, value, rest) = match parse_token(input) {
242 (Some(key), rest) if rest.starts_with('=') => match parse_value(&rest[1..]) {
243 (Some(value), rest) => Some((key, value, rest)),
244 (None, _) => None,
245 },
246 _ => None,
247 }
248 .ok_or_else(|| ParseError::new("parse error in forwarded-pair"))?;
249
250 match key {
251 "by" => {
252 if self.by.is_some() {
253 return Err(ParseError::new("parse error, duplicate `by` key"));
254 }
255 self.by = Some(value);
256 }
257
258 "host" => {
259 if self.host.is_some() {
260 return Err(ParseError::new("parse error, duplicate `host` key"));
261 }
262 self.host = Some(value);
263 }
264
265 "proto" => {
266 if self.proto.is_some() {
267 return Err(ParseError::new("parse error, duplicate `proto` key"));
268 }
269 self.proto = Some(value);
270 }
271
272 _ => { }
273 }
274
275 match rest.strip_prefix(';') {
276 Some(rest) => Ok(rest),
277 None => Ok(rest),
278 }
279 }
280
281 fn parse_for(&mut self, input: &'a str) -> Result<&'a str, ParseError> {
282 let mut rest = input;
283
284 loop {
285 rest = match match_ignore_case("for=", rest) {
286 (true, rest) => rest,
287 (false, _) => return Err(ParseError::new("http list must start with for=")),
288 };
289
290 let (value, rest_) = parse_value(rest);
291 rest = rest_;
292
293 if let Some(value) = value {
294 self.forwarded_for.push(value);
296 } else {
297 return Err(ParseError::new("for= without valid value"));
298 }
299
300 match rest.chars().next() {
301 Some(',') => {
303 rest = rest[1..].trim_start();
304 }
305
306 Some(';') => return Ok(&rest[1..]),
308
309 None => return Ok(rest),
311
312 _ => return Err(ParseError::new("unexpected character after for= section")),
314 }
315 }
316 }
317
318 pub fn into_owned(self) -> Forwarded<'static> {
321 Forwarded {
322 by: self.by.map(|by| Cow::Owned(by.into_owned())),
323 forwarded_for: self
324 .forwarded_for
325 .into_iter()
326 .map(|ff| Cow::Owned(ff.into_owned()))
327 .collect(),
328 host: self.host.map(|h| Cow::Owned(h.into_owned())),
329 proto: self.proto.map(|p| Cow::Owned(p.into_owned())),
330 }
331 }
332
333 pub fn new() -> Self {
335 Self::default()
336 }
337
338 pub fn add_for(&mut self, forwarded_for: impl Into<Cow<'a, str>>) {
340 self.forwarded_for.push(forwarded_for.into());
341 }
342
343 pub fn forwarded_for(&self) -> Vec<&str> {
345 self.forwarded_for.iter().map(|x| x.as_ref()).collect()
346 }
347
348 pub fn set_host(&mut self, host: impl Into<Cow<'a, str>>) {
350 self.host = Some(host.into());
351 }
352
353 pub fn host(&self) -> Option<&str> {
355 self.host.as_deref()
356 }
357
358 pub fn set_proto(&mut self, proto: impl Into<Cow<'a, str>>) {
360 self.proto = Some(proto.into())
361 }
362
363 pub fn proto(&self) -> Option<&str> {
365 self.proto.as_deref()
366 }
367
368 pub fn set_by(&mut self, by: impl Into<Cow<'a, str>>) {
370 self.by = Some(by.into());
371 }
372
373 pub fn by(&self) -> Option<&str> {
375 self.by.as_deref()
376 }
377}
378
379impl Display for Forwarded<'_> {
380 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
381 let mut needs_semi = false;
382 if let Some(by) = self.by() {
383 needs_semi = true;
384 write!(f, "by={}", format_value(by))?;
385 }
386
387 if !self.forwarded_for.is_empty() {
388 if needs_semi {
389 f.write_char(';')?;
390 }
391 needs_semi = true;
392 f.write_str(
393 &self
394 .forwarded_for
395 .iter()
396 .map(|f| format!("for={}", format_value(f)))
397 .collect::<Vec<_>>()
398 .join(", "),
399 )?;
400 }
401
402 if let Some(host) = self.host() {
403 if needs_semi {
404 f.write_char(';')?;
405 }
406 needs_semi = true;
407 write!(f, "host={}", format_value(host))?
408 }
409
410 if let Some(proto) = self.proto() {
411 if needs_semi {
412 f.write_char(';')?;
413 }
414 write!(f, "proto={}", format_value(proto))?
415 }
416
417 Ok(())
418 }
419}
420
421fn parse_value(input: &str) -> (Option<Cow<'_, str>>, &str) {
422 match parse_token(input) {
423 (Some(token), rest) => (Some(Cow::Borrowed(token)), rest),
424 (None, rest) => parse_quoted_string(rest),
425 }
426}
427
428fn format_value(input: &str) -> Cow<'_, str> {
429 match parse_token(input) {
430 (_, "") => input.into(),
431 _ => {
432 let mut string = String::from("\"");
433 for ch in input.chars() {
434 if let '\\' | '"' = ch {
435 string.push('\\');
436 }
437 string.push(ch);
438 }
439 string.push('"');
440 string.into()
441 }
442 }
443}
444
445fn match_ignore_case<'a>(start: &'static str, input: &'a str) -> (bool, &'a str) {
446 let len = start.len();
447 if input[..len].eq_ignore_ascii_case(start) {
448 (true, &input[len..])
449 } else {
450 (false, input)
451 }
452}
453
454fn starts_with_ignore_case(start: &'static str, input: &str) -> bool {
455 if start.len() <= input.len() {
456 let len = start.len();
457 input[..len].eq_ignore_ascii_case(start)
458 } else {
459 false
460 }
461}
462
463#[derive(Debug, Clone, Copy)]
464pub struct ParseError(&'static str);
465impl ParseError {
466 pub fn new(msg: &'static str) -> Self {
467 Self(msg)
468 }
469}
470
471impl std::error::Error for ParseError {}
472impl Display for ParseError {
473 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
474 write!(f, "unable to parse forwarded header: {}", self.0)
475 }
476}
477
478impl<'a> TryFrom<&'a str> for Forwarded<'a> {
479 type Error = ParseError;
480
481 fn try_from(value: &'a str) -> Result<Self, Self::Error> {
482 Self::parse(value)
483 }
484}
485
486#[cfg(test)]
487mod tests {
488 use super::*;
489 type Result = std::result::Result<(), ParseError>;
490
491 #[test]
492 fn starts_with_ignore_case_can_handle_short_inputs() {
493 assert!(!starts_with_ignore_case("helloooooo", "h"));
494 }
495
496 #[test]
497 fn parsing_for() -> Result {
498 assert_eq!(
499 Forwarded::parse(r#"for="_gazonk""#)?.forwarded_for(),
500 vec!["_gazonk"]
501 );
502 assert_eq!(
503 Forwarded::parse(r#"For="[2001:db8:cafe::17]:4711""#)?.forwarded_for(),
504 vec!["[2001:db8:cafe::17]:4711"]
505 );
506
507 assert_eq!(
508 Forwarded::parse("for=192.0.2.60;proto=http;by=203.0.113.43")?.forwarded_for(),
509 vec!["192.0.2.60"]
510 );
511
512 assert_eq!(
513 Forwarded::parse("for=192.0.2.43, for=198.51.100.17")?.forwarded_for(),
514 vec!["192.0.2.43", "198.51.100.17"]
515 );
516
517 assert_eq!(
518 Forwarded::parse(r#"for=192.0.2.43,for="[2001:db8:cafe::17]",for=unknown"#)?
519 .forwarded_for(),
520 Forwarded::parse(r#"for=192.0.2.43, for="[2001:db8:cafe::17]", for=unknown"#)?
521 .forwarded_for()
522 );
523
524 assert_eq!(
525 Forwarded::parse(
526 r#"for=192.0.2.43,for="this is a valid quoted-string, \" \\",for=unknown"#
527 )?
528 .forwarded_for(),
529 vec![
530 "192.0.2.43",
531 r#"this is a valid quoted-string, " \"#,
532 "unknown"
533 ]
534 );
535
536 Ok(())
537 }
538
539 #[test]
540 fn basic_parse() -> Result {
541 let forwarded = Forwarded::parse("for=client.com;by=proxy.com;host=host.com;proto=https")?;
542
543 assert_eq!(forwarded.by(), Some("proxy.com"));
544 assert_eq!(forwarded.forwarded_for(), vec!["client.com"]);
545 assert_eq!(forwarded.host(), Some("host.com"));
546 assert_eq!(forwarded.proto(), Some("https"));
547 assert!(matches!(forwarded, Forwarded { .. }));
548 Ok(())
549 }
550
551 #[test]
552 fn bad_parse() {
553 let err = Forwarded::parse("by=proxy.com;for=client;host=example.com;host").unwrap_err();
554 assert_eq!(
555 err.to_string(),
556 "unable to parse forwarded header: parse error in forwarded-pair"
557 );
558
559 let err = Forwarded::parse("by;for;host;proto").unwrap_err();
560 assert_eq!(
561 err.to_string(),
562 "unable to parse forwarded header: parse error in forwarded-pair"
563 );
564
565 let err = Forwarded::parse("for=for, key=value").unwrap_err();
566 assert_eq!(
567 err.to_string(),
568 "unable to parse forwarded header: http list must start with for="
569 );
570
571 let err = Forwarded::parse(r#"for="unterminated string"#).unwrap_err();
572 assert_eq!(
573 err.to_string(),
574 "unable to parse forwarded header: for= without valid value"
575 );
576
577 let err = Forwarded::parse(r#"for=, for=;"#).unwrap_err();
578 assert_eq!(
579 err.to_string(),
580 "unable to parse forwarded header: for= without valid value"
581 );
582 }
583
584 #[test]
585 fn bad_parse_from_headers() -> Result {
586 let mut headers = Headers::new();
587 headers.append("forwarded", "uh oh");
588 assert_eq!(
589 Forwarded::from_headers(&headers).unwrap_err().to_string(),
590 "unable to parse forwarded header: parse error in forwarded-pair"
591 );
592
593 let headers = Headers::new();
594 assert!(Forwarded::from_headers(&headers)?.is_none());
595 Ok(())
596 }
597
598 #[test]
599 fn from_x_headers() -> Result {
600 let mut headers = Headers::new();
601 headers.append(XforwardedFor, "192.0.2.43, 2001:db8:cafe::17");
602 headers.append(XforwardedProto, "gopher");
603 headers.append(XforwardedHost, "example.com");
604 let forwarded = Forwarded::from_headers(&headers)?.unwrap();
605 assert_eq!(
606 forwarded.to_string(),
607 r#"for=192.0.2.43, for="[2001:db8:cafe::17]";host=example.com;proto=gopher"#
608 );
609 Ok(())
610 }
611
612 #[test]
613 fn from_x_headers_with_ssl_on() -> Result {
614 let mut headers = Headers::new();
615 headers.append(XforwardedFor, "192.0.2.43, 2001:db8:cafe::17");
616 headers.append(XforwardedHost, "example.com");
617 headers.append(XforwardedSsl, "on");
618 let forwarded = Forwarded::from_headers(&headers)?.unwrap();
619 assert_eq!(
620 forwarded.to_string(),
621 r#"for=192.0.2.43, for="[2001:db8:cafe::17]";host=example.com;proto=https"#
622 );
623 Ok(())
624 }
625
626 #[test]
627 fn formatting_edge_cases() {
628 let mut forwarded = Forwarded::new();
629 forwarded.add_for(r#"quote: " backslash: \"#);
630 forwarded.add_for(";proto=https");
631 assert_eq!(
632 forwarded.to_string(),
633 r#"for="quote: \" backslash: \\", for=";proto=https""#
634 );
635
636 let mut forwarded = Forwarded::new();
637 forwarded.set_host("localhost:8080");
638 forwarded.set_proto("not:normal"); forwarded.set_by("localhost:8081");
640 assert_eq!(
641 forwarded.to_string(),
642 r#"by="localhost:8081";host="localhost:8080";proto="not:normal""#
643 );
644 }
645
646 #[test]
647 fn parse_edge_cases() -> Result {
648 let forwarded =
649 Forwarded::parse(r#"for=";", for=",", for="\"", for=unquoted;by=";proto=https""#)?;
650 assert_eq!(forwarded.forwarded_for(), vec![";", ",", "\"", "unquoted"]);
651 assert_eq!(forwarded.by(), Some(";proto=https"));
652 assert!(forwarded.proto().is_none());
653
654 let forwarded = Forwarded::parse("proto=https")?;
655 assert_eq!(forwarded.proto(), Some("https"));
656 Ok(())
657 }
658
659 #[test]
660 fn owned_parse() -> Result {
661 let forwarded =
662 Forwarded::parse("for=client;by=proxy.com;host=example.com;proto=https")?.into_owned();
663
664 assert_eq!(forwarded.by(), Some("proxy.com"));
665 assert_eq!(forwarded.forwarded_for(), vec!["client"]);
666 assert_eq!(forwarded.host(), Some("example.com"));
667 assert_eq!(forwarded.proto(), Some("https"));
668 assert!(matches!(forwarded, Forwarded { .. }));
669 Ok(())
670 }
671
672 #[test]
673 fn from_headers() -> Result {
674 let mut headers = Headers::new();
675 headers.append("Forwarded", "for=for");
676
677 let forwarded = Forwarded::from_headers(&headers)?.unwrap();
678 assert_eq!(forwarded.forwarded_for(), vec!["for"]);
679
680 Ok(())
681 }
682
683 #[test]
684 fn owned_can_outlive_headers() -> Result {
685 let forwarded = {
686 let mut headers = Headers::new();
687 headers.append("Forwarded", "for=for;by=by;host=host;proto=proto");
688 Forwarded::from_headers(&headers)?.unwrap().into_owned()
689 };
690 assert_eq!(forwarded.by(), Some("by"));
691 Ok(())
692 }
693
694 #[test]
695 fn round_trip() -> Result {
696 let inputs = [
697 "for=client,for=b,for=c;by=proxy.com;host=example.com;proto=https",
698 "by=proxy.com;proto=https;host=example.com;for=a,for=b",
699 "by=proxy.com",
700 "proto=https",
701 "host=example.com",
702 "for=a,for=b",
703 r#"by="localhost:8081";host="localhost:8080";proto="not:normal""#,
704 ];
705 for input in inputs {
706 let forwarded = Forwarded::parse(input)?;
707 let header = forwarded.to_string();
708 let parsed = Forwarded::parse(header.as_str())?;
709 assert_eq!(forwarded, parsed);
710 }
711 Ok(())
712 }
713}