diff --git a/Cargo.toml b/Cargo.toml index 04faa6b..6972530 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,5 +29,8 @@ rmp-serde = "0.14.4" serde = {version = "1.0.117", features = ["derive"]} structopt = "0.3.20" tokio = { version = "0.2.22", features = ["full"] } +tracing = "0.1.21" +tracing-futures = "0.2.4" +tracing-subscriber = "0.2.15" toml = "0.5.7" ulid = "0.4.1" diff --git a/src/bin/ptth_relay.rs b/src/bin/ptth_relay.rs index 4c57502..4f5ad9a 100644 --- a/src/bin/ptth_relay.rs +++ b/src/bin/ptth_relay.rs @@ -16,13 +16,13 @@ async fn main () -> Result <(), Box > { let config_file = { let config_file_path = "config/ptth_relay.toml"; - let mut f = File::open (config_file_path).expect (&format! ("Can't open {:?}", config_file_path)); + let mut f = File::open (config_file_path).unwrap_or_else (|_| panic! ("Can't open {:?}", config_file_path)); let mut buffer = vec! [0u8; 4096]; - let bytes_read = f.read (&mut buffer).expect (&format! ("Can't read {:?}", config_file_path)); + let bytes_read = f.read (&mut buffer).unwrap_or_else (|_| panic! ("Can't read {:?}", config_file_path)); buffer.truncate (bytes_read); - let config_s = String::from_utf8 (buffer).expect (&format! ("Can't parse {:?} as UTF-8", config_file_path)); - toml::from_str (&config_s).expect (&format! ("Can't parse {:?} as TOML", config_file_path)) + let config_s = String::from_utf8 (buffer).unwrap_or_else (|_| panic! ("Can't parse {:?} as UTF-8", config_file_path)); + toml::from_str (&config_s).unwrap_or_else (|_| panic! ("Can't parse {:?} as TOML", config_file_path)) }; eprintln! ("ptth_relay Git version: {:?}", ptth::git_version::GIT_VERSION); diff --git a/src/bin/ptth_server.rs b/src/bin/ptth_server.rs index 1577ae0..1ec88a0 100644 --- a/src/bin/ptth_server.rs +++ b/src/bin/ptth_server.rs @@ -18,14 +18,14 @@ async fn main () -> Result <(), Box > { let config_file = { let config_file_path = "config/ptth_server.toml"; - let mut f = std::fs::File::open (config_file_path).expect (&format! ("Can't open {:?}", config_file_path)); + let mut f = std::fs::File::open (config_file_path).unwrap_or_else (|_| panic! ("Can't open {:?}", config_file_path)); let mut buffer = vec! [0u8; 4096]; - let bytes_read = f.read (&mut buffer).expect (&format! ("Can't read {:?}", config_file_path)); + let bytes_read = f.read (&mut buffer).unwrap_or_else (|_| panic! ("Can't read {:?}", config_file_path)); buffer.truncate (bytes_read); - let config_s = String::from_utf8 (buffer).expect (&format! ("Can't parse {:?} as UTF-8", config_file_path)); - toml::from_str (&config_s).expect (&format! ("Can't parse {:?} as TOML", config_file_path)) + let config_s = String::from_utf8 (buffer).unwrap_or_else (|_| panic! ("Can't parse {:?} as UTF-8", config_file_path)); + toml::from_str (&config_s).unwrap_or_else (|_| panic! ("Can't parse {:?} as TOML", config_file_path)) }; - ptth::server::main (config_file).await + ptth::server::main (config_file, None).await } diff --git a/src/lib.rs b/src/lib.rs index d610009..a4ceb98 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -41,12 +41,21 @@ pub fn password_is_bad (mut password: String) -> bool { #[cfg (test)] mod tests { use std::{ - sync::Arc, + sync::{ + Arc, + atomic::{ + AtomicBool, + Ordering, + }, + }, + time::Duration, }; use tokio::{ runtime::Runtime, spawn, + sync::oneshot, + time::delay_for, }; use super::{ @@ -58,7 +67,7 @@ mod tests { fn check_bad_passwords () { use crate::password_is_bad; - for pw in vec! [ + for pw in &[ "password", "pAsSwOrD", "secret", @@ -80,6 +89,9 @@ mod tests { fn end_to_end () { use maplit::*; use reqwest::Client; + use tracing::{info}; + + tracing_subscriber::fmt::init (); let mut rt = Runtime::new ().unwrap (); @@ -99,8 +111,9 @@ mod tests { let relay_state = Arc::new (relay::RelayState::from (&config_file)); let relay_state_2 = relay_state.clone (); - spawn (async move { - relay::run_relay (relay_state_2, None).await.unwrap (); + let (stop_relay_tx, stop_relay_rx) = oneshot::channel (); + let task_relay = spawn (async move { + relay::run_relay (relay_state_2, Some (stop_relay_rx)).await.unwrap (); }); assert! (relay_state.list_servers ().await.is_empty ()); @@ -113,18 +126,22 @@ mod tests { relay_url: "http://127.0.0.1:4000/7ZSFUKGV".into (), file_server_root: None, }; - spawn (async move { - server::main (config_file).await.unwrap (); - }); + let stop_server_atomic = Arc::new (AtomicBool::from (false)); + let task_server = { + let stop_server_atomic = stop_server_atomic.clone (); + spawn (async move { + server::main (config_file, Some (stop_server_atomic)).await.unwrap (); + }) + }; - tokio::time::delay_for (std::time::Duration::from_millis (500)).await; + delay_for (Duration::from_millis (500)).await; assert_eq! (relay_state.list_servers ().await, vec! [ server_name.to_string (), ]); let client = Client::builder () - .timeout (std::time::Duration::from_secs (2)) + .timeout (Duration::from_secs (2)) .build ().unwrap (); let resp = client.get (&format! ("{}/frontend/relay_up_check", relay_url)) @@ -156,6 +173,19 @@ mod tests { .send ().await.unwrap (); assert_eq! (resp.status (), reqwest::StatusCode::NOT_FOUND); + + info! ("Shutting down end-to-end test"); + + stop_server_atomic.store (true, Ordering::Relaxed); + stop_relay_tx.send (()).unwrap (); + + info! ("Sent stop messages"); + + task_relay.await.unwrap (); + info! ("Relay stopped"); + + task_server.await.unwrap (); + info! ("Server stopped"); }); } } diff --git a/src/relay/mod.rs b/src/relay/mod.rs index 30844b3..18a7127 100644 --- a/src/relay/mod.rs +++ b/src/relay/mod.rs @@ -10,7 +10,6 @@ use std::{ }; use dashmap::DashMap; -use futures::channel::oneshot; use handlebars::Handlebars; use hyper::{ Body, @@ -26,7 +25,10 @@ use serde::{ Serialize, }; use tokio::{ - sync::Mutex, + sync::{ + Mutex, + oneshot, + }, }; use crate::{ @@ -183,17 +185,23 @@ async fn handle_http_listen ( if let Some (ParkedClients (v)) = request_rendezvous.remove (&watcher_code) { + // 1 or more clients were parked - Make the server + // handle them immediately + return status_reply (StatusCode::OK, rmp_serde::to_vec (&v).unwrap ()); } request_rendezvous.insert (watcher_code, ParkedServer (tx)); } - let one_req = vec! [ - rx.await.unwrap (), - ]; + // No clients were parked - make the server long-poll - return status_reply (StatusCode::OK, rmp_serde::to_vec (&one_req).unwrap ()); + let one_req = match rx.await { + Ok (r) => r, + Err (_) => return status_reply (StatusCode::SERVICE_UNAVAILABLE, "Server is shutting down, try again soon"), + }; + + status_reply (StatusCode::OK, rmp_serde::to_vec (&vec! [one_req]).unwrap ()) } async fn handle_http_response ( @@ -385,11 +393,13 @@ pub fn load_templates () Ok (handlebars) } +use tracing::info; + pub async fn run_relay ( state: Arc , - shutdown_oneshot: Option > + shutdown_oneshot: Option > ) - -> Result <(), Box > +-> Result <(), Box > { let addr = SocketAddr::from (( [0, 0, 0, 0], @@ -406,7 +416,7 @@ pub async fn run_relay ( } } - eprintln! ("Loaded {} server tripcodes", state.config.server_tripcodes.len ()); + info! ("Loaded {} server tripcodes", state.config.server_tripcodes.len ()); let make_svc = make_service_fn (|_conn| { let state = state.clone (); @@ -424,16 +434,52 @@ pub async fn run_relay ( .serve (make_svc); match shutdown_oneshot { - Some (rx) => server.with_graceful_shutdown (async { - rx.await.ok (); - }).await?, + Some (rx) => { + info! ("Configured for graceful shutdown"); + server.with_graceful_shutdown (async { + rx.await.ok (); + + state.response_rendezvous.clear (); + + let mut request_rendezvoux = state.request_rendezvous.lock ().await; + request_rendezvoux.clear (); + + info! ("Received graceful shutdown"); + }).await? + }, None => server.await?, }; + info! ("Exiting"); Ok (()) } #[cfg (test)] mod tests { + use std::time::Duration; + use tokio::{ + runtime::Runtime, + spawn, + sync::oneshot, + time::delay_for, + }; + + #[test] + fn so_crazy_it_might_work () { + let mut rt = Runtime::new ().unwrap (); + + rt.block_on (async { + let (tx, rx) = oneshot::channel (); + + let task_1 = spawn (async move { + delay_for (Duration::from_secs (1)).await; + tx.send (()).unwrap (); + }); + + rx.await.unwrap (); + + task_1.await.unwrap (); + }); + } } diff --git a/src/server/mod.rs b/src/server/mod.rs index c5986ed..6de6158 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -117,7 +117,17 @@ pub struct Config { pub file_server_root: Option , } -pub async fn main (config_file: ConfigFile) +use std::sync::atomic::{ + AtomicBool, + Ordering, +}; + +use tracing::info; + +pub async fn main ( + config_file: ConfigFile, + shutdown_atomic: Option > +) -> Result <(), Box > { use std::convert::TryInto; @@ -135,6 +145,7 @@ pub async fn main (config_file: ConfigFile) let client = Client::builder () .default_headers (headers) + .timeout (Duration::from_secs (30)) .build ().unwrap (); let handlebars = file_server::load_templates ()?; @@ -150,10 +161,24 @@ pub async fn main (config_file: ConfigFile) let mut backoff_delay = 0; loop { + if let Some (a) = &shutdown_atomic { + if a.load (Ordering::Relaxed) { + break; + } + } + if backoff_delay > 0 { delay_for (Duration::from_millis (backoff_delay)).await; } + if let Some (a) = &shutdown_atomic { + if a.load (Ordering::Relaxed) { + break; + } + } + + info! ("http_listen"); + let req_req = state.client.get (&format! ("{}/http_listen/{}", state.config.relay_url, config_file.name)); let err_backoff_delay = std::cmp::min (30_000, backoff_delay * 2 + 500); @@ -186,4 +211,8 @@ pub async fn main (config_file: ConfigFile) handle_req_resp (state, req_resp).await; }); } + + info! ("Exiting"); + + Ok (()) }