diff --git a/src/bin/ptth_file_server.rs b/src/bin/ptth_file_server.rs index 5636d06..aec5855 100644 --- a/src/bin/ptth_file_server.rs +++ b/src/bin/ptth_file_server.rs @@ -1,11 +1,13 @@ use std::{ convert::Infallible, error::Error, + net::SocketAddr, path::PathBuf, sync::Arc, - net::SocketAddr, + time::Duration, }; +use futures::FutureExt; use hyper::{ Body, Request, @@ -18,8 +20,15 @@ use hyper::{ StatusCode, }; use serde::Deserialize; +use tokio::{ + sync::{ + oneshot, + watch, + }, + time::delay_for, +}; use tracing::{ - debug, info, trace, warn, + debug, error, info, trace, warn, }; use ptth::{ @@ -36,6 +45,8 @@ pub struct Config { struct ServerState <'a> { config: Config, handlebars: handlebars::Handlebars <'a>, + + shutdown_watch_rx: watch::Receiver , } fn status_reply > (status: StatusCode, b: B) @@ -45,7 +56,7 @@ fn status_reply > (status: StatusCode, b: B) } async fn handle_all (req: Request , state: Arc >) --> Result , Infallible> +-> Result , String> { let path = req.uri ().path (); //println! ("{}", path); @@ -65,12 +76,19 @@ async fn handle_all (req: Request , state: Arc >) .as_ref () .unwrap_or (&default_root); + let mut shutdown_watch_rx = state.shutdown_watch_rx.clone (); + if shutdown_watch_rx.recv ().await != Some (false) { + error! ("Can't serve, I'm shutting down"); + panic! ("Can't serve, I'm shutting down"); + } + let ptth_resp = file_server::serve_all ( &state.handlebars, file_server_root, ptth_req.method, &ptth_req.uri, - &ptth_req.headers + &ptth_req.headers, + Some (shutdown_watch_rx) ).await; let mut resp = Response::builder () @@ -111,18 +129,22 @@ async fn main () -> Result <(), Box > { let handlebars = file_server::load_templates ()?; + let (shutdown_watch_tx, shutdown_watch_rx) = watch::channel (false); + let state = Arc::new (ServerState { - handlebars, config: Config { file_server_root: config_file.file_server_root, }, + handlebars, + + shutdown_watch_rx, }); let make_svc = make_service_fn (|_conn| { let state = state.clone (); async { - Ok::<_, Infallible> (service_fn (move |req| { + Ok::<_, String> (service_fn (move |req| { let state = state.clone (); handle_all (req, state) @@ -130,9 +152,30 @@ async fn main () -> Result <(), Box > { } }); - let server = Server::bind (&addr).serve (make_svc); + let shutdown_oneshot = ptth::graceful_shutdown::init (); + let (force_shutdown_tx, force_shutdown_rx) = oneshot::channel (); - server.await?; + let server = Server::bind (&addr) + .serve (make_svc) + .with_graceful_shutdown (async move { + shutdown_oneshot.await.ok (); + info! ("Received graceful shutdown"); + + shutdown_watch_tx.broadcast (true).unwrap (); + force_shutdown_tx.send (()).unwrap (); + }); + + let force_shutdown_fut = async move { + force_shutdown_rx.await.unwrap (); + delay_for (Duration::from_secs (5)).await; + + error! ("Forcing shutdown"); + }; + + futures::select! { + x = server.fuse () => x?, + _ = force_shutdown_fut.fuse () => (), + }; Ok (()) } diff --git a/src/server/file_server.rs b/src/server/file_server.rs index 11e4cdd..4155ebe 100644 --- a/src/server/file_server.rs +++ b/src/server/file_server.rs @@ -9,6 +9,7 @@ use std::{ path::{Path, PathBuf}, }; +use futures::FutureExt; use handlebars::Handlebars; use tokio::{ fs::{ @@ -20,6 +21,7 @@ use tokio::{ sync::mpsc::{ channel, }, + sync::watch, }; use tracing::{ debug, error, info, trace, warn, @@ -139,14 +141,15 @@ async fn serve_dir ( resp } -#[instrument (level = "debug", skip (f))] +#[instrument (level = "debug", skip (f, cancel_rx))] async fn serve_file ( mut f: File, should_send_body: bool, range_start: Option , - range_end: Option + range_end: Option , + mut cancel_rx: Option > ) -> http_serde::Response { - let (tx, rx) = channel (2); + let (tx, rx) = channel (1); let body = if should_send_body { Some (rx) } @@ -169,6 +172,7 @@ async fn serve_file ( if should_send_body { tokio::spawn (async move { + { //println! ("Opening file {:?}", path); let mut tx = tx; @@ -187,7 +191,20 @@ async fn serve_file ( break; } - if tx.send (Ok::<_, Infallible> (buffer)).await.is_err () { + let send_fut = tx.send (Ok::<_, Infallible> (buffer)); + + let send_result = match &mut cancel_rx { + Some (cancel_rx) => futures::select! { + x = send_fut.fuse () => x, + _ = cancel_rx.recv ().fuse () => { + error! ("Cancelled"); + break; + }, + }, + None => send_fut.await, + }; + + if send_result.is_err () { warn! ("Cancelling file stream (Sent {} out of {} bytes)", bytes_sent, end - start); break; } @@ -203,6 +220,8 @@ async fn serve_file ( //delay_for (Duration::from_millis (50)).await; } + } + debug! ("Exited stream scope"); }); } @@ -249,6 +268,7 @@ pub async fn serve_all ( method: http_serde::Method, uri: &str, headers: &HashMap >, + cancel_rx: Option > ) -> http_serde::Response { @@ -297,7 +317,8 @@ pub async fn serve_all ( file, should_send_body, range_start, - range_end + range_end, + cancel_rx ).await } else { diff --git a/src/server/mod.rs b/src/server/mod.rs index 4c74381..86a1bc3 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -76,7 +76,8 @@ async fn handle_req_resp <'a> ( file_server_root, parts.method, uri, - &parts.headers + &parts.headers, + None ).await } else {