Skip to main content

trillium_proxy/
forward_proxy_connect.rs

1use crate::bytes;
2use full_duplex_async_copy::full_duplex_copy;
3use std::fmt::Debug;
4use trillium::{Conn, Handler, Method, Status, Transport, Upgrade};
5use trillium_client::{ArcedConnector, Connector};
6use url::Url;
7
8#[derive(Debug)]
9/// trillium handler to implement Connect proxying
10pub struct ForwardProxyConnect(ArcedConnector);
11
12struct ForwardUpgrade(Box<dyn Transport>);
13
14impl Debug for ForwardUpgrade {
15    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
16        f.debug_tuple("ForwardUpgrade").finish_non_exhaustive()
17    }
18}
19
20impl ForwardProxyConnect {
21    /// construct a new ForwardProxyConnect
22    pub fn new(connector: impl Connector) -> Self {
23        Self(ArcedConnector::new(connector))
24    }
25}
26
27impl Handler for ForwardProxyConnect {
28    async fn run(&self, conn: Conn) -> Conn {
29        if conn.method() == Method::Connect {
30            let authority = {
31                let conn: &trillium_http::Conn<Box<dyn Transport>> = conn.as_ref();
32                conn.authority()
33            };
34
35            let Some(authority) = authority else {
36                return conn.with_status(Status::BadRequest).halt();
37            };
38
39            let Ok(url) = Url::parse(&format!("http://{}", authority)) else {
40                return conn.with_status(Status::BadRequest).halt();
41            };
42
43            if url.cannot_be_a_base() {
44                return conn.with_status(Status::BadRequest).halt();
45            }
46
47            let Ok(tcp) = Connector::connect(&self.0, &url).await else {
48                return conn.with_status(Status::BadGateway).halt();
49            };
50
51            conn.with_status(Status::Ok)
52                .with_state(ForwardUpgrade(tcp))
53                .halt()
54        } else {
55            conn
56        }
57    }
58
59    fn has_upgrade(&self, upgrade: &Upgrade) -> bool {
60        upgrade.state().contains::<ForwardUpgrade>()
61    }
62
63    async fn upgrade(&self, mut upgrade: Upgrade) {
64        let Some(ForwardUpgrade(upstream)) = upgrade.state_mut().take() else {
65            return;
66        };
67
68        let downstream = upgrade;
69        match full_duplex_copy(upstream, downstream).await {
70            Err(e) => log::error!("upgrade stream error: {:?}", e),
71            Ok((up, down)) => {
72                log::debug!("streamed upgrade {} up and {} down", bytes(up), bytes(down))
73            }
74        }
75    }
76}