//! # PTTH Relay //! //! The PTTH relay accepts incoming connections from PTTH servers, and //! acts as a reverse proxy, forwarding incoming requests from HTTP clients //! to PTTH servers. #![warn (clippy::pedantic)] // I don't see the point of writing the type twice if I'm initializing a struct // and the type is already in the struct definition. #![allow (clippy::default_trait_access)] // I'm not sure if I like this one #![allow (clippy::enum_glob_use)] // I don't see the point in documenting the errors outside of where the // error type is defined. #![allow (clippy::missing_errors_doc)] // False positive on futures::select! macro #![allow (clippy::mut_mut)] use std::{ borrow::Cow, convert::Infallible, net::SocketAddr, path::{Path, PathBuf}, sync::Arc, time::Duration, }; use anyhow::bail; use chrono::{ DateTime, SecondsFormat, Utc }; use dashmap::DashMap; use futures_util::StreamExt; use handlebars::Handlebars; use hyper::{ Body, Request, Response, Server, StatusCode, }; use hyper::service::{make_service_fn, service_fn}; use serde::{ Serialize, }; use tokio::{ sync::{ oneshot, }, }; use tokio_stream::wrappers::ReceiverStream; use ptth_core::{ http_serde, prelude::*, }; pub mod config; pub mod errors; pub mod key_validity; mod git_version; mod relay_state; mod routing; mod scraper_api; mod server_endpoint; pub use config::{ Config, machine_editable, }; pub use errors::*; pub use relay_state::Relay; use relay_state::{ AuditData, AuditEvent, }; use relay_state::{ RejectedServer, RequestRendezvous, }; fn ok_reply > (b: B) -> Result , http::Error> { Response::builder ().status (StatusCode::OK).body (b.into ()) } fn error_reply (status: StatusCode, b: &str) -> Result , http::Error> { Response::builder () .status (status) .header ("content-type", "text/plain") .body (format! ("{}\n", b).into ()) } fn get_user_name (req: &http::request::Parts) -> Option { req.headers.get ("X-Email").and_then (|x| Some (x.to_str ().ok ()?.to_string ())) } /// Clients will come here to start requests, and always park for at least /// a short amount of time. async fn handle_http_request ( req: http::request::Parts, uri: String, state: &Relay, server_name: &str ) -> Result , RequestError> { use RequestError::*; let req_id = rusty_ulid::generate_ulid_string (); debug! ("Created request {}", req_id); let req_method = req.method.clone (); if ! state.server_exists (server_name).await { return Err (UnknownServer); } let req = http_serde::RequestParts::from_hyper (req.method, uri.clone (), req.headers) .map_err (|_| BadRequest)?; let (tx, rx) = oneshot::channel (); let req_id = rusty_ulid::generate_ulid_string (); debug! ("Forwarding {}", req_id); { let response_rendezvous = state.response_rendezvous.read ().await; response_rendezvous.insert (req_id.clone (), tx); } state.park_client (server_name, req, &req_id).await; // UKAUFFY4 (Receive half) let received = match tokio::time::timeout (Duration::from_secs (30), rx).await { Err (_) => { debug! ("Timed out request {}", req_id); return Err (ServerNeverResponded); } Ok (x) => x, }; let received = match received { Err (_) => { debug! ("Responder sender dropped for request {}", req_id); return Err (ServerTimedOut); }, Ok (x) => x, }; let (parts, body) = match received { Err (ShuttingDownError::ShuttingDown) => { return Err (RelayShuttingDown); }, Ok (x) => x, }; let mut resp = Response::builder () .status (hyper::StatusCode::from (parts.status_code)); if req_method == hyper::Method::GET && parts.headers.get ("accept-ranges").is_some () { trace! ("Stream restart code could go here"); } for (k, v) in parts.headers { resp = resp.header (&k, v); } debug! ("Unparked request {}", req_id); Ok (resp.body (body)?) } #[derive (Debug, PartialEq)] enum LastSeen { Negative, Connected, Description (String), } fn pretty_print_utc ( now: DateTime , last_seen: DateTime ) -> String { let dur = now.signed_duration_since (last_seen); if dur < chrono::Duration::zero () { return last_seen.to_rfc3339_opts (SecondsFormat::Secs, true); } if dur.num_minutes () < 1 { return format! ("{} s ago", dur.num_seconds ()); } if dur.num_hours () < 1 { return format! ("{} m ago", dur.num_minutes ()); } last_seen.to_rfc3339_opts (SecondsFormat::Secs, true) } // Mnemonic is "now - last_seen" fn pretty_print_last_seen ( now: DateTime , last_seen: DateTime ) -> LastSeen { use LastSeen::*; let dur = now.signed_duration_since (last_seen); if dur < chrono::Duration::zero () { return Negative; } if dur.num_minutes () < 1 { return Connected; } if dur.num_hours () < 1 { return Description (format! ("{} m ago", dur.num_minutes ())); } if dur.num_days () < 1 { return Description (format! ("{} h ago", dur.num_hours ())); } Description (last_seen.to_rfc3339_opts (SecondsFormat::Secs, true)) } #[derive (Serialize)] struct ServerEntry <'a> { name: String, display_name: String, last_seen: Cow <'a, str>, } #[derive (Serialize)] struct ServerListPage <'a> { dev_mode: bool, git_version: Option , servers: Vec >, news_url: Option , connected_server_count: usize, registered_server_count: usize, date_rfc3339: String, } #[derive (Serialize)] struct UnregisteredServerListPage { unregistered_servers: Vec , } #[derive (Serialize)] struct UnregisteredServer { name: String, tripcode: String, last_seen: String, } #[derive (Serialize)] struct AuditLogPage { audit_log: Vec , } #[derive (Serialize)] struct AuditEntryPretty { utc_pretty: String, data_pretty: String, } async fn handle_server_list_internal (state: &Relay) -> ServerListPage <'static> { use LastSeen::*; let dev_mode; let news_url; { let guard = state.config.read ().await; dev_mode = guard.iso.dev_mode.is_some (); news_url = guard.news_url.clone (); } let git_version = git_version::read ().await; let server_list = scraper_api::v1_server_list (&state).await; let now = Utc::now (); let registered_server_count = server_list.servers.len (); let mut connected_server_count = 0; let servers = server_list.servers.into_iter () .map (|x| { let last_seen = match x.last_seen { None => "Never".into (), Some (x) => match pretty_print_last_seen (now, x) { Negative => "Error (negative time)".into (), Connected => { connected_server_count += 1; "Connected".into () }, Description (s) => s.into (), }, }; ServerEntry { name: x.name, display_name: x.display_name, last_seen, } }) .collect (); let date_rfc3339 = now.to_rfc3339_opts (SecondsFormat::Secs, true); ServerListPage { dev_mode, git_version, servers, news_url, connected_server_count, registered_server_count, date_rfc3339, } } async fn handle_unregistered_servers_internal (state: &Relay) -> UnregisteredServerListPage { use LastSeen::*; let now = Utc::now (); let mut server_list = state.unregistered_servers.to_vec ().await; { let me_config = state.me_config.read ().await; server_list = server_list.into_iter () .filter (|s| ! me_config.servers.contains_key (&s.name)) .collect (); } server_list.sort_by_key (|s| { (s.name.clone (), *s.tripcode.as_bytes (), now - s.seen) }); server_list.dedup_by_key (|s| { (s.name.clone (), s.tripcode) }); let unregistered_servers = server_list.into_iter () .map (|x| { let last_seen = match pretty_print_last_seen (now, x.seen) { Negative => "Error (negative time)".into (), Connected => "Recently".into (), Description (s) => s, }; let tripcode = base64::encode (x.tripcode.as_bytes ()); UnregisteredServer { name: x.name, tripcode, last_seen, } }).collect (); UnregisteredServerListPage { unregistered_servers, } } async fn handle_audit_log_internal (state: &Relay) -> AuditLogPage { let utc_now = Utc::now (); let audit_log = state.audit_log.to_vec ().await .iter ().rev ().map (|e| { AuditEntryPretty { utc_pretty: pretty_print_utc (utc_now, e.time_utc), data_pretty: format! ("{:?}", e.data), } }).collect (); AuditLogPage { audit_log, } } async fn handle_server_list ( state: &Relay, handlebars: Arc > ) -> Result , RequestError> { let page = handle_server_list_internal (&state).await; let s = handlebars.render ("server_list", &page)?; Ok (ok_reply (s)?) } async fn handle_unregistered_servers ( state: &Relay, handlebars: Arc > ) -> Result , RequestError> { let page = handle_unregistered_servers_internal (&state).await; let s = handlebars.render ("unregistered_servers", &page)?; Ok (ok_reply (s)?) } async fn handle_audit_log ( state: &Relay, handlebars: Arc > ) -> Result , RequestError> { { let cfg = state.config.read ().await; if cfg.hide_audit_log { return Ok (error_reply (StatusCode::FORBIDDEN, "Forbidden")?); } } let page = handle_audit_log_internal (state).await; let s = handlebars.render ("audit_log", &page)?; Ok (ok_reply (s)?) } async fn handle_endless_sink (req: Request ) -> Result , http::Error> { let (_parts, mut body) = req.into_parts (); let mut bytes_received = 0; loop { let item = body.next ().await; if let Some (item) = item { if let Ok (bytes) = &item { bytes_received += bytes.len (); } } else { debug! ("Finished sinking debug bytes"); break; } } Ok (ok_reply (format! ("Sank {} bytes\n", bytes_received))?) } async fn handle_endless_source (gib: usize, throttle: Option ) -> Result , http::Error> { use tokio::sync::mpsc; let block_bytes = 64 * 1024; let num_blocks = (1024 * 1024 * 1024 / block_bytes) * gib; let (tx, rx) = mpsc::channel (1); tokio::spawn (async move { let random_block = { use rand::RngCore; let mut rng = rand::thread_rng (); let mut block = vec! [0_u8; 64 * 1024]; rng.fill_bytes (&mut block); block }; let mut interval = tokio::time::interval (Duration::from_millis (1000)); interval.set_missed_tick_behavior (tokio::time::MissedTickBehavior::Skip); let mut blocks_sent = 0; while blocks_sent < num_blocks { if throttle.is_some () { interval.tick ().await; } for _ in 0..throttle.unwrap_or (1) { let item = Ok::<_, Infallible> (random_block.clone ()); if tx.send (item).await.is_err () { debug! ("Endless source dropped"); return; } blocks_sent += 1; } } debug! ("Endless source ended"); }); Response::builder () .status (StatusCode::OK) .header ("content-type", "application/octet-stream") .body (Body::wrap_stream (ReceiverStream::new (rx))) } async fn handle_gen_scraper_key (_state: &Relay) -> Result , http::Error> { let key = ptth_core::gen_key (); let body = format! ("Random key: {}\n", key); Response::builder () .status (StatusCode::OK) .header ("content-type", "text/plain") .body (Body::from (body)) } async fn handle_register_server (req: Request , state: &Relay) -> Result <(), anyhow::Error> { let (parts, body) = req.into_parts (); let user = get_user_name (&parts); let form_data = read_body_limited (body, 1_024).await?; let server: crate::config::file::Server = serde_urlencoded::from_bytes (&form_data)?; state.audit_log.push (AuditEvent::new (AuditData::RegisterServer { user, server: server.clone (), })).await; { let mut me_config = state.me_config.write ().await; me_config.servers.insert (server.name.clone (), server); me_config.save (Path::new ("data/ptth_relay_me_config.toml")).await?; } Ok (()) } async fn read_body_limited (mut body: Body, limit: usize) -> anyhow::Result > { let mut buffer = vec! []; while let Some (chunk) = body.next ().await { let chunk = chunk?; if buffer.len () + chunk.len () > limit { bail! ("Body was bigger than limit"); } buffer.extend_from_slice (&chunk); } Ok (buffer) } #[instrument (level = "trace", skip (req, state, handlebars))] async fn handle_all ( req: Request , state: Arc , handlebars: Arc > ) -> Result , RequestError> { use routing::Route::*; let state = &*state; // The path is cloned here, so it's okay to consume the request // later. let path = req.uri ().path ().to_string (); trace! ("Request path: {}", path); let route = routing::route_url (req.method (), &path); let route = match route { Ok (x) => x, Err (e) => { use routing::Error; let response = match e { Error::BadUriFormat => error_reply (StatusCode::BAD_REQUEST, "Bad URI format")?, Error::MethodNotAllowed => error_reply (StatusCode::METHOD_NOT_ALLOWED, "Method not allowed. Are you POST-ing to a GET-only url, or vice versa?")?, Error::NotFound => error_reply (StatusCode::OK, "URL routing failed")?, }; return Ok (response); }, }; let response = match route { ClientAuditLog => handle_audit_log (state, handlebars).await?, ClientRelayIsUp => error_reply (StatusCode::OK, "Relay is up")?, ClientServerGet { listen_code, path, } => { let (parts, _) = req.into_parts (); let user = get_user_name (&parts); state.audit_log.push (AuditEvent::new (AuditData::WebClientGet { user, server_name: listen_code.to_string (), uri: path.to_string (), })).await; handle_http_request (parts, path.to_string (), &state, listen_code).await? }, ClientServerList => handle_server_list (state, handlebars).await?, ClientUnregisteredServers => handle_unregistered_servers (state, handlebars).await?, Debug => { let s = handlebars.render ("debug", &())?; ok_reply (s)? }, DebugEndlessSink => handle_endless_sink (req).await?, DebugEndlessSource (throttle) => handle_endless_source (1, throttle).await?, DebugGenKey => handle_gen_scraper_key (state).await?, DebugMysteriousError => return Err (RequestError::Mysterious), RegisterServer => { match handle_register_server (req, state).await { Ok (_) => Response::builder () .status (StatusCode::SEE_OTHER) .header ("location", "unregistered_servers") .body (Body::from ("Success. Redirecting..."))?, Err (e) => error_reply (StatusCode::BAD_REQUEST, &format! ("{:?}", e))?, } } Root => { let s = handlebars.render ("root", &())?; ok_reply (s)? }, Scraper { rest, } => scraper_api::handle (req, state, rest).await?, ServerHttpListen { listen_code, } => { let api_key = req.headers ().get ("X-ApiKey"); let api_key = match api_key { None => return Ok (error_reply (StatusCode::FORBIDDEN, "Forbidden")?), Some (x) => x, }; match check_server_api_key (state, listen_code, api_key.as_bytes ()).await { Ok (_) => (), Err (_) => return Ok (error_reply (StatusCode::FORBIDDEN, "Forbidden")?) } server_endpoint::handle_listen (state, listen_code.into ()).await? }, ServerHttpResponse { request_code, } => { let request_code = request_code.into (); server_endpoint::handle_response (req, state, request_code).await? }, }; Ok (response) } async fn check_server_api_key (state: &Relay, name: &str, api_key: &[u8]) -> Result <(), anyhow::Error> { let actual_tripcode = key_validity::BlakeHashWrapper::from_key (api_key); let expected_human = { let config = state.config.read ().await; config.servers.get (name).map (|s| s.tripcode) }; let expected_machine = { let me_config = state.me_config.read ().await; me_config.servers.get (name).map (|s| s.tripcode) }; if expected_machine.is_none () && expected_human.is_none () { state.unregistered_servers.push (crate::RejectedServer { name: name.to_string (), tripcode: *actual_tripcode, seen: Utc::now (), }).await; bail! ("Denied API request for non-existent server name {}", name); } if Some (actual_tripcode) == expected_human { return Ok (()); } if Some (actual_tripcode) == expected_machine { return Ok (()); } bail! ("Denied API request for bad tripcode {}", base64::encode (actual_tripcode.as_bytes ())); } fn load_templates (asset_root: &Path) -> Result , RelayError> { let mut handlebars = Handlebars::new (); handlebars.set_strict_mode (true); let asset_root = asset_root.join ("handlebars/relay"); for (k, v) in &[ ("audit_log", "audit_log.hbs"), ("debug", "debug.hbs"), ("root", "root.hbs"), ("server_list", "server_list.hbs"), ("unregistered_servers", "unregistered_servers.hbs"), ] { handlebars.register_template_file (k, &asset_root.join (v))?; } Ok (handlebars) } async fn reload_config ( state: &Arc , config_reload_path: &Path ) -> Result <(), ConfigError> { // Reload human-editable config let new_config = Config::from_file (config_reload_path).await?; // Reload machine-editable config, if possible // let me_config = machine_editable::Config::from_file (Path::new ("data/ptth_relay_me_config.toml")).await.ok (); let mut config = state.config.write ().await; trace! ("Reloading config"); if config.servers.len () != new_config.servers.len () { debug! ("Loaded {} server configs", new_config.servers.len ()); } if config.iso.enable_scraper_api != new_config.iso.enable_scraper_api { debug! ("enable_scraper_api: {}", new_config.iso.enable_scraper_api); } (*config) = new_config; if config.iso.dev_mode.is_some () { error! ("Dev mode is enabled! This might turn off some security features. If you see this in production, escalate it to someone!"); } Ok (()) } pub async fn run_relay ( state: Arc , asset_root: &Path, shutdown_oneshot: oneshot::Receiver <()>, config_reload_path: Option ) -> Result <(), RelayError> { let handlebars = Arc::new (load_templates (asset_root)?); if let Some (x) = git_version::read ().await { info! ("ptth_relay Git version: {:?}", x); } else { info! ("ptth_relay not built from Git"); } if let Some (config_reload_path) = config_reload_path { let state_2 = state.clone (); tokio::spawn (async move { let mut reload_interval = tokio::time::interval (Duration::from_secs (60)); reload_interval.set_missed_tick_behavior (tokio::time::MissedTickBehavior::Skip); loop { reload_interval.tick ().await; reload_config (&state_2, &config_reload_path).await.ok (); } }); } let make_svc = make_service_fn (|_conn| { let state = state.clone (); let handlebars = handlebars.clone (); async { Ok::<_, Infallible> (service_fn (move |req| { let state = state.clone (); let handlebars = handlebars.clone (); async { Ok::<_, Infallible> (handle_all (req, state, handlebars).await.unwrap_or_else (|e| { use RequestError::*; error! ("{}", e); let status_code = match &e { UnknownServer => StatusCode::NOT_FOUND, BadRequest => StatusCode::BAD_REQUEST, ServerNeverResponded | ServerTimedOut => StatusCode::GATEWAY_TIMEOUT, _ => StatusCode::INTERNAL_SERVER_ERROR, }; error_reply (status_code, "Error in relay").unwrap () })) } })) } }); let addr = { let guard = state.config.read ().await; SocketAddr::from (( guard.address, guard.port.unwrap_or (4000), )) }; let server = Server::bind (&addr) .serve (make_svc); state.audit_log.push (AuditEvent::new (AuditData::RelayStart)).await; trace! ("Serving relay on {:?}", addr); server.with_graceful_shutdown (async { use ShuttingDownError::ShuttingDown; shutdown_oneshot.await.ok (); state.shutdown_watch_tx.send (true).expect ("Can't broadcast graceful shutdown"); let mut response_rendezvous = state.response_rendezvous.write ().await; let mut swapped = DashMap::default (); std::mem::swap (&mut swapped, &mut response_rendezvous); for (_, sender) in swapped { sender.send (Err (ShuttingDown)).ok (); } let mut request_rendezvous = state.request_rendezvous.lock ().await; for (_, x) in request_rendezvous.drain () { use RequestRendezvous::*; match x { ParkedClients (_) => (), ParkedServer (sender) => drop (sender.send (Err (ShuttingDown))), } } debug! ("Performed all cleanup"); }).await?; Ok (()) } #[cfg (test)] mod tests;