use std::{ error::Error, collections::*, convert::Infallible, iter::FromIterator, net::SocketAddr, sync::{ Arc, }, time::Duration, }; use dashmap::DashMap; use futures::{ FutureExt, stream::StreamExt, }; use handlebars::Handlebars; use hyper::{ Body, Method, Request, Response, Server, StatusCode, }; use hyper::service::{make_service_fn, service_fn}; use serde::{ Deserialize, Serialize, }; use tokio::{ spawn, sync::{ Mutex, mpsc, oneshot, RwLock, watch, }, time::delay_for, }; use tracing::{ debug, error, info, trace, instrument, }; use crate::{ http_serde, prefix_match, }; /* Here's what we need to handle: When a request comes in: - Park the client in response_rendezvous - Look up the server ID in request_rendezvous - If a server is parked, unpark it and send the request - Otherwise, queue the request When a server comes to listen: - Look up the server ID in request_rendezvous - Either return all pending requests, or park the server When a server comes to respond: - Look up the parked client in response_rendezvous - Unpark the client and begin streaming So we need these lookups to be fast: - Server IDs, where (1 server) or (0 or many clients) can be parked - Request IDs, where 1 client is parked */ #[derive (Debug)] enum RelayError { RelayShuttingDown, } enum RequestRendezvous { ParkedClients (Vec ), ParkedServer (oneshot::Sender >), } type ResponseRendezvous = oneshot::Sender >; // Stuff we need to load from the config file and use to // set up the HTTP server #[derive (Default, Deserialize)] pub struct ConfigFile { pub port: Option , pub server_tripcodes: HashMap , } // Stuff we actually need at runtime struct Config { server_tripcodes: HashMap , } impl From <&ConfigFile> for Config { fn from (f: &ConfigFile) -> Self { let server_tripcodes = HashMap::from_iter (f.server_tripcodes.iter () .map (|(k, v)| { use std::convert::TryInto; let bytes: Vec = base64::decode (v).unwrap (); let bytes: [u8; 32] = (&bytes [..]).try_into ().unwrap (); let v = blake3::Hash::from (bytes); debug! ("Tripcode {} => {}", k, v.to_hex ()); (k.clone (), v) })); Self { server_tripcodes, } } } use chrono::{ DateTime, SecondsFormat, Utc }; #[derive (Clone)] pub struct ServerStatus { last_seen: DateTime , } impl Default for ServerStatus { fn default () -> Self { Self { last_seen: Utc::now (), } } } pub struct RelayState { config: RwLock , handlebars: Arc >, // Key: Server ID request_rendezvous: Mutex >, server_status: Mutex >, // Key: Request ID response_rendezvous: RwLock >, shutdown_watch_tx: watch::Sender , shutdown_watch_rx: watch::Receiver , } impl From <&ConfigFile> for RelayState { fn from (config_file: &ConfigFile) -> Self { let (shutdown_watch_tx, shutdown_watch_rx) = watch::channel (false); Self { config: Config::from (config_file).into (), handlebars: Arc::new (load_templates (&PathBuf::new ()).unwrap ()), request_rendezvous: Default::default (), server_status: Default::default (), response_rendezvous: Default::default (), shutdown_watch_tx, shutdown_watch_rx, } } } impl RelayState { pub async fn list_servers (&self) -> Vec { self.request_rendezvous.lock ().await.iter () .map (|(k, _)| (*k).clone ()) .collect () } } fn ok_reply > (b: B) -> Response { Response::builder ().status (StatusCode::OK).body (b.into ()).unwrap () } fn error_reply (status: StatusCode, b: &str) -> Response { Response::builder () .status (status) .header ("content-type", "text/plain") .body (format! ("{}\n", b).into ()).unwrap () } // Servers will come here and either handle queued requests from parked clients, // or park themselves until a request comes in. async fn handle_http_listen ( state: Arc , watcher_code: String, api_key: &[u8], ) -> Response { let trip_error = error_reply (StatusCode::UNAUTHORIZED, "Bad X-ApiKey"); let expected_tripcode = { let config = state.config.read ().await; match config.server_tripcodes.get (&watcher_code) { None => { error! ("Denied http_listen for non-existent server name {}", watcher_code); return trip_error; }, Some (x) => (*x).clone (), } }; let actual_tripcode = blake3::hash (api_key); if expected_tripcode != actual_tripcode { error! ("Denied http_listen for bad tripcode {}", base64::encode (actual_tripcode.as_bytes ())); return trip_error; } // End of early returns { let mut server_status = state.server_status.lock ().await; let mut status = server_status.entry (watcher_code.clone ()).or_insert_with (Default::default); status.last_seen = Utc::now (); } use RequestRendezvous::*; let (tx, rx) = oneshot::channel (); { let mut request_rendezvous = state.request_rendezvous.lock ().await; if let Some (ParkedClients (v)) = request_rendezvous.remove (&watcher_code) { if ! v.is_empty () { // 1 or more clients were parked - Make the server // handle them immediately debug! ("Sending {} parked requests to server {}", v.len (), watcher_code); return ok_reply (rmp_serde::to_vec (&v).unwrap ()); } } debug! ("Parking server {}", watcher_code); request_rendezvous.insert (watcher_code.clone (), ParkedServer (tx)); } // No clients were parked - make the server long-poll futures::select! { x = rx.fuse () => match x { Ok (Ok (one_req)) => { debug! ("Unparking server {}", watcher_code); ok_reply (rmp_serde::to_vec (&vec! [one_req]).unwrap ()) }, Ok (Err (RelayError::RelayShuttingDown)) => error_reply (StatusCode::SERVICE_UNAVAILABLE, "Server is shutting down, try again soon"), Err (_) => error_reply (StatusCode::INTERNAL_SERVER_ERROR, "Server error"), }, _ = delay_for (Duration::from_secs (30)).fuse () => { debug! ("Timed out http_listen for server {}", watcher_code); return error_reply (StatusCode::NO_CONTENT, "No requests now, long-poll again") } } } // Servers will come here to stream responses to clients async fn handle_http_response ( req: Request , state: Arc , req_id: String, ) -> Response { let (parts, mut body) = req.into_parts (); let resp_parts: http_serde::ResponseParts = rmp_serde::from_read_ref (&base64::decode (parts.headers.get (crate::PTTH_MAGIC_HEADER).unwrap ()).unwrap ()).unwrap (); // Intercept the body packets here so we can check when the stream // ends or errors out #[derive (Debug)] enum BodyFinishedReason { StreamFinished, ClientDisconnected, } use BodyFinishedReason::*; let (mut body_tx, body_rx) = mpsc::channel (2); let (body_finished_tx, body_finished_rx) = oneshot::channel (); let mut shutdown_watch_rx = state.shutdown_watch_rx.clone (); spawn (async move { if shutdown_watch_rx.recv ().await == Some (false) { loop { let item = body.next ().await; if let Some (item) = item { if let Ok (bytes) = &item { trace! ("Relaying {} bytes", bytes.len ()); } futures::select! { x = body_tx.send (item).fuse () => if let Err (_) = x { info! ("Body closed while relaying. (Client hung up?)"); body_finished_tx.send (ClientDisconnected).unwrap (); break; }, _ = shutdown_watch_rx.recv ().fuse () => { debug! ("Closing stream: relay is shutting down"); break; }, } } else { debug! ("Finished relaying bytes"); body_finished_tx.send (StreamFinished).unwrap (); break; } } } else { debug! ("Can't relay bytes, relay is shutting down"); } }); let body = Body::wrap_stream (body_rx); let tx = { let response_rendezvous = state.response_rendezvous.read ().await; match response_rendezvous.remove (&req_id) { None => { error! ("Server tried to respond to non-existent request"); return error_reply (StatusCode::BAD_REQUEST, "Request ID not found in response_rendezvous"); }, Some ((_, x)) => x, } }; // UKAUFFY4 (Send half) if tx.send (Ok ((resp_parts, body))).is_err () { let msg = "Failed to connect to client"; error! (msg); return error_reply (StatusCode::BAD_GATEWAY, msg); } debug! ("Connected server to client for streaming."); match body_finished_rx.await { Ok (StreamFinished) => { error_reply (StatusCode::OK, "StreamFinished") }, Ok (ClientDisconnected) => { error_reply (StatusCode::OK, "ClientDisconnected") }, Err (e) => { debug! ("body_finished_rx {}", e); error_reply (StatusCode::OK, "body_finished_rx Err") }, } } // Clients will come here to start requests, and always park for at least // a short amount of time. async fn handle_http_request ( req: http::request::Parts, uri: String, state: Arc , watcher_code: String ) -> Response { { let config = state.config.read ().await; if ! config.server_tripcodes.contains_key (&watcher_code) { return error_reply (StatusCode::NOT_FOUND, "Unknown server"); } } let req = match http_serde::RequestParts::from_hyper (req.method, uri, req.headers) { Ok (x) => x, _ => return error_reply (StatusCode::BAD_REQUEST, "Bad request"), }; let (tx, rx) = oneshot::channel (); let req_id = ulid::Ulid::new ().to_string (); { let response_rendezvous = state.response_rendezvous.read ().await; response_rendezvous.insert (req_id.clone (), tx); } trace! ("Created request {}", req_id); { let mut request_rendezvous = state.request_rendezvous.lock ().await; let wrapped = http_serde::WrappedRequest { id: req_id.clone (), req, }; use RequestRendezvous::*; let new_rendezvous = match request_rendezvous.remove (&watcher_code) { Some (ParkedClients (mut v)) => { debug! ("Parking request {} ({} already queued)", req_id, v.len ()); v.push (wrapped); ParkedClients (v) }, Some (ParkedServer (s)) => { // If sending to the server fails, queue it match s.send (Ok (wrapped)) { Ok (()) => { // TODO: This can actually still fail, if the server // disconnects right as we're sending this. // Then what? debug! ( "Sending request {} directly to server {}", req_id, watcher_code, ); ParkedClients (vec! []) }, Err (Ok (wrapped)) => { debug! ("Parking request {}", req_id); ParkedClients (vec! [wrapped]) }, Err (_) => unreachable! (), } }, None => { debug! ("Parking request {}", req_id); ParkedClients (vec! [wrapped]) }, }; request_rendezvous.insert (watcher_code, new_rendezvous); } let timeout = tokio::time::delay_for (std::time::Duration::from_secs (30)); let received = tokio::select! { val = rx => val, () = timeout => { debug! ("Timed out request {}", req_id); return error_reply (StatusCode::GATEWAY_TIMEOUT, "Remote server never responded") }, }; // UKAUFFY4 (Receive half) match received { Ok (Ok ((parts, body))) => { let mut resp = Response::builder () .status (hyper::StatusCode::from (parts.status_code)); for (k, v) in parts.headers.into_iter () { resp = resp.header (&k, v); } debug! ("Unparked request {}", req_id); resp.body (body) .unwrap () }, Ok (Err (RelayError::RelayShuttingDown)) => { error_reply (StatusCode::GATEWAY_TIMEOUT, "Relay shutting down") }, Err (_) => { debug! ("Responder sender dropped for request {}", req_id); error_reply (StatusCode::GATEWAY_TIMEOUT, "Remote server timed out") }, } } #[derive (Debug, PartialEq)] enum LastSeen { Negative, Connected, Description (String), } // Mnemonic is "now - last_seen" fn pretty_print_last_seen ( now: DateTime , last_seen: DateTime ) -> LastSeen { use LastSeen::*; let dur = now.signed_duration_since (last_seen); if dur < chrono::Duration::zero () { return Negative; } if dur.num_minutes () < 1 { return Connected; } if dur.num_hours () < 1 { return Description (format! ("{} m ago", dur.num_minutes ())); } if dur.num_days () < 1 { return Description (format! ("{} h ago", dur.num_hours ())); } Description (last_seen.to_rfc3339_opts (SecondsFormat::Secs, true)) } async fn handle_server_list ( state: Arc ) -> Response { use std::borrow::Cow; #[derive (Serialize)] struct ServerEntry <'a> { path: String, name: String, last_seen: Cow <'a, str>, } #[derive (Serialize)] struct ServerListPage <'a> { servers: Vec >, } let servers = { let guard = state.server_status.lock ().await; (*guard).clone () }; let now = Utc::now (); let mut servers: Vec <_> = servers.into_iter () .map (|(name, server)| { let display_name = percent_encoding::percent_decode_str (&name).decode_utf8 ().unwrap_or_else (|_| "Server name isn't UTF-8".into ()).to_string (); use LastSeen::*; let last_seen = match pretty_print_last_seen (now, server.last_seen) { Negative => "Error (negative time)".into (), Connected => "Connected".into (), Description (s) => s.into (), }; ServerEntry { name: display_name, path: name, last_seen: last_seen, } }) .collect (); servers.sort_by (|a, b| a.name.cmp (&b.name)); let page = ServerListPage { servers, }; let s = state.handlebars.render ("relay_server_list", &page).unwrap (); ok_reply (s) } #[instrument (level = "trace", skip (req, state))] async fn handle_all (req: Request , state: Arc ) -> Result , Infallible> { let path = req.uri ().path (); //println! ("{}", path); debug! ("Request path: {}", path); let api_key = req.headers ().get ("X-ApiKey"); if req.method () == Method::POST { // This is stuff the server can use. Clients can't // POST right now return Ok (if let Some (request_code) = prefix_match ("/7ZSFUKGV/http_response/", path) { let request_code = request_code.into (); handle_http_response (req, state, request_code).await } else { error_reply (StatusCode::BAD_REQUEST, "Can't POST this") }); } Ok (if let Some (listen_code) = prefix_match ("/7ZSFUKGV/http_listen/", path) { let api_key = match api_key { None => return Ok (error_reply (StatusCode::UNAUTHORIZED, "Can't register as server without an API key")), Some (x) => x, }; handle_http_listen (state, listen_code.into (), api_key.as_bytes ()).await } else if let Some (rest) = prefix_match ("/frontend/servers/", path) { if rest == "" { handle_server_list (state).await } else if let Some (idx) = rest.find ('/') { let listen_code = String::from (&rest [0..idx]); let path = String::from (&rest [idx..]); let (parts, _) = req.into_parts (); handle_http_request (parts, path, state, listen_code).await } else { error_reply (StatusCode::BAD_REQUEST, "Bad URI format") } } else if path == "/" { let s = state.handlebars.render ("relay_root", &()).unwrap (); ok_reply (s) } else if path == "/frontend/relay_up_check" { error_reply (StatusCode::OK, "Relay is up") } else { error_reply (StatusCode::OK, "Hi") }) } use std::path::{Path, PathBuf}; pub fn load_templates (asset_root: &Path) -> Result , Box > { let mut handlebars = Handlebars::new (); handlebars.set_strict_mode (true); let asset_root = asset_root.join ("handlebars/relay"); for (k, v) in vec! [ ("relay_server_list", "relay_server_list.html"), ("relay_root", "relay_root.html"), ].into_iter () { handlebars.register_template_file (k, &asset_root.join (v))?; } Ok (handlebars) } async fn reload_config ( state: &Arc , config_reload_path: &Path ) -> Option <()> { use tokio::prelude::*; let mut f = tokio::fs::File::open (config_reload_path).await.ok ()?; let mut buffer = vec! [0u8; 4096]; let bytes_read = f.read (&mut buffer).await.ok ()?; buffer.truncate (bytes_read); let config_s = String::from_utf8 (buffer).ok ()?; let new_config: ConfigFile = toml::from_str (&config_s).ok ()?; let new_config = Config::from (&new_config); let mut config = state.config.write ().await; (*config) = new_config; debug! ("Loaded {} server tripcodes", config.server_tripcodes.len ()); Some (()) } pub async fn run_relay ( state: Arc , shutdown_oneshot: oneshot::Receiver <()>, config_reload_path: Option ) -> Result <(), Box > { let addr = SocketAddr::from (( [0, 0, 0, 0], 4000, )); if let Some (config_reload_path) = config_reload_path { let state_2 = state.clone (); tokio::spawn (async move { let mut reload_interval = tokio::time::interval (Duration::from_secs (60)); loop { reload_interval.tick ().await; reload_config (&state_2, &config_reload_path).await; } }); } let make_svc = make_service_fn (|_conn| { let state = state.clone (); async { Ok::<_, Infallible> (service_fn (move |req| { let state = state.clone (); handle_all (req, state) })) } }); let server = Server::bind (&addr) .serve (make_svc); server.with_graceful_shutdown (async { shutdown_oneshot.await.ok (); state.shutdown_watch_tx.broadcast (true).unwrap (); use RelayError::*; let mut response_rendezvous = state.response_rendezvous.write ().await; let mut swapped = DashMap::default (); std::mem::swap (&mut swapped, &mut response_rendezvous); for (_, sender) in swapped.into_iter () { sender.send (Err (RelayShuttingDown)).ok (); } let mut request_rendezvous = state.request_rendezvous.lock ().await; for (_, x) in request_rendezvous.drain () { use RequestRendezvous::*; match x { ParkedClients (_) => (), ParkedServer (sender) => drop (sender.send (Err (RelayShuttingDown))), } } debug! ("Performed all cleanup"); }).await?; Ok (()) } #[cfg (test)] mod tests { use super::*; #[test] fn test_pretty_print_last_seen () { use LastSeen::*; let last_seen = DateTime::parse_from_rfc3339 ("2019-05-29T00:00:00+00:00").unwrap ().with_timezone (&Utc); for (input, expected) in vec! [ ("2019-05-28T23:59:59+00:00", Negative), ("2019-05-29T00:00:00+00:00", Connected), ("2019-05-29T00:00:59+00:00", Connected), ("2019-05-29T00:01:30+00:00", Description ("1 m ago".into ())), ("2019-05-29T00:59:30+00:00", Description ("59 m ago".into ())), ("2019-05-29T01:00:30+00:00", Description ("1 h ago".into ())), ("2019-05-29T10:00:00+00:00", Description ("10 h ago".into ())), ("2019-05-30T00:00:00+00:00", Description ("2019-05-29T00:00:00Z".into ())), ("2019-05-30T10:00:00+00:00", Description ("2019-05-29T00:00:00Z".into ())), ("2019-05-31T00:00:00+00:00", Description ("2019-05-29T00:00:00Z".into ())), ].into_iter () { let now = DateTime::parse_from_rfc3339 (input).unwrap ().with_timezone (&Utc); let actual = pretty_print_last_seen (now, last_seen); assert_eq! (actual, expected); } } }