diff --git a/README.md b/README.md index cdf44b4..f08902f 100644 --- a/README.md +++ b/README.md @@ -159,7 +159,6 @@ For now, either email me (if you know me personally) or make a pull request to a ## License PTTH is licensed under the -[GNU AGPLv3](https://www.gnu.org/licenses/agpl-3.0.html), -with an exception for my current employer. +[GNU AGPLv3](https://www.gnu.org/licenses/agpl-3.0.html) Copyright 2020 "Trish" diff --git a/src/bin/ptth_file_server.rs b/src/bin/ptth_file_server.rs index fb3fabf..5636d06 100644 --- a/src/bin/ptth_file_server.rs +++ b/src/bin/ptth_file_server.rs @@ -17,6 +17,10 @@ use hyper::{ }, StatusCode, }; +use serde::Deserialize; +use tracing::{ + debug, info, trace, warn, +}; use ptth::{ http_serde::RequestParts, @@ -24,8 +28,14 @@ use ptth::{ server::file_server, }; +#[derive (Default)] +pub struct Config { + pub file_server_root: Option , +} + struct ServerState <'a> { - handlebars: Arc >, + config: Config, + handlebars: handlebars::Handlebars <'a>, } fn status_reply > (status: StatusCode, b: B) @@ -41,8 +51,6 @@ async fn handle_all (req: Request , state: Arc >) //println! ("{}", path); if let Some (path) = prefix_match (path, "/files") { - let root = PathBuf::from ("./"); - let path = path.into (); let (parts, _) = req.into_parts (); @@ -52,7 +60,18 @@ async fn handle_all (req: Request , state: Arc >) _ => return Ok (status_reply (StatusCode::BAD_REQUEST, "Bad request")), }; - let ptth_resp = file_server::serve_all (&state.handlebars, &root, ptth_req.method, &ptth_req.uri, &ptth_req.headers).await; + let default_root = PathBuf::from ("./"); + let file_server_root: &std::path::Path = state.config.file_server_root + .as_ref () + .unwrap_or (&default_root); + + let ptth_resp = file_server::serve_all ( + &state.handlebars, + file_server_root, + ptth_req.method, + &ptth_req.uri, + &ptth_req.headers + ).await; let mut resp = Response::builder () .status (StatusCode::from (ptth_resp.parts.status_code)); @@ -77,14 +96,26 @@ async fn handle_all (req: Request , state: Arc >) } } +#[derive (Deserialize)] +pub struct ConfigFile { + pub file_server_root: Option , +} + #[tokio::main] async fn main () -> Result <(), Box > { + tracing_subscriber::fmt::init (); + let config_file: ConfigFile = ptth::load_toml::load ("config/ptth_server.toml"); + info! ("file_server_root: {:?}", config_file.file_server_root); + let addr = SocketAddr::from(([0, 0, 0, 0], 4000)); - let handlebars = Arc::new (file_server::load_templates ()?); + let handlebars = file_server::load_templates ()?; let state = Arc::new (ServerState { handlebars, + config: Config { + file_server_root: config_file.file_server_root, + }, }); let make_svc = make_service_fn (|_conn| { diff --git a/src/bin/ptth_server.rs b/src/bin/ptth_server.rs index 790dc97..aad61c7 100644 --- a/src/bin/ptth_server.rs +++ b/src/bin/ptth_server.rs @@ -14,7 +14,6 @@ struct Opt { #[tokio::main] async fn main () -> Result <(), Box > { tracing_subscriber::fmt::init (); - let config_file = ptth::load_toml::load ("config/ptth_server.toml"); ptth::server::run_server ( diff --git a/src/relay/mod.rs b/src/relay/mod.rs index 6b2c1cd..add443b 100644 --- a/src/relay/mod.rs +++ b/src/relay/mod.rs @@ -36,6 +36,7 @@ use tokio::{ mpsc, oneshot, RwLock, + watch, }, time::delay_for, }; @@ -133,26 +134,22 @@ pub struct RelayState { // Key: Request ID response_rendezvous: RwLock >, -} - -impl Default for RelayState { - fn default () -> Self { - Self { - config: Config::from (&ConfigFile::default ()), - handlebars: Arc::new (load_templates ().unwrap ()), - request_rendezvous: Default::default (), - response_rendezvous: Default::default (), - } - } + + shutdown_watch_tx: watch::Sender , + shutdown_watch_rx: watch::Receiver , } impl From <&ConfigFile> for RelayState { fn from (config_file: &ConfigFile) -> Self { + let (shutdown_watch_tx, shutdown_watch_rx) = watch::channel (false); + Self { config: Config::from (config_file), handlebars: Arc::new (load_templates ().unwrap ()), request_rendezvous: Default::default (), response_rendezvous: Default::default (), + shutdown_watch_tx, + shutdown_watch_rx, } } } @@ -270,27 +267,39 @@ async fn handle_http_response ( let (mut body_tx, body_rx) = mpsc::channel (2); let (body_finished_tx, body_finished_rx) = oneshot::channel (); + let mut shutdown_watch_rx = state.shutdown_watch_rx.clone (); spawn (async move { - loop { - let item = body.next ().await; - - if let Some (item) = item { - if let Ok (bytes) = &item { - trace! ("Relaying {} bytes", bytes.len ()); - } + if shutdown_watch_rx.recv ().await == Some (false) { + loop { + let item = body.next ().await; - if let Err (_e) = body_tx.send (item).await { - info! ("Body closed while relaying. (Client hung up?)"); - body_finished_tx.send (ClientDisconnected).unwrap (); + if let Some (item) = item { + if let Ok (bytes) = &item { + trace! ("Relaying {} bytes", bytes.len ()); + } + + futures::select! { + x = body_tx.send (item).fuse () => if let Err (_) = x { + info! ("Body closed while relaying. (Client hung up?)"); + body_finished_tx.send (ClientDisconnected).unwrap (); + break; + }, + _ = shutdown_watch_rx.recv ().fuse () => { + debug! ("Closing stream: relay is shutting down"); + break; + }, + } + } + else { + debug! ("Finished relaying bytes"); + body_finished_tx.send (StreamFinished).unwrap (); break; } } - else { - debug! ("Finished relaying bytes"); - body_finished_tx.send (StreamFinished).unwrap (); - break; - } + } + else { + debug! ("Can't relay bytes, relay is shutting down"); } }); @@ -315,13 +324,17 @@ async fn handle_http_response ( } debug! ("Connected server to client for streaming."); - match body_finished_rx.await.unwrap () { - StreamFinished => { + match body_finished_rx.await { + Ok (StreamFinished) => { error_reply (StatusCode::OK, "StreamFinished") }, - ClientDisconnected => { + Ok (ClientDisconnected) => { error_reply (StatusCode::OK, "ClientDisconnected") - } + }, + Err (e) => { + debug! ("body_finished_rx {}", e); + error_reply (StatusCode::OK, "body_finished_rx Err") + }, } } @@ -569,10 +582,14 @@ pub async fn run_relay ( let server = Server::bind (&addr) .serve (make_svc); + + server.with_graceful_shutdown (async { shutdown_oneshot.await.ok (); info! ("Received graceful shutdown"); + state.shutdown_watch_tx.broadcast (true).unwrap (); + use RelayError::*; let mut response_rendezvous = state.response_rendezvous.write ().await; @@ -594,6 +611,8 @@ pub async fn run_relay ( ParkedServer (sender) => drop (sender.send (Err (RelayShuttingDown))), } } + + info! ("Performed all cleanup"); }).await?; info! ("Exiting");