//! # PTTH Relay //! //! The PTTH relay accepts incoming connections from PTTH servers, and //! acts as a reverse proxy, forwarding incoming requests from HTTP clients //! to PTTH servers. #![warn (clippy::pedantic)] // I don't see the point of writing the type twice if I'm initializing a struct // and the type is already in the struct definition. #![allow (clippy::default_trait_access)] // I'm not sure if I like this one #![allow (clippy::enum_glob_use)] // I don't see the point in documenting the errors outside of where the // error type is defined. #![allow (clippy::missing_errors_doc)] // False positive on futures::select! macro #![allow (clippy::mut_mut)] use std::{ borrow::Cow, convert::Infallible, net::SocketAddr, path::{Path, PathBuf}, sync::Arc, time::Duration, }; use chrono::{ DateTime, SecondsFormat, Utc }; use dashmap::DashMap; use futures_util::StreamExt; use handlebars::Handlebars; use hyper::{ Body, Request, Response, Server, StatusCode, }; use hyper::service::{make_service_fn, service_fn}; use serde::{ Serialize, }; use tokio::{ sync::{ oneshot, }, }; use tokio_stream::wrappers::ReceiverStream; use ptth_core::{ http_serde, prelude::*, }; pub mod config; pub mod errors; pub mod key_validity; mod git_version; mod relay_state; mod routing; mod scraper_api; mod server_endpoint; pub use config::Config; pub use errors::*; pub use relay_state::Relay; use relay_state::{ RejectedServer, RequestRendezvous, }; fn ok_reply > (b: B) -> Result , http::Error> { Response::builder ().status (StatusCode::OK).body (b.into ()) } fn error_reply (status: StatusCode, b: &str) -> Result , http::Error> { Response::builder () .status (status) .header ("content-type", "text/plain") .body (format! ("{}\n", b).into ()) } // Clients will come here to start requests, and always park for at least // a short amount of time. async fn handle_http_request ( req: http::request::Parts, uri: String, state: Arc , server_name: &str ) -> Result , http::Error> { use crate::relay_state::{ AuditData, AuditEvent, }; { let config = state.config.read ().await; if ! config.servers.contains_key (server_name) { return error_reply (StatusCode::NOT_FOUND, "Unknown server"); } } let req = match http_serde::RequestParts::from_hyper (req.method, uri, req.headers) { Ok (x) => x, Err (_) => return error_reply (StatusCode::BAD_REQUEST, "Bad request"), }; let (tx, rx) = oneshot::channel (); let req_id = rusty_ulid::generate_ulid_string (); state.audit_log.push (AuditEvent::new (AuditData::WebClientGet { req_id: req_id.clone (), server_name: server_name.to_string (), })).await; trace! ("Created request {}", req_id); { let response_rendezvous = state.response_rendezvous.read ().await; response_rendezvous.insert (req_id.clone (), tx); } { use RequestRendezvous::*; let mut request_rendezvous = state.request_rendezvous.lock ().await; let wrapped = http_serde::WrappedRequest { id: req_id.clone (), req, }; let new_rendezvous = match request_rendezvous.remove (server_name) { Some (ParkedClients (mut v)) => { debug! ("Parking request {} ({} already queued)", req_id, v.len ()); v.push (wrapped); ParkedClients (v) }, Some (ParkedServer (s)) => { // If sending to the server fails, queue it match s.send (Ok (wrapped)) { Ok (()) => { // TODO: This can actually still fail, if the server // disconnects right as we're sending this. // Then what? trace! ( "Sending request {} directly to server {}", req_id, server_name, ); ParkedClients (vec! []) }, Err (Ok (wrapped)) => { debug! ("Parking request {}", req_id); ParkedClients (vec! [wrapped]) }, Err (_) => unreachable! (), } }, None => { debug! ("Parking request {}", req_id); ParkedClients (vec! [wrapped]) }, }; request_rendezvous.insert (server_name.to_string (), new_rendezvous); } let timeout = tokio::time::sleep (std::time::Duration::from_secs (30)); let received = tokio::select! { val = rx => val, () = timeout => { debug! ("Timed out request {}", req_id); return error_reply (StatusCode::GATEWAY_TIMEOUT, "Remote server never responded") }, }; // UKAUFFY4 (Receive half) match received { Ok (Ok ((parts, body))) => { let mut resp = Response::builder () .status (hyper::StatusCode::from (parts.status_code)); for (k, v) in parts.headers { resp = resp.header (&k, v); } debug! ("Unparked request {}", req_id); resp.body (body) }, Ok (Err (ShuttingDownError::ShuttingDown)) => { error_reply (StatusCode::GATEWAY_TIMEOUT, "Relay shutting down") }, Err (_) => { debug! ("Responder sender dropped for request {}", req_id); error_reply (StatusCode::GATEWAY_TIMEOUT, "Remote server timed out") }, } } #[derive (Debug, PartialEq)] enum LastSeen { Negative, Connected, Description (String), } // Mnemonic is "now - last_seen" fn pretty_print_last_seen ( now: DateTime , last_seen: DateTime ) -> LastSeen { use LastSeen::*; let dur = now.signed_duration_since (last_seen); if dur < chrono::Duration::zero () { return Negative; } if dur.num_minutes () < 1 { return Connected; } if dur.num_hours () < 1 { return Description (format! ("{} m ago", dur.num_minutes ())); } if dur.num_days () < 1 { return Description (format! ("{} h ago", dur.num_hours ())); } Description (last_seen.to_rfc3339_opts (SecondsFormat::Secs, true)) } #[derive (Serialize)] struct ServerEntry <'a> { name: String, display_name: String, last_seen: Cow <'a, str>, } #[derive (Serialize)] struct ServerListPage <'a> { dev_mode: bool, git_version: Option , servers: Vec >, news_url: Option , } #[derive (Serialize)] struct UnregisteredServerListPage { unregistered_servers: Vec , } #[derive (Serialize)] struct UnregisteredServer { name: String, tripcode: String, last_seen: String, } #[derive (Serialize)] struct AuditLogPage { audit_log: Vec , } async fn handle_server_list_internal (state: &Arc ) -> ServerListPage <'static> { use LastSeen::*; let dev_mode; let news_url; { let guard = state.config.read ().await; dev_mode = guard.iso.dev_mode.is_some (); news_url = guard.news_url.clone (); } let git_version = git_version::read ().await; let server_list = scraper_api::v1_server_list (&state).await; let now = Utc::now (); let servers = server_list.servers.into_iter () .map (|x| { let last_seen = match x.last_seen { None => "Never".into (), Some (x) => match pretty_print_last_seen (now, x) { Negative => "Error (negative time)".into (), Connected => "Connected".into (), Description (s) => s.into (), }, }; ServerEntry { name: x.name, display_name: x.display_name, last_seen, } }) .collect (); ServerListPage { dev_mode, git_version, servers, news_url, } } async fn handle_unregistered_servers_internal (state: &Arc ) -> UnregisteredServerListPage { use LastSeen::*; let now = Utc::now (); let server_list = state.unregistered_servers.to_vec ().await; let unregistered_servers = server_list.into_iter () .map (|x| { let last_seen = match pretty_print_last_seen (now, x.seen) { Negative => "Error (negative time)".into (), Connected => "Recently".into (), Description (s) => s, }; UnregisteredServer { name: x.name, tripcode: base64::encode (x.tripcode.as_bytes ()), last_seen, } }).collect (); UnregisteredServerListPage { unregistered_servers, } } async fn handle_audit_log_internal (state: &Arc ) -> AuditLogPage { let audit_log = state.audit_log.to_vec ().await .iter ().rev ().map (|e| format! ("{:?}", e)).collect (); AuditLogPage { audit_log, } } async fn handle_server_list ( state: Arc , handlebars: Arc > ) -> Result , RequestError> { let page = handle_server_list_internal (&state).await; let s = handlebars.render ("server_list", &page)?; Ok (ok_reply (s)?) } async fn handle_unregistered_servers ( state: Arc , handlebars: Arc > ) -> Result , RequestError> { let page = handle_unregistered_servers_internal (&state).await; let s = handlebars.render ("unregistered_servers", &page)?; Ok (ok_reply (s)?) } async fn handle_audit_log ( state: Arc , handlebars: Arc > ) -> Result , RequestError> { let page = handle_audit_log_internal (&state).await; let s = handlebars.render ("audit_log", &page)?; Ok (ok_reply (s)?) } async fn handle_endless_sink (req: Request ) -> Result , http::Error> { let (_parts, mut body) = req.into_parts (); let mut bytes_received = 0; loop { let item = body.next ().await; if let Some (item) = item { if let Ok (bytes) = &item { bytes_received += bytes.len (); } } else { debug! ("Finished sinking debug bytes"); break; } } Ok (ok_reply (format! ("Sank {} bytes\n", bytes_received))?) } async fn handle_endless_source (gib: usize, throttle: Option ) -> Result , http::Error> { use tokio::sync::mpsc; let block_bytes = 64 * 1024; let num_blocks = (1024 * 1024 * 1024 / block_bytes) * gib; let (tx, rx) = mpsc::channel (1); tokio::spawn (async move { let random_block = { use rand::RngCore; let mut rng = rand::thread_rng (); let mut block = vec! [0_u8; 64 * 1024]; rng.fill_bytes (&mut block); block }; let mut interval = tokio::time::interval (Duration::from_millis (1000)); let mut blocks_sent = 0; while blocks_sent < num_blocks { if throttle.is_some () { interval.tick ().await; } for _ in 0..throttle.unwrap_or (1) { let item = Ok::<_, Infallible> (random_block.clone ()); if tx.send (item).await.is_err () { debug! ("Endless source dropped"); return; } blocks_sent += 1; } } debug! ("Endless source ended"); }); Response::builder () .status (StatusCode::OK) .header ("content-type", "application/octet-stream") .body (Body::wrap_stream (ReceiverStream::new (rx))) } async fn handle_gen_scraper_key (_state: Arc ) -> Result , http::Error> { let key = ptth_core::gen_key (); let body = format! ("Random key: {}\n", key); Response::builder () .status (StatusCode::OK) .header ("content-type", "text/plain") .body (Body::from (body)) } #[instrument (level = "trace", skip (req, state, handlebars))] async fn handle_all ( req: Request , state: Arc , handlebars: Arc > ) -> Result , RequestError> { use routing::Route::*; // The path is cloned here, so it's okay to consume the request // later. let path = req.uri ().path ().to_string (); trace! ("Request path: {}", path); let route = routing::route_url (req.method (), &path); let response = match route { ClientAuditLog => handle_audit_log (state, handlebars).await?, ClientRelayIsUp => error_reply (StatusCode::OK, "Relay is up")?, ClientServerGet { listen_code, path, } => { let (parts, _) = req.into_parts (); handle_http_request (parts, path.to_string (), state, listen_code).await? }, ClientServerList => handle_server_list (state, handlebars).await?, ClientUnregisteredServers => handle_unregistered_servers (state, handlebars).await?, Debug => { let s = handlebars.render ("debug", &())?; ok_reply (s)? }, DebugEndlessSink => handle_endless_sink (req).await?, DebugEndlessSource (throttle) => handle_endless_source (1, throttle).await?, DebugGenKey => handle_gen_scraper_key (state).await?, DebugMysteriousError => return Err (RequestError::Mysterious), ErrorBadUriFormat => error_reply (StatusCode::BAD_REQUEST, "Bad URI format")?, ErrorCantPost => { error! ("Can't POST {}", path); error_reply (StatusCode::BAD_REQUEST, "Can't POST this")? }, ErrorMethodNotAllowed => error_reply (StatusCode::METHOD_NOT_ALLOWED, "Method not allowed. Are you POST-ing to a GET-only url, or vice versa?")?, ErrorRoutingFailed => error_reply (StatusCode::OK, "URL routing failed")?, Root => { let s = handlebars.render ("root", &())?; ok_reply (s)? }, Scraper { rest, } => scraper_api::handle (req, state, rest).await?, ServerHttpListen { listen_code, } => { let api_key = req.headers ().get ("X-ApiKey"); let api_key = match api_key { None => return Ok (error_reply (StatusCode::FORBIDDEN, "Can't run server without an API key")?), Some (x) => x, }; server_endpoint::handle_listen (state, listen_code.into (), api_key.as_bytes ()).await? }, ServerHttpResponse { request_code, } => { let request_code = request_code.into (); server_endpoint::handle_response (req, state, request_code).await? }, }; Ok (response) } fn load_templates (asset_root: &Path) -> Result , RelayError> { let mut handlebars = Handlebars::new (); handlebars.set_strict_mode (true); let asset_root = asset_root.join ("handlebars/relay"); for (k, v) in &[ ("audit_log", "audit_log.hbs"), ("debug", "debug.hbs"), ("root", "root.hbs"), ("server_list", "server_list.hbs"), ("unregistered_servers", "unregistered_servers.hbs"), ] { handlebars.register_template_file (k, &asset_root.join (v))?; } Ok (handlebars) } async fn reload_config ( state: &Arc , config_reload_path: &Path ) -> Result <(), ConfigError> { let new_config = Config::from_file (config_reload_path).await?; let mut config = state.config.write ().await; trace! ("Reloading config"); if config.servers.len () != new_config.servers.len () { debug! ("Loaded {} server configs", config.servers.len ()); } if config.iso.enable_scraper_api != new_config.iso.enable_scraper_api { debug! ("enable_scraper_api: {}", config.iso.enable_scraper_api); } (*config) = new_config; if config.iso.dev_mode.is_some () { error! ("Dev mode is enabled! This might turn off some security features. If you see this in production, escalate it to someone!"); } Ok (()) } pub async fn run_relay ( state: Arc , asset_root: &Path, shutdown_oneshot: oneshot::Receiver <()>, config_reload_path: Option ) -> Result <(), RelayError> { use crate::relay_state::{ AuditData, AuditEvent, }; let handlebars = Arc::new (load_templates (asset_root)?); if let Some (x) = git_version::read ().await { info! ("ptth_relay Git version: {:?}", x); } else { info! ("ptth_relay not built from Git"); } if let Some (config_reload_path) = config_reload_path { let state_2 = state.clone (); tokio::spawn (async move { let mut reload_interval = tokio::time::interval (Duration::from_secs (60)); loop { reload_interval.tick ().await; reload_config (&state_2, &config_reload_path).await.ok (); } }); } let make_svc = make_service_fn (|_conn| { let state = state.clone (); let handlebars = handlebars.clone (); async { Ok::<_, Infallible> (service_fn (move |req| { let state = state.clone (); let handlebars = handlebars.clone (); async { Ok::<_, Infallible> (handle_all (req, state, handlebars).await.unwrap_or_else (|e| { error! ("{}", e); error_reply (StatusCode::INTERNAL_SERVER_ERROR, "Error in relay").unwrap () })) } })) } }); let addr = SocketAddr::from (( [0, 0, 0, 0], state.config.read ().await.port.unwrap_or (4000), )); let server = Server::bind (&addr) .serve (make_svc); state.audit_log.push (AuditEvent::new (AuditData::RelayStart)).await; server.with_graceful_shutdown (async { use ShuttingDownError::ShuttingDown; shutdown_oneshot.await.ok (); state.shutdown_watch_tx.send (true).expect ("Can't broadcast graceful shutdown"); let mut response_rendezvous = state.response_rendezvous.write ().await; let mut swapped = DashMap::default (); std::mem::swap (&mut swapped, &mut response_rendezvous); for (_, sender) in swapped { sender.send (Err (ShuttingDown)).ok (); } let mut request_rendezvous = state.request_rendezvous.lock ().await; for (_, x) in request_rendezvous.drain () { use RequestRendezvous::*; match x { ParkedClients (_) => (), ParkedServer (sender) => drop (sender.send (Err (ShuttingDown))), } } debug! ("Performed all cleanup"); }).await?; Ok (()) } #[cfg (test)] mod tests;