#![warn (clippy::pedantic)] // I don't see the point of writing the type twice if I'm initializing a struct // and the type is already in the struct definition. #![allow (clippy::default_trait_access)] // I'm not sure if I like this one #![allow (clippy::enum_glob_use)] // I don't see the point in documenting the errors outside of where the // error type is defined. #![allow (clippy::missing_errors_doc)] // False positive on futures::select! macro #![allow (clippy::mut_mut)] use std::{ borrow::Cow, collections::HashMap, convert::TryFrom, iter::FromIterator, net::SocketAddr, path::{Path, PathBuf}, sync::Arc, time::Duration, }; use chrono::{ DateTime, SecondsFormat, Utc }; use dashmap::DashMap; use handlebars::Handlebars; use hyper::{ Body, Method, Request, Response, Server, StatusCode, }; use hyper::service::{make_service_fn, service_fn}; use serde::{ Serialize, }; use tokio::{ sync::{ Mutex, oneshot, RwLock, watch, }, }; use ptth_core::{ http_serde, prefix_match, prelude::*, }; pub mod config; pub mod errors; pub mod git_version; pub mod key_validity; mod server_endpoint; pub use config::Config; pub use errors::*; /* Here's what we need to handle: When a request comes in: - Park the client in response_rendezvous - Look up the server ID in request_rendezvous - If a server is parked, unpark it and send the request - Otherwise, queue the request When a server comes to listen: - Look up the server ID in request_rendezvous - Either return all pending requests, or park the server When a server comes to respond: - Look up the parked client in response_rendezvous - Unpark the client and begin streaming So we need these lookups to be fast: - Server IDs, where (1 server) or (0 or many clients) can be parked - Request IDs, where 1 client is parked */ enum RequestRendezvous { ParkedClients (Vec ), ParkedServer (oneshot::Sender >), } type ResponseRendezvous = oneshot::Sender >; #[derive (Clone)] pub struct ServerStatus { last_seen: DateTime , } impl Default for ServerStatus { fn default () -> Self { Self { last_seen: Utc::now (), } } } pub struct RelayState { config: RwLock , handlebars: Arc >, // Key: Server ID request_rendezvous: Mutex >, server_status: Mutex >, // Key: Request ID response_rendezvous: RwLock >, shutdown_watch_tx: watch::Sender , shutdown_watch_rx: watch::Receiver , } impl TryFrom for RelayState { type Error = RelayError; fn try_from (config: Config) -> Result { let (shutdown_watch_tx, shutdown_watch_rx) = watch::channel (false); Ok (Self { config: config.into (), handlebars: Arc::new (load_templates (&PathBuf::new ())?), request_rendezvous: Default::default (), server_status: Default::default (), response_rendezvous: Default::default (), shutdown_watch_tx, shutdown_watch_rx, }) } } impl RelayState { pub async fn list_servers (&self) -> Vec { self.request_rendezvous.lock ().await.iter () .map (|(k, _)| (*k).clone ()) .collect () } } fn ok_reply > (b: B) -> Result , http::Error> { Response::builder ().status (StatusCode::OK).body (b.into ()) } fn error_reply (status: StatusCode, b: &str) -> Result , http::Error> { Response::builder () .status (status) .header ("content-type", "text/plain") .body (format! ("{}\n", b).into ()) } // Clients will come here to start requests, and always park for at least // a short amount of time. async fn handle_http_request ( req: http::request::Parts, uri: String, state: Arc , watcher_code: String ) -> Result , http::Error> { { let config = state.config.read ().await; if ! config.servers.contains_key (&watcher_code) { return error_reply (StatusCode::NOT_FOUND, "Unknown server"); } } let req = match http_serde::RequestParts::from_hyper (req.method, uri, req.headers) { Ok (x) => x, Err (_) => return error_reply (StatusCode::BAD_REQUEST, "Bad request"), }; let (tx, rx) = oneshot::channel (); let req_id = ulid::Ulid::new ().to_string (); { let response_rendezvous = state.response_rendezvous.read ().await; response_rendezvous.insert (req_id.clone (), tx); } trace! ("Created request {}", req_id); { use RequestRendezvous::*; let mut request_rendezvous = state.request_rendezvous.lock ().await; let wrapped = http_serde::WrappedRequest { id: req_id.clone (), req, }; let new_rendezvous = match request_rendezvous.remove (&watcher_code) { Some (ParkedClients (mut v)) => { debug! ("Parking request {} ({} already queued)", req_id, v.len ()); v.push (wrapped); ParkedClients (v) }, Some (ParkedServer (s)) => { // If sending to the server fails, queue it match s.send (Ok (wrapped)) { Ok (()) => { // TODO: This can actually still fail, if the server // disconnects right as we're sending this. // Then what? debug! ( "Sending request {} directly to server {}", req_id, watcher_code, ); ParkedClients (vec! []) }, Err (Ok (wrapped)) => { debug! ("Parking request {}", req_id); ParkedClients (vec! [wrapped]) }, Err (_) => unreachable! (), } }, None => { debug! ("Parking request {}", req_id); ParkedClients (vec! [wrapped]) }, }; request_rendezvous.insert (watcher_code, new_rendezvous); } let timeout = tokio::time::delay_for (std::time::Duration::from_secs (30)); let received = tokio::select! { val = rx => val, () = timeout => { debug! ("Timed out request {}", req_id); return error_reply (StatusCode::GATEWAY_TIMEOUT, "Remote server never responded") }, }; // UKAUFFY4 (Receive half) match received { Ok (Ok ((parts, body))) => { let mut resp = Response::builder () .status (hyper::StatusCode::from (parts.status_code)); for (k, v) in parts.headers { resp = resp.header (&k, v); } debug! ("Unparked request {}", req_id); resp.body (body) }, Ok (Err (ShuttingDownError::ShuttingDown)) => { error_reply (StatusCode::GATEWAY_TIMEOUT, "Relay shutting down") }, Err (_) => { debug! ("Responder sender dropped for request {}", req_id); error_reply (StatusCode::GATEWAY_TIMEOUT, "Remote server timed out") }, } } #[derive (Debug, PartialEq)] enum LastSeen { Negative, Connected, Description (String), } // Mnemonic is "now - last_seen" fn pretty_print_last_seen ( now: DateTime , last_seen: DateTime ) -> LastSeen { use LastSeen::*; let dur = now.signed_duration_since (last_seen); if dur < chrono::Duration::zero () { return Negative; } if dur.num_minutes () < 1 { return Connected; } if dur.num_hours () < 1 { return Description (format! ("{} m ago", dur.num_minutes ())); } if dur.num_days () < 1 { return Description (format! ("{} h ago", dur.num_hours ())); } Description (last_seen.to_rfc3339_opts (SecondsFormat::Secs, true)) } #[derive (Serialize)] struct ServerEntry <'a> { id: String, display_name: String, last_seen: Cow <'a, str>, } #[derive (Serialize)] struct ServerListPage <'a> { dev_mode: bool, git_version: Option , servers: Vec >, } async fn handle_server_list_internal (state: &Arc ) -> ServerListPage <'static> { let dev_mode; let display_names: HashMap = { let guard = state.config.read ().await; dev_mode = guard.iso.dev_mode.is_some (); let servers = (*guard).servers.iter () .map (|(k, v)| { let display_name = v.display_name .clone () .unwrap_or_else (|| k.clone ()); (k.clone (), display_name) }); HashMap::from_iter (servers) }; let server_statuses = { let guard = state.server_status.lock ().await; (*guard).clone () }; let now = Utc::now (); let mut servers: Vec <_> = display_names.into_iter () .map (|(id, display_name)| { use LastSeen::*; let status = match server_statuses.get (&id) { None => return ServerEntry { display_name, id, last_seen: "Never".into (), }, Some (x) => x, }; let last_seen = match pretty_print_last_seen (now, status.last_seen) { Negative => "Error (negative time)".into (), Connected => "Connected".into (), Description (s) => s.into (), }; ServerEntry { display_name, id, last_seen, } }) .collect (); servers.sort_by (|a, b| a.display_name.cmp (&b.display_name)); ServerListPage { dev_mode, git_version: git_version::read_git_version ().await, servers, } } async fn handle_server_list ( state: Arc ) -> Result , RequestError> { let page = handle_server_list_internal (&state).await; let s = state.handlebars.render ("relay_server_list", &page)?; Ok (ok_reply (s)?) } #[instrument (level = "trace", skip (req, state))] async fn handle_scraper_api_v1 ( req: Request , state: Arc , path_rest: &str ) -> Result , RequestError> { use key_validity::KeyValidity; let api_key = req.headers ().get ("X-ApiKey"); let api_key = match api_key { None => return Ok (error_reply (StatusCode::FORBIDDEN, "Can't run scraper without an API key")?), Some (x) => x, }; let bad_key = || error_reply (StatusCode::FORBIDDEN, "403 Forbidden"); { let config = state.config.read ().await; let dev_mode = match &config.iso.dev_mode { None => return Ok (bad_key ()?), Some (x) => x, }; let expected_key = match &dev_mode.scraper_key { None => return Ok (bad_key ()?), Some (x) => x, }; let now = chrono::Utc::now (); match expected_key.is_valid (now, api_key.as_bytes ()) { KeyValidity::Valid => (), KeyValidity::WrongKey (bad_hash) => { error! ("Bad scraper key with hash {:?}", bad_hash); return Ok (bad_key ()?); } err => { error! ("Bad scraper key {:?}", err); return Ok (bad_key ()?); }, } } if path_rest == "test" { Ok (error_reply (StatusCode::OK, "You're valid!")?) } else { Ok (error_reply (StatusCode::NOT_FOUND, "Unknown API endpoint")?) } } #[instrument (level = "trace", skip (req, state))] async fn handle_scraper_api ( req: Request , state: Arc , path_rest: &str ) -> Result , RequestError> { { if ! state.config.read ().await.iso.enable_scraper_auth { return Ok (error_reply (StatusCode::FORBIDDEN, "Scraper API disabled")?); } } if let Some (rest) = prefix_match ("v1/", path_rest) { handle_scraper_api_v1 (req, state, rest).await } else if let Some (rest) = prefix_match ("api/", path_rest) { handle_scraper_api_v1 (req, state, rest).await } else { Ok (error_reply (StatusCode::NOT_FOUND, "Unknown scraper API version")?) } } #[instrument (level = "trace", skip (req, state))] async fn handle_all (req: Request , state: Arc ) -> Result , RequestError> { let path = req.uri ().path ().to_string (); //println! ("{}", path); debug! ("Request path: {}", path); if req.method () == Method::POST { // This is stuff the server can use. Clients can't // POST right now return if let Some (request_code) = prefix_match ("/7ZSFUKGV/http_response/", &path) { let request_code = request_code.into (); Ok (server_endpoint::handle_response (req, state, request_code).await?) } else { Ok (error_reply (StatusCode::BAD_REQUEST, "Can't POST this")?) }; } if let Some (listen_code) = prefix_match ("/7ZSFUKGV/http_listen/", &path) { let api_key = req.headers ().get ("X-ApiKey"); let api_key = match api_key { None => return Ok (error_reply (StatusCode::FORBIDDEN, "Can't run server without an API key")?), Some (x) => x, }; server_endpoint::handle_listen (state, listen_code.into (), api_key.as_bytes ()).await } else if let Some (rest) = prefix_match ("/frontend/servers/", &path) { if rest == "" { Ok (handle_server_list (state).await?) } else if let Some (idx) = rest.find ('/') { let listen_code = String::from (&rest [0..idx]); let path = String::from (&rest [idx..]); let (parts, _) = req.into_parts (); Ok (handle_http_request (parts, path, state, listen_code).await?) } else { Ok (error_reply (StatusCode::BAD_REQUEST, "Bad URI format")?) } } else if path == "/" { let s = state.handlebars.render ("relay_root", &())?; Ok (ok_reply (s)?) } else if path == "/frontend/relay_up_check" { Ok (error_reply (StatusCode::OK, "Relay is up")?) } else if path == "/frontend/test_mysterious_error" { Err (RequestError::Mysterious) } else if let Some (rest) = prefix_match ("/scraper/", &path) { handle_scraper_api (req, state, rest).await } else { Ok (error_reply (StatusCode::OK, "Hi")?) } } pub fn load_templates (asset_root: &Path) -> Result , RelayError> { let mut handlebars = Handlebars::new (); handlebars.set_strict_mode (true); let asset_root = asset_root.join ("handlebars/relay"); for (k, v) in &[ ("relay_server_list", "relay_server_list.html"), ("relay_root", "relay_root.html"), ] { handlebars.register_template_file (k, &asset_root.join (v))?; } Ok (handlebars) } async fn reload_config ( state: &Arc , config_reload_path: &Path ) -> Result <(), ConfigError> { let new_config = Config::from_file (config_reload_path).await?; let mut config = state.config.write ().await; (*config) = new_config; debug! ("Loaded {} server configs", config.servers.len ()); debug! ("enable_scraper_auth: {}", config.iso.enable_scraper_auth); if config.iso.dev_mode.is_some () { error! ("Dev mode is enabled! This might turn off some security features. If you see this in production, escalate it to someone!"); } Ok (()) } pub async fn run_relay ( state: Arc , shutdown_oneshot: oneshot::Receiver <()>, config_reload_path: Option ) -> Result <(), RelayError> { if let Some (config_reload_path) = config_reload_path { let state_2 = state.clone (); tokio::spawn (async move { let mut reload_interval = tokio::time::interval (Duration::from_secs (60)); loop { reload_interval.tick ().await; reload_config (&state_2, &config_reload_path).await.ok (); } }); } let make_svc = make_service_fn (|_conn| { let state = state.clone (); async { Ok::<_, RequestError> (service_fn (move |req| { let state = state.clone (); handle_all (req, state) })) } }); let addr = SocketAddr::from (( [0, 0, 0, 0], state.config.read ().await.port.unwrap_or (4000), )); let server = Server::bind (&addr) .serve (make_svc); server.with_graceful_shutdown (async { use ShuttingDownError::ShuttingDown; shutdown_oneshot.await.ok (); state.shutdown_watch_tx.broadcast (true).expect ("Can't broadcast graceful shutdown"); let mut response_rendezvous = state.response_rendezvous.write ().await; let mut swapped = DashMap::default (); std::mem::swap (&mut swapped, &mut response_rendezvous); for (_, sender) in swapped { sender.send (Err (ShuttingDown)).ok (); } let mut request_rendezvous = state.request_rendezvous.lock ().await; for (_, x) in request_rendezvous.drain () { use RequestRendezvous::*; match x { ParkedClients (_) => (), ParkedServer (sender) => drop (sender.send (Err (ShuttingDown))), } } debug! ("Performed all cleanup"); }).await?; Ok (()) } #[cfg (test)] mod tests;