Skip to main content

trillium_tokio/
client.rs

1use crate::{TokioRuntime, TokioTransport};
2use async_compat::Compat;
3use std::{
4    io::{Error, ErrorKind, Result},
5    net::SocketAddr,
6    time::Duration,
7};
8use tokio::net::TcpStream;
9use trillium_server_common::{
10    Connector, Transport,
11    url::{Host, Url},
12};
13
14/// configuration for the tcp Connector
15#[derive(Default, Debug, Clone, Copy)]
16pub struct ClientConfig {
17    /// disable [nagle's algorithm](https://en.wikipedia.org/wiki/Nagle%27s_algorithm)
18    /// see [`TcpStream::set_nodelay`] for more info
19    pub nodelay: Option<bool>,
20
21    /// time to live for the tcp protocol. set [`TcpStream::set_ttl`] for more info
22    pub ttl: Option<u32>,
23
24    /// sets SO_LINGER. I don't really understand this, but see
25    /// [`TcpStream::set_linger`] for more info
26    pub linger: Option<Option<Duration>>,
27}
28
29impl ClientConfig {
30    /// constructs a default ClientConfig
31    pub const fn new() -> Self {
32        Self {
33            nodelay: None,
34            ttl: None,
35            linger: None,
36        }
37    }
38
39    /// chainable setter to set default nodelay
40    pub const fn with_nodelay(mut self, nodelay: bool) -> Self {
41        self.nodelay = Some(nodelay);
42        self
43    }
44
45    /// chainable setter for ip ttl
46    pub const fn with_ttl(mut self, ttl: u32) -> Self {
47        self.ttl = Some(ttl);
48        self
49    }
50
51    /// chainable setter for linger
52    pub const fn with_linger(mut self, linger: Option<Duration>) -> Self {
53        self.linger = Some(linger);
54        self
55    }
56}
57
58impl Connector for ClientConfig {
59    type Runtime = TokioRuntime;
60    type Transport = TokioTransport<Compat<TcpStream>>;
61    type Udp = crate::TokioUdpSocket;
62
63    async fn connect(&self, url: &Url) -> Result<Self::Transport> {
64        if url.scheme() != "http" {
65            return Err(Error::new(
66                ErrorKind::InvalidInput,
67                format!("unknown scheme {}", url.scheme()),
68            ));
69        }
70
71        let host = url
72            .host()
73            .ok_or_else(|| Error::new(ErrorKind::InvalidInput, format!("{url} missing host")))?;
74
75        let port = url
76            .port_or_known_default()
77            // this should be ok because we already checked that the scheme is http, which has a
78            // default port
79            .ok_or_else(|| Error::new(ErrorKind::InvalidInput, format!("{url} missing port")))?;
80
81        let mut tcp = match host {
82            Host::Domain(domain) => Self::Transport::connect((domain, port)).await?,
83            Host::Ipv4(ip) => Self::Transport::connect((ip, port)).await?,
84            Host::Ipv6(ip) => Self::Transport::connect((ip, port)).await?,
85        };
86
87        if let Some(nodelay) = self.nodelay {
88            tcp.set_nodelay(nodelay)?;
89        }
90
91        if let Some(ttl) = self.ttl {
92            tcp.set_ip_ttl(ttl)?;
93        }
94
95        if let Some(dur) = self.linger {
96            tcp.set_linger(dur)?;
97        }
98
99        Ok(tcp)
100    }
101
102    fn runtime(&self) -> Self::Runtime {
103        TokioRuntime::default()
104    }
105
106    async fn resolve(&self, host: &str, port: u16) -> Result<Vec<SocketAddr>> {
107        tokio::net::lookup_host((host, port))
108            .await
109            .map(Iterator::collect)
110    }
111}