1use crate::{
2 Acceptor, ArcHandler, QuicConfig, RuntimeTrait, Server, ServerHandle,
3 running_config::RunningConfig,
4};
5use async_cell::sync::AsyncCell;
6use futures_lite::StreamExt;
7use std::{cell::OnceCell, net::SocketAddr, pin::pin, sync::Arc};
8use trillium::{Handler, Headers, HttpConfig, Info, KnownHeaderName, Swansong, TypeSet};
9use trillium_http::HttpContext;
10use url::Url;
11
12#[derive(Debug)]
54pub struct Config<ServerType: Server, AcceptorType, QuicType: QuicConfig<ServerType> = ()> {
55 pub(crate) acceptor: AcceptorType,
56 pub(crate) quic: QuicType,
57 pub(crate) binding: Option<ServerType>,
58 pub(crate) host: Option<String>,
59 pub(crate) context_cell: Arc<AsyncCell<Arc<HttpContext>>>,
60 pub(crate) max_connections: Option<usize>,
61 pub(crate) nodelay: bool,
62 pub(crate) port: Option<u16>,
63 pub(crate) register_signals: bool,
64 pub(crate) runtime: ServerType::Runtime,
65 pub(crate) context: HttpContext,
66}
67
68impl<ServerType, AcceptorType, QuicType> Config<ServerType, AcceptorType, QuicType>
69where
70 ServerType: Server,
71 AcceptorType: Acceptor<ServerType::Transport>,
72 QuicType: QuicConfig<ServerType>,
73{
74 pub fn run(self, handler: impl Handler) {
81 self.runtime.clone().block_on(self.run_async(handler));
82 }
83
84 pub async fn run_async(self, mut handler: impl Handler) {
90 let Self {
91 runtime,
92 acceptor,
93 quic,
94 mut max_connections,
95 nodelay,
96 binding,
97 host,
98 port,
99 register_signals,
100 context,
101 context_cell,
102 } = self;
103
104 #[cfg(unix)]
105 if max_connections.is_none() {
106 max_connections = rlimit::getrlimit(rlimit::Resource::NOFILE)
107 .ok()
108 .and_then(|(soft, _hard)| soft.try_into().ok())
109 .map(|limit: usize| ((limit as f32) * 0.75) as usize);
110 };
111
112 log::debug!("using max connections of {:?}", max_connections);
113
114 let host = host
115 .or_else(|| std::env::var("HOST").ok())
116 .unwrap_or_else(|| "localhost".into());
117 let port = port
118 .or_else(|| {
119 std::env::var("PORT")
120 .ok()
121 .map(|x| x.parse().expect("PORT must be an unsigned integer"))
122 })
123 .unwrap_or(8080);
124
125 let listener = binding
126 .inspect(|_| log::debug!("taking prebound listener"))
127 .unwrap_or_else(|| ServerType::from_host_and_port(&host, port));
128
129 let swansong = context.swansong().clone();
130
131 let mut info = Info::from(context)
132 .with_shared_state(runtime.clone().into())
133 .with_shared_state(runtime.clone());
134
135 info.shared_state_entry::<Headers>()
136 .or_default()
137 .try_insert(KnownHeaderName::Server, trillium::headers::server_header());
138
139 listener.init(&mut info);
140
141 let quic_binding = if let Some(socket_addr) = info.tcp_socket_addr().copied() {
142 let quic_binding = quic
143 .bind(socket_addr, runtime.clone(), &mut info)
144 .map(|r| r.expect("failed to bind QUIC endpoint"));
145
146 if quic_binding.is_some() {
147 info.shared_state_entry::<Headers>()
148 .or_default()
149 .try_insert_with(KnownHeaderName::AltSvc, || -> &'static str {
150 format!("h3=\":{}\"", socket_addr.port()).leak()
151 });
152 }
153
154 quic_binding
155 } else {
156 None
157 };
158
159 insert_url(info.as_mut(), acceptor.is_secure());
160
161 handler.init(&mut info).await;
162
163 let context = Arc::new(HttpContext::from(info));
164
165 context_cell.set(context.clone());
166
167 if register_signals {
168 let runtime = runtime.clone();
169 runtime.clone().spawn(async move {
170 let mut signals = pin!(runtime.hook_signals([2, 3, 15]));
171 while signals.next().await.is_some() {
172 let guard_count = swansong.guard_count();
173 if swansong.state().is_shutting_down() {
174 eprintln!(
175 "\nSecond interrupt, shutting down harshly (dropping {guard_count} \
176 guards)"
177 );
178 std::process::exit(1);
179 } else {
180 println!(
181 "\nShutting down gracefully. Waiting for {guard_count} shutdown \
182 guards to drop.\nControl-c again to force."
183 );
184 swansong.shut_down();
185 }
186 }
187 });
188 }
189
190 let handler = ArcHandler::new(handler);
191
192 if let Some(quic_binding) = quic_binding {
193 let context = context.clone();
194 let handler = handler.clone();
195 runtime.clone().spawn(crate::h3::run_h3(
196 quic_binding,
197 context,
198 handler,
199 runtime.clone(),
200 ));
201 }
202
203 let running_config = Arc::new(RunningConfig {
204 acceptor,
205 max_connections,
206 context,
207 runtime,
208 nodelay,
209 });
210
211 running_config.run_async(listener, handler).await;
212 }
213
214 pub fn spawn(self, handler: impl Handler) -> ServerHandle {
221 let server_handle = self.handle();
222 self.runtime.clone().spawn(self.run_async(handler));
223 server_handle
224 }
225
226 pub fn handle(&self) -> ServerHandle {
229 ServerHandle {
230 swansong: self.context.swansong().clone(),
231 context: self.context_cell.clone(),
232 received_context: OnceCell::new(),
233 runtime: self.runtime().into(),
234 }
235 }
236
237 pub fn with_port(mut self, port: u16) -> Self {
240 if self.has_binding() {
241 log::warn!(
242 "constructing a config with both a port and a pre-bound listener will ignore the \
243 port"
244 );
245 }
246 self.port = Some(port);
247 self
248 }
249
250 pub fn with_host(mut self, host: &str) -> Self {
254 if self.has_binding() {
255 log::warn!(
256 "constructing a config with both a host and a pre-bound listener will ignore the \
257 host"
258 );
259 }
260 self.host = Some(host.into());
261 self
262 }
263
264 pub fn without_signals(mut self) -> Self {
269 self.register_signals = false;
270 self
271 }
272
273 pub fn with_nodelay(mut self) -> Self {
277 self.nodelay = true;
278 self
279 }
280
281 pub fn with_socketaddr(self, socketaddr: SocketAddr) -> Self {
285 self.with_host(&socketaddr.ip().to_string())
286 .with_port(socketaddr.port())
287 }
288
289 pub fn with_acceptor<A: Acceptor<ServerType::Transport>>(
291 self,
292 acceptor: A,
293 ) -> Config<ServerType, A, QuicType> {
294 Config {
295 acceptor,
296 quic: self.quic,
297 host: self.host,
298 port: self.port,
299 nodelay: self.nodelay,
300 register_signals: self.register_signals,
301 max_connections: self.max_connections,
302 context_cell: self.context_cell,
303 context: self.context,
304 binding: self.binding,
305 runtime: self.runtime,
306 }
307 }
308
309 pub fn with_quic<Q: QuicConfig<ServerType>>(
311 self,
312 quic: Q,
313 ) -> Config<ServerType, AcceptorType, Q> {
314 Config {
315 acceptor: self.acceptor,
316 quic,
317 host: self.host,
318 port: self.port,
319 nodelay: self.nodelay,
320 register_signals: self.register_signals,
321 max_connections: self.max_connections,
322 context_cell: self.context_cell,
323 context: self.context,
324 binding: self.binding,
325 runtime: self.runtime,
326 }
327 }
328
329 pub fn with_swansong(mut self, swansong: Swansong) -> Self {
331 self.context.set_swansong(swansong);
332 self
333 }
334
335 pub fn with_max_connections(mut self, max_connections: Option<usize>) -> Self {
339 self.max_connections = max_connections;
340 self
341 }
342
343 pub fn with_http_config(mut self, config: HttpConfig) -> Self {
347 *self.context.config_mut() = config;
348 self
349 }
350
351 pub fn with_prebound_server(mut self, server: impl Into<ServerType>) -> Self {
363 if self.host.is_some() {
364 log::warn!(
365 "constructing a config with both a host and a pre-bound listener will ignore the \
366 host"
367 );
368 }
369
370 if self.port.is_some() {
371 log::warn!(
372 "constructing a config with both a port and a pre-bound listener will ignore the \
373 port"
374 );
375 }
376
377 self.binding = Some(server.into());
378 self
379 }
380
381 fn has_binding(&self) -> bool {
382 self.binding.is_some()
383 }
384
385 pub fn runtime(&self) -> ServerType::Runtime {
387 self.runtime.clone()
388 }
389
390 pub fn port(&self) -> Option<u16> {
392 self.port
393 }
394
395 pub fn host(&self) -> Option<&str> {
397 self.host.as_deref()
398 }
399
400 pub fn with_shared_state<T: Send + Sync + 'static>(mut self, state: T) -> Self {
412 self.context.shared_state_mut().insert(state);
413 self
414 }
415
416 pub fn set_shared_state<T: Send + Sync + 'static>(&mut self, state: T) -> &mut Self {
428 self.context.shared_state_mut().insert(state);
429 self
430 }
431}
432
433impl<ServerType: Server> Config<ServerType, ()> {
434 pub fn new() -> Self {
436 Self::default()
437 }
438}
439
440impl<ServerType: Server> Default for Config<ServerType, ()> {
441 fn default() -> Self {
442 Self {
443 acceptor: (),
444 quic: (),
445 port: None,
446 host: None,
447 nodelay: false,
448 register_signals: cfg!(unix),
449 max_connections: None,
450 context_cell: AsyncCell::shared(),
451 binding: None,
452 runtime: ServerType::runtime(),
453 context: Default::default(),
454 }
455 }
456}
457
458fn insert_url(state: &mut TypeSet, secure: bool) -> Option<()> {
459 let socket_addr = state.get::<SocketAddr>().copied()?;
460 let vacant_entry = state.entry::<Url>().into_vacant()?;
461
462 let host = if socket_addr.ip().is_loopback() {
463 "localhost".to_string()
464 } else {
465 socket_addr.ip().to_string()
466 };
467
468 let url = match (secure, socket_addr.port()) {
469 (true, 443) => format!("https://{host}"),
470 (false, 80) => format!("http://{host}"),
471 (true, port) => format!("https://{host}:{port}/"),
472 (false, port) => format!("http://{host}:{port}/"),
473 };
474
475 let url = Url::parse(&url).ok()?;
476
477 vacant_entry.insert(url);
478 Some(())
479}