use std::{ convert::Infallible, error::Error, net::SocketAddr, path::PathBuf, sync::Arc, time::Duration, }; use futures::FutureExt; use hyper::{ Body, Request, Response, Server, service::{ make_service_fn, service_fn, }, StatusCode, }; use serde::Deserialize; use tokio::{ sync::{ oneshot, watch, }, time::delay_for, }; use tracing::{ debug, error, info, trace, warn, }; use ptth::{ http_serde::RequestParts, prefix_match, server::file_server, }; #[derive (Default)] pub struct Config { pub file_server_root: Option , } struct ServerState <'a> { config: Config, handlebars: handlebars::Handlebars <'a>, shutdown_watch_rx: watch::Receiver , } fn status_reply > (status: StatusCode, b: B) -> Response { Response::builder ().status (status).body (b.into ()).unwrap () } async fn handle_all (req: Request , state: Arc >) -> Result , String> { let path = req.uri ().path (); //println! ("{}", path); if let Some (path) = prefix_match (path, "/files") { let path = path.into (); let (parts, _) = req.into_parts (); let ptth_req = match RequestParts::from_hyper (parts.method, path, parts.headers) { Ok (x) => x, _ => return Ok (status_reply (StatusCode::BAD_REQUEST, "Bad request")), }; let default_root = PathBuf::from ("./"); let file_server_root: &std::path::Path = state.config.file_server_root .as_ref () .unwrap_or (&default_root); let mut shutdown_watch_rx = state.shutdown_watch_rx.clone (); if shutdown_watch_rx.recv ().await != Some (false) { error! ("Can't serve, I'm shutting down"); panic! ("Can't serve, I'm shutting down"); } let ptth_resp = file_server::serve_all ( &state.handlebars, file_server_root, ptth_req.method, &ptth_req.uri, &ptth_req.headers, Some (shutdown_watch_rx) ).await; let mut resp = Response::builder () .status (StatusCode::from (ptth_resp.parts.status_code)); use std::str::FromStr; for (k, v) in ptth_resp.parts.headers.into_iter () { resp = resp.header (hyper::header::HeaderName::from_str (&k).unwrap (), v); } let body = ptth_resp.body .map (Body::wrap_stream) .unwrap_or_else (Body::empty) ; let resp = resp.body (body).unwrap (); Ok (resp) } else { Ok (status_reply (StatusCode::NOT_FOUND, "404 Not Found\n")) } } #[derive (Deserialize)] pub struct ConfigFile { pub file_server_root: Option , } #[tokio::main] async fn main () -> Result <(), Box > { tracing_subscriber::fmt::init (); let config_file: ConfigFile = ptth::load_toml::load ("config/ptth_server.toml"); info! ("file_server_root: {:?}", config_file.file_server_root); let addr = SocketAddr::from(([0, 0, 0, 0], 4000)); let handlebars = file_server::load_templates ()?; let (shutdown_watch_tx, shutdown_watch_rx) = watch::channel (false); let state = Arc::new (ServerState { config: Config { file_server_root: config_file.file_server_root, }, handlebars, shutdown_watch_rx, }); let make_svc = make_service_fn (|_conn| { let state = state.clone (); async { Ok::<_, String> (service_fn (move |req| { let state = state.clone (); handle_all (req, state) })) } }); let shutdown_oneshot = ptth::graceful_shutdown::init (); let (force_shutdown_tx, force_shutdown_rx) = oneshot::channel (); let server = Server::bind (&addr) .serve (make_svc) .with_graceful_shutdown (async move { shutdown_oneshot.await.ok (); info! ("Received graceful shutdown"); shutdown_watch_tx.broadcast (true).unwrap (); force_shutdown_tx.send (()).unwrap (); }); let force_shutdown_fut = async move { force_shutdown_rx.await.unwrap (); delay_for (Duration::from_secs (5)).await; error! ("Forcing shutdown"); }; futures::select! { x = server.fuse () => x?, _ = force_shutdown_fut.fuse () => (), }; Ok (()) }