trillium_proxy/
forward_proxy_connect.rs1use crate::bytes;
2use full_duplex_async_copy::full_duplex_copy;
3use std::fmt::Debug;
4use trillium::{Conn, Handler, Method, Status, Transport, Upgrade};
5use trillium_client::{ArcedConnector, Connector};
6use url::Url;
7
8#[derive(Debug)]
9pub struct ForwardProxyConnect(ArcedConnector);
11
12struct ForwardUpgrade(Box<dyn Transport>);
13
14impl Debug for ForwardUpgrade {
15 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
16 f.debug_tuple("ForwardUpgrade").finish_non_exhaustive()
17 }
18}
19
20impl ForwardProxyConnect {
21 pub fn new(connector: impl Connector) -> Self {
23 Self(ArcedConnector::new(connector))
24 }
25}
26
27impl Handler for ForwardProxyConnect {
28 async fn run(&self, conn: Conn) -> Conn {
29 if conn.method() == Method::Connect {
30 let authority = {
31 let conn: &trillium_http::Conn<Box<dyn Transport>> = conn.as_ref();
32 conn.authority()
33 };
34
35 let Some(authority) = authority else {
36 return conn.with_status(Status::BadRequest).halt();
37 };
38
39 let Ok(url) = Url::parse(&format!("http://{}", authority)) else {
40 return conn.with_status(Status::BadRequest).halt();
41 };
42
43 if url.cannot_be_a_base() {
44 return conn.with_status(Status::BadRequest).halt();
45 }
46
47 let Ok(tcp) = Connector::connect(&self.0, &url).await else {
48 return conn.with_status(Status::BadGateway).halt();
49 };
50
51 conn.with_status(Status::Ok)
52 .with_state(ForwardUpgrade(tcp))
53 .halt()
54 } else {
55 conn
56 }
57 }
58
59 fn has_upgrade(&self, upgrade: &Upgrade) -> bool {
60 upgrade.state().contains::<ForwardUpgrade>()
61 }
62
63 async fn upgrade(&self, mut upgrade: Upgrade) {
64 let Some(ForwardUpgrade(upstream)) = upgrade.state_mut().take() else {
65 return;
66 };
67
68 let downstream = upgrade;
69 match full_duplex_copy(upstream, downstream).await {
70 Err(e) => log::error!("upgrade stream error: {:?}", e),
71 Ok((up, down)) => {
72 log::debug!("streamed upgrade {} up and {} down", bytes(up), bytes(down))
73 }
74 }
75 }
76}