#![warn (clippy::pedantic)] // I don't see the point of writing the type twice if I'm initializing a struct // and the type is already in the struct definition. #![allow (clippy::default_trait_access)] // I'm not sure if I like this one #![allow (clippy::enum_glob_use)] // I don't see the point in documenting the errors outside of where the // error type is defined. #![allow (clippy::missing_errors_doc)] // False positive on futures::select! macro #![allow (clippy::mut_mut)] use std::{ borrow::Cow, convert::Infallible, net::SocketAddr, path::{Path, PathBuf}, sync::Arc, time::Duration, }; use chrono::{ DateTime, SecondsFormat, Utc }; use dashmap::DashMap; use futures_util::StreamExt; use handlebars::Handlebars; use hyper::{ Body, Method, Request, Response, Server, StatusCode, }; use hyper::service::{make_service_fn, service_fn}; use serde::{ Serialize, }; use tokio::{ sync::{ oneshot, }, }; use tokio_stream::wrappers::ReceiverStream; use ptth_core::{ http_serde, prefix_match, prelude::*, }; pub mod config; pub mod errors; pub mod git_version; pub mod key_validity; mod relay_state; mod scraper_api; mod server_endpoint; pub use config::Config; pub use errors::*; pub use relay_state::RelayState; use relay_state::*; fn ok_reply > (b: B) -> Result , http::Error> { Response::builder ().status (StatusCode::OK).body (b.into ()) } fn error_reply (status: StatusCode, b: &str) -> Result , http::Error> { Response::builder () .status (status) .header ("content-type", "text/plain") .body (format! ("{}\n", b).into ()) } // Clients will come here to start requests, and always park for at least // a short amount of time. async fn handle_http_request ( req: http::request::Parts, uri: String, state: Arc , watcher_code: String ) -> Result , http::Error> { { let config = state.config.read ().await; if ! config.servers.contains_key (&watcher_code) { return error_reply (StatusCode::NOT_FOUND, "Unknown server"); } } let req = match http_serde::RequestParts::from_hyper (req.method, uri, req.headers) { Ok (x) => x, Err (_) => return error_reply (StatusCode::BAD_REQUEST, "Bad request"), }; let (tx, rx) = oneshot::channel (); let req_id = ulid::Ulid::new ().to_string (); { let response_rendezvous = state.response_rendezvous.read ().await; response_rendezvous.insert (req_id.clone (), tx); } trace! ("Created request {}", req_id); { use RequestRendezvous::*; let mut request_rendezvous = state.request_rendezvous.lock ().await; let wrapped = http_serde::WrappedRequest { id: req_id.clone (), req, }; let new_rendezvous = match request_rendezvous.remove (&watcher_code) { Some (ParkedClients (mut v)) => { debug! ("Parking request {} ({} already queued)", req_id, v.len ()); v.push (wrapped); ParkedClients (v) }, Some (ParkedServer (s)) => { // If sending to the server fails, queue it match s.send (Ok (wrapped)) { Ok (()) => { // TODO: This can actually still fail, if the server // disconnects right as we're sending this. // Then what? trace! ( "Sending request {} directly to server {}", req_id, watcher_code, ); ParkedClients (vec! []) }, Err (Ok (wrapped)) => { debug! ("Parking request {}", req_id); ParkedClients (vec! [wrapped]) }, Err (_) => unreachable! (), } }, None => { debug! ("Parking request {}", req_id); ParkedClients (vec! [wrapped]) }, }; request_rendezvous.insert (watcher_code, new_rendezvous); } let timeout = tokio::time::sleep (std::time::Duration::from_secs (30)); let received = tokio::select! { val = rx => val, () = timeout => { debug! ("Timed out request {}", req_id); return error_reply (StatusCode::GATEWAY_TIMEOUT, "Remote server never responded") }, }; // UKAUFFY4 (Receive half) match received { Ok (Ok ((parts, body))) => { let mut resp = Response::builder () .status (hyper::StatusCode::from (parts.status_code)); for (k, v) in parts.headers { resp = resp.header (&k, v); } debug! ("Unparked request {}", req_id); resp.body (body) }, Ok (Err (ShuttingDownError::ShuttingDown)) => { error_reply (StatusCode::GATEWAY_TIMEOUT, "Relay shutting down") }, Err (_) => { debug! ("Responder sender dropped for request {}", req_id); error_reply (StatusCode::GATEWAY_TIMEOUT, "Remote server timed out") }, } } #[derive (Debug, PartialEq)] enum LastSeen { Negative, Connected, Description (String), } // Mnemonic is "now - last_seen" fn pretty_print_last_seen ( now: DateTime , last_seen: DateTime ) -> LastSeen { use LastSeen::*; let dur = now.signed_duration_since (last_seen); if dur < chrono::Duration::zero () { return Negative; } if dur.num_minutes () < 1 { return Connected; } if dur.num_hours () < 1 { return Description (format! ("{} m ago", dur.num_minutes ())); } if dur.num_days () < 1 { return Description (format! ("{} h ago", dur.num_hours ())); } Description (last_seen.to_rfc3339_opts (SecondsFormat::Secs, true)) } #[derive (Serialize)] struct ServerEntry <'a> { name: String, display_name: String, last_seen: Cow <'a, str>, } #[derive (Serialize)] struct ServerListPage <'a> { dev_mode: bool, git_version: Option , servers: Vec >, } #[derive (Serialize)] struct UnregisteredServer { name: String, tripcode: String, last_seen: String, } #[derive (Serialize)] struct UnregisteredServerListPage { unregistered_servers: Vec , } async fn handle_server_list_internal (state: &Arc ) -> ServerListPage <'static> { use LastSeen::*; let dev_mode = { let guard = state.config.read ().await; guard.iso.dev_mode.is_some () }; let git_version = git_version::read_git_version ().await; let server_list = scraper_api::v1_server_list (&state).await; let now = Utc::now (); let servers = server_list.servers.into_iter () .map (|x| { let last_seen = match x.last_seen { None => "Never".into (), Some (x) => match pretty_print_last_seen (now, x) { Negative => "Error (negative time)".into (), Connected => "Connected".into (), Description (s) => s.into (), }, }; ServerEntry { name: x.name, display_name: x.display_name, last_seen, } }) .collect (); ServerListPage { dev_mode, git_version, servers, } } async fn handle_unregistered_servers_internal (state: &Arc ) -> UnregisteredServerListPage { UnregisteredServerListPage { unregistered_servers: vec! [], } } async fn handle_server_list ( state: Arc , handlebars: Arc > ) -> Result , RequestError> { let page = handle_server_list_internal (&state).await; let s = handlebars.render ("server_list", &page)?; Ok (ok_reply (s)?) } async fn handle_unregistered_servers ( state: Arc , handlebars: Arc > ) -> Result , RequestError> { let page = handle_unregistered_servers_internal (&state).await; let s = handlebars.render ("unregistered_servers", &page)?; Ok (ok_reply (s)?) } async fn handle_endless_sink (req: Request ) -> Result , http::Error> { let (_parts, mut body) = req.into_parts (); let mut bytes_received = 0; loop { let item = body.next ().await; if let Some (item) = item { if let Ok (bytes) = &item { bytes_received += bytes.len (); } } else { debug! ("Finished sinking debug bytes"); break; } } Ok (ok_reply (format! ("Sank {} bytes\n", bytes_received))?) } async fn handle_endless_source (gib: usize, throttle: Option ) -> Result , http::Error> { use tokio::sync::mpsc; let block_bytes = 64 * 1024; let num_blocks = (1024 * 1024 * 1024 / block_bytes) * gib; let (tx, rx) = mpsc::channel (1); tokio::spawn (async move { let random_block = { use rand::RngCore; let mut rng = rand::thread_rng (); let mut block = vec! [0u8; 64 * 1024]; rng.fill_bytes (&mut block); block }; let mut interval = tokio::time::interval (Duration::from_millis (1000)); let mut blocks_sent = 0; while blocks_sent < num_blocks { if throttle.is_some () { interval.tick ().await; } for _ in 0..throttle.unwrap_or (1) { let item = Ok::<_, Infallible> (random_block.clone ()); if let Err (_) = tx.send (item).await { debug! ("Endless source dropped"); return; } blocks_sent += 1; } } debug! ("Endless source ended"); }); Response::builder () .status (StatusCode::OK) .header ("content-type", "application/octet-stream") .body (Body::wrap_stream (ReceiverStream::new (rx))) } #[instrument (level = "trace", skip (req, state, handlebars))] async fn handle_all ( req: Request , state: Arc , handlebars: Arc > ) -> Result , RequestError> { let path = req.uri ().path ().to_string (); //println! ("{}", path); debug! ("Request path: {}", path); if req.method () == Method::POST { // This is stuff the server can use. Clients can't // POST right now return if let Some (request_code) = prefix_match ("/7ZSFUKGV/http_response/", &path) { let request_code = request_code.into (); Ok (server_endpoint::handle_response (req, state, request_code).await?) } else if path == "/frontend/debug/endless_sink" { Ok (handle_endless_sink (req).await?) } else { error! ("Can't POST {}", path); Ok (error_reply (StatusCode::BAD_REQUEST, "Can't POST this")?) }; } if let Some (listen_code) = prefix_match ("/7ZSFUKGV/http_listen/", &path) { let api_key = req.headers ().get ("X-ApiKey"); let api_key = match api_key { None => return Ok (error_reply (StatusCode::FORBIDDEN, "Can't run server without an API key")?), Some (x) => x, }; server_endpoint::handle_listen (state, listen_code.into (), api_key.as_bytes ()).await } else if let Some (rest) = prefix_match ("/frontend/servers/", &path) { // DRY T4H76LB3 if rest == "" { Ok (handle_server_list (state, handlebars).await?) } else if let Some (idx) = rest.find ('/') { let listen_code = String::from (&rest [0..idx]); let path = String::from (&rest [idx..]); let (parts, _) = req.into_parts (); Ok (handle_http_request (parts, path, state, listen_code).await?) } else { Ok (error_reply (StatusCode::BAD_REQUEST, "Bad URI format")?) } } else if path == "/frontend/unregistered_servers" { Ok (handle_unregistered_servers (state, handlebars).await?) } else if let Some (rest) = prefix_match ("/frontend/debug/", &path) { if rest == "" { let s = handlebars.render ("debug", &())?; Ok (ok_reply (s)?) } else if rest == "endless_source" { Ok (handle_endless_source (1, None).await?) } else if rest == "endless_source_throttled" { Ok (handle_endless_source (1, Some (1024 / 64)).await?) } else if rest == "endless_sink" { Ok (error_reply (StatusCode::METHOD_NOT_ALLOWED, "Don't GET this URL, POST to it.")?) } else { Ok (error_reply (StatusCode::NOT_FOUND, "Can't route URL")?) } } else if path == "/" { let s = handlebars.render ("root", &())?; Ok (ok_reply (s)?) } else if path == "/frontend/relay_up_check" { Ok (error_reply (StatusCode::OK, "Relay is up")?) } else if path == "/frontend/test_mysterious_error" { Err (RequestError::Mysterious) } else if let Some (rest) = prefix_match ("/scraper/", &path) { scraper_api::handle (req, state, rest).await } else { Ok (error_reply (StatusCode::OK, "Can't route URL")?) } } pub fn load_templates (asset_root: &Path) -> Result , RelayError> { let mut handlebars = Handlebars::new (); handlebars.set_strict_mode (true); let asset_root = asset_root.join ("handlebars/relay"); for (k, v) in &[ ("debug", "debug.hbs"), ("root", "root.hbs"), ("server_list", "server_list.hbs"), ("unregistered_servers", "unregistered_servers.hbs"), ] { handlebars.register_template_file (k, &asset_root.join (v))?; } Ok (handlebars) } async fn reload_config ( state: &Arc , config_reload_path: &Path ) -> Result <(), ConfigError> { let new_config = Config::from_file (config_reload_path).await?; let mut config = state.config.write ().await; trace! ("Reloading config"); if config.servers.len () != new_config.servers.len () { debug! ("Loaded {} server configs", config.servers.len ()); } if config.iso.enable_scraper_api != new_config.iso.enable_scraper_api { debug! ("enable_scraper_api: {}", config.iso.enable_scraper_api); } (*config) = new_config; if config.iso.dev_mode.is_some () { error! ("Dev mode is enabled! This might turn off some security features. If you see this in production, escalate it to someone!"); } Ok (()) } pub async fn run_relay ( state: Arc , handlebars: Arc >, shutdown_oneshot: oneshot::Receiver <()>, config_reload_path: Option ) -> Result <(), RelayError> { if let Some (config_reload_path) = config_reload_path { let state_2 = state.clone (); tokio::spawn (async move { let mut reload_interval = tokio::time::interval (Duration::from_secs (60)); loop { reload_interval.tick ().await; reload_config (&state_2, &config_reload_path).await.ok (); } }); } let make_svc = make_service_fn (|_conn| { let state = state.clone (); let handlebars = handlebars.clone (); async { Ok::<_, Infallible> (service_fn (move |req| { let state = state.clone (); let handlebars = handlebars.clone (); async { Ok::<_, Infallible> (handle_all (req, state, handlebars).await.unwrap_or_else (|e| { error! ("{}", e); error_reply (StatusCode::INTERNAL_SERVER_ERROR, "Error in relay").unwrap () })) } })) } }); let addr = SocketAddr::from (( [0, 0, 0, 0], state.config.read ().await.port.unwrap_or (4000), )); let server = Server::bind (&addr) .serve (make_svc); server.with_graceful_shutdown (async { use ShuttingDownError::ShuttingDown; shutdown_oneshot.await.ok (); state.shutdown_watch_tx.send (true).expect ("Can't broadcast graceful shutdown"); let mut response_rendezvous = state.response_rendezvous.write ().await; let mut swapped = DashMap::default (); std::mem::swap (&mut swapped, &mut response_rendezvous); for (_, sender) in swapped { sender.send (Err (ShuttingDown)).ok (); } let mut request_rendezvous = state.request_rendezvous.lock ().await; for (_, x) in request_rendezvous.drain () { use RequestRendezvous::*; match x { ParkedClients (_) => (), ParkedServer (sender) => drop (sender.send (Err (ShuttingDown))), } } debug! ("Performed all cleanup"); }).await?; Ok (()) } #[cfg (test)] mod tests;