1use crate::{
2 Acceptor, ArcHandler, QuicConfig, RuntimeTrait, Server, ServerHandle,
3 running_config::RunningConfig,
4};
5use async_cell::sync::AsyncCell;
6use futures_lite::StreamExt;
7use std::{cell::OnceCell, net::SocketAddr, pin::pin, sync::Arc};
8use trillium::{Handler, Headers, HttpConfig, Info, KnownHeaderName, Swansong, TypeSet};
9use trillium_http::HttpContext;
10use url::Url;
11
12#[derive(Debug)]
54pub struct Config<ServerType: Server, AcceptorType, QuicType: QuicConfig<ServerType> = ()> {
55 pub(crate) acceptor: AcceptorType,
56 pub(crate) quic: QuicType,
57 pub(crate) binding: Option<ServerType>,
58 pub(crate) host: Option<String>,
59 pub(crate) context_cell: Arc<AsyncCell<Arc<HttpContext>>>,
60 pub(crate) max_connections: Option<usize>,
61 pub(crate) nodelay: bool,
62 pub(crate) port: Option<u16>,
63 pub(crate) register_signals: bool,
64 pub(crate) runtime: ServerType::Runtime,
65 pub(crate) context: HttpContext,
66}
67
68impl<ServerType, AcceptorType, QuicType> Config<ServerType, AcceptorType, QuicType>
69where
70 ServerType: Server,
71 AcceptorType: Acceptor<ServerType::Transport>,
72 QuicType: QuicConfig<ServerType>,
73{
74 pub fn run(self, handler: impl Handler) {
81 self.runtime.clone().block_on(self.run_async(handler));
82 }
83
84 pub async fn run_async(self, mut handler: impl Handler) {
90 #[cfg_attr(not(unix), allow(unused_mut))]
91 let Self {
92 runtime,
93 acceptor,
94 quic,
95 mut max_connections,
96 nodelay,
97 binding,
98 host,
99 port,
100 register_signals,
101 context,
102 context_cell,
103 } = self;
104
105 #[cfg(unix)]
106 if max_connections.is_none() {
107 max_connections = rlimit::getrlimit(rlimit::Resource::NOFILE)
108 .ok()
109 .and_then(|(soft, _hard)| soft.try_into().ok())
110 .map(|limit: usize| ((limit as f32) * 0.75) as usize);
111 };
112
113 log::debug!("using max connections of {:?}", max_connections);
114
115 let host = host
116 .or_else(|| std::env::var("HOST").ok())
117 .unwrap_or_else(|| "localhost".into());
118 let port = port
119 .or_else(|| {
120 std::env::var("PORT")
121 .ok()
122 .map(|x| x.parse().expect("PORT must be an unsigned integer"))
123 })
124 .unwrap_or(8080);
125
126 let listener = binding
127 .inspect(|_| log::debug!("taking prebound listener"))
128 .unwrap_or_else(|| ServerType::from_host_and_port(&host, port));
129
130 let swansong = context.swansong().clone();
131
132 let mut info = Info::from(context)
133 .with_shared_state(runtime.clone().into())
134 .with_shared_state(runtime.clone());
135
136 info.shared_state_entry::<Headers>()
137 .or_default()
138 .try_insert(KnownHeaderName::Server, trillium::headers::server_header());
139
140 listener.init(&mut info);
141
142 let quic_binding = if let Some(socket_addr) = info.tcp_socket_addr().copied() {
143 let quic_binding = quic
144 .bind(socket_addr, runtime.clone(), &mut info)
145 .map(|r| r.expect("failed to bind QUIC endpoint"));
146
147 if quic_binding.is_some() {
148 info.shared_state_entry::<Headers>()
149 .or_default()
150 .try_insert_with(KnownHeaderName::AltSvc, || -> &'static str {
151 format!("h3=\":{}\"", socket_addr.port()).leak()
152 });
153 }
154
155 quic_binding
156 } else {
157 None
158 };
159
160 insert_url(info.as_mut(), acceptor.is_secure());
161
162 handler.init(&mut info).await;
163
164 let context = Arc::new(HttpContext::from(info));
165
166 context_cell.set(context.clone());
167
168 if register_signals {
169 let runtime = runtime.clone();
170 runtime.clone().spawn(async move {
171 let mut signals = pin!(runtime.hook_signals([2, 3, 15]));
172 while signals.next().await.is_some() {
173 let guard_count = swansong.guard_count();
174 if swansong.state().is_shutting_down() {
175 eprintln!(
176 "\nSecond interrupt, shutting down harshly (dropping {guard_count} \
177 guards)"
178 );
179 std::process::exit(1);
180 } else {
181 println!(
182 "\nShutting down gracefully. Waiting for {guard_count} shutdown \
183 guards to drop.\nControl-c again to force."
184 );
185 swansong.shut_down();
186 }
187 }
188 });
189 }
190
191 let handler = ArcHandler::new(handler);
192
193 if let Some(quic_binding) = quic_binding {
194 let context = context.clone();
195 let handler = handler.clone();
196 runtime.clone().spawn(crate::h3::run_h3(
197 quic_binding,
198 context,
199 handler,
200 runtime.clone(),
201 ));
202 }
203
204 let running_config = Arc::new(RunningConfig {
205 acceptor,
206 max_connections,
207 context,
208 runtime,
209 nodelay,
210 });
211
212 running_config.run_async(listener, handler).await;
213 }
214
215 pub fn spawn(self, handler: impl Handler) -> ServerHandle {
222 let server_handle = self.handle();
223 self.runtime.clone().spawn(self.run_async(handler));
224 server_handle
225 }
226
227 pub fn handle(&self) -> ServerHandle {
230 ServerHandle {
231 swansong: self.context.swansong().clone(),
232 context: self.context_cell.clone(),
233 received_context: OnceCell::new(),
234 runtime: self.runtime().into(),
235 }
236 }
237
238 pub fn with_port(mut self, port: u16) -> Self {
241 if self.has_binding() {
242 log::warn!(
243 "constructing a config with both a port and a pre-bound listener will ignore the \
244 port"
245 );
246 }
247 self.port = Some(port);
248 self
249 }
250
251 pub fn with_host(mut self, host: &str) -> Self {
255 if self.has_binding() {
256 log::warn!(
257 "constructing a config with both a host and a pre-bound listener will ignore the \
258 host"
259 );
260 }
261 self.host = Some(host.into());
262 self
263 }
264
265 pub fn without_signals(mut self) -> Self {
270 self.register_signals = false;
271 self
272 }
273
274 pub fn with_nodelay(mut self) -> Self {
278 self.nodelay = true;
279 self
280 }
281
282 pub fn with_socketaddr(self, socketaddr: SocketAddr) -> Self {
286 self.with_host(&socketaddr.ip().to_string())
287 .with_port(socketaddr.port())
288 }
289
290 pub fn with_acceptor<A: Acceptor<ServerType::Transport>>(
292 self,
293 acceptor: A,
294 ) -> Config<ServerType, A, QuicType> {
295 Config {
296 acceptor,
297 quic: self.quic,
298 host: self.host,
299 port: self.port,
300 nodelay: self.nodelay,
301 register_signals: self.register_signals,
302 max_connections: self.max_connections,
303 context_cell: self.context_cell,
304 context: self.context,
305 binding: self.binding,
306 runtime: self.runtime,
307 }
308 }
309
310 pub fn with_quic<Q: QuicConfig<ServerType>>(
312 self,
313 quic: Q,
314 ) -> Config<ServerType, AcceptorType, Q> {
315 Config {
316 acceptor: self.acceptor,
317 quic,
318 host: self.host,
319 port: self.port,
320 nodelay: self.nodelay,
321 register_signals: self.register_signals,
322 max_connections: self.max_connections,
323 context_cell: self.context_cell,
324 context: self.context,
325 binding: self.binding,
326 runtime: self.runtime,
327 }
328 }
329
330 pub fn with_swansong(mut self, swansong: Swansong) -> Self {
332 self.context.set_swansong(swansong);
333 self
334 }
335
336 pub fn with_max_connections(mut self, max_connections: Option<usize>) -> Self {
340 self.max_connections = max_connections;
341 self
342 }
343
344 pub fn with_http_config(mut self, config: HttpConfig) -> Self {
348 *self.context.config_mut() = config;
349 self
350 }
351
352 pub fn with_prebound_server(mut self, server: impl Into<ServerType>) -> Self {
364 if self.host.is_some() {
365 log::warn!(
366 "constructing a config with both a host and a pre-bound listener will ignore the \
367 host"
368 );
369 }
370
371 if self.port.is_some() {
372 log::warn!(
373 "constructing a config with both a port and a pre-bound listener will ignore the \
374 port"
375 );
376 }
377
378 self.binding = Some(server.into());
379 self
380 }
381
382 fn has_binding(&self) -> bool {
383 self.binding.is_some()
384 }
385
386 pub fn runtime(&self) -> ServerType::Runtime {
388 self.runtime.clone()
389 }
390
391 pub fn port(&self) -> Option<u16> {
393 self.port
394 }
395
396 pub fn host(&self) -> Option<&str> {
398 self.host.as_deref()
399 }
400
401 pub fn with_shared_state<T: Send + Sync + 'static>(mut self, state: T) -> Self {
413 self.context.shared_state_mut().insert(state);
414 self
415 }
416
417 pub fn set_shared_state<T: Send + Sync + 'static>(&mut self, state: T) -> &mut Self {
429 self.context.shared_state_mut().insert(state);
430 self
431 }
432}
433
434impl<ServerType: Server> Config<ServerType, ()> {
435 pub fn new() -> Self {
437 Self::default()
438 }
439}
440
441impl<ServerType: Server> Default for Config<ServerType, ()> {
442 fn default() -> Self {
443 Self {
444 acceptor: (),
445 quic: (),
446 port: None,
447 host: None,
448 nodelay: false,
449 register_signals: cfg!(unix),
450 max_connections: None,
451 context_cell: AsyncCell::shared(),
452 binding: None,
453 runtime: ServerType::runtime(),
454 context: Default::default(),
455 }
456 }
457}
458
459fn insert_url(state: &mut TypeSet, secure: bool) -> Option<()> {
460 let socket_addr = state.get::<SocketAddr>().copied()?;
461 let vacant_entry = state.entry::<Url>().into_vacant()?;
462
463 let host = if socket_addr.ip().is_loopback() {
464 "localhost".to_string()
465 } else {
466 socket_addr.ip().to_string()
467 };
468
469 let url = match (secure, socket_addr.port()) {
470 (true, 443) => format!("https://{host}"),
471 (false, 80) => format!("http://{host}"),
472 (true, port) => format!("https://{host}:{port}/"),
473 (false, port) => format!("http://{host}:{port}/"),
474 };
475
476 let url = Url::parse(&url).ok()?;
477
478 vacant_entry.insert(url);
479 Some(())
480}