use std::{ error::Error, collections::*, convert::Infallible, iter::FromIterator, net::SocketAddr, sync::{ Arc, }, time::Duration, }; use dashmap::DashMap; use futures::{ FutureExt, stream::StreamExt, }; use handlebars::Handlebars; use hyper::{ Body, Method, Request, Response, Server, StatusCode, }; use hyper::service::{make_service_fn, service_fn}; use serde::{ Deserialize, Serialize, }; use tokio::{ spawn, sync::{ Mutex, mpsc, oneshot, RwLock, watch, }, time::delay_for, }; use tracing::{ debug, error, info, trace, instrument, }; use crate::{ http_serde, prefix_match, }; /* Here's what we need to handle: When a request comes in: - Park the client in response_rendezvous - Look up the server ID in request_rendezvous - If a server is parked, unpark it and send the request - Otherwise, queue the request When a server comes to listen: - Look up the server ID in request_rendezvous - Either return all pending requests, or park the server When a server comes to respond: - Look up the parked client in response_rendezvous - Unpark the client and begin streaming So we need these lookups to be fast: - Server IDs, where (1 server) or (0 or many clients) can be parked - Request IDs, where 1 client is parked */ #[derive (Debug)] enum RelayError { RelayShuttingDown, } enum RequestRendezvous { ParkedClients (Vec ), ParkedServer (oneshot::Sender >), } type ResponseRendezvous = oneshot::Sender >; // Stuff we need to load from the config file and use to // set up the HTTP server #[derive (Default, Deserialize)] pub struct ConfigFile { pub port: Option , pub server_tripcodes: HashMap , } // Stuff we actually need at runtime struct Config { server_tripcodes: HashMap , } impl From <&ConfigFile> for Config { fn from (f: &ConfigFile) -> Self { let server_tripcodes = HashMap::from_iter (f.server_tripcodes.iter () .map (|(k, v)| { use std::convert::TryInto; let bytes: Vec = base64::decode (v).unwrap (); let bytes: [u8; 32] = (&bytes [..]).try_into ().unwrap (); let v = blake3::Hash::from (bytes); debug! ("Tripcode {} => {}", k, v.to_hex ()); (k.clone (), v) })); Self { server_tripcodes, } } } pub struct RelayState { config: Config, handlebars: Arc >, // Key: Server ID request_rendezvous: Mutex >, // Key: Request ID response_rendezvous: RwLock >, shutdown_watch_tx: watch::Sender , shutdown_watch_rx: watch::Receiver , } impl From <&ConfigFile> for RelayState { fn from (config_file: &ConfigFile) -> Self { let (shutdown_watch_tx, shutdown_watch_rx) = watch::channel (false); Self { config: Config::from (config_file), handlebars: Arc::new (load_templates ().unwrap ()), request_rendezvous: Default::default (), response_rendezvous: Default::default (), shutdown_watch_tx, shutdown_watch_rx, } } } impl RelayState { pub async fn list_servers (&self) -> Vec { self.request_rendezvous.lock ().await.iter () .map (|(k, _)| (*k).clone ()) .collect () } } fn ok_reply > (b: B) -> Response { Response::builder ().status (StatusCode::OK).body (b.into ()).unwrap () } fn error_reply (status: StatusCode, b: &str) -> Response { Response::builder () .status (status) .header ("content-type", "text/plain") .body (format! ("{}\n", b).into ()).unwrap () } // Servers will come here and either handle queued requests from parked clients, // or park themselves until a request comes in. async fn handle_http_listen ( state: Arc , watcher_code: String, api_key: &[u8], ) -> Response { let trip_error = error_reply (StatusCode::UNAUTHORIZED, "Bad X-ApiKey"); let expected_tripcode = match state.config.server_tripcodes.get (&watcher_code) { None => { error! ("Denied http_listen for non-existent server name {}", watcher_code); return trip_error; }, Some (x) => x, }; let actual_tripcode = blake3::hash (api_key); if expected_tripcode != &actual_tripcode { error! ("Denied http_listen for bad tripcode {}", base64::encode (actual_tripcode.as_bytes ())); return trip_error; } use RequestRendezvous::*; let (tx, rx) = oneshot::channel (); { let mut request_rendezvous = state.request_rendezvous.lock ().await; if let Some (ParkedClients (v)) = request_rendezvous.remove (&watcher_code) { if ! v.is_empty () { // 1 or more clients were parked - Make the server // handle them immediately debug! ("Sending {} parked requests to server {}", v.len (), watcher_code); return ok_reply (rmp_serde::to_vec (&v).unwrap ()); } } debug! ("Parking server {}", watcher_code); request_rendezvous.insert (watcher_code.clone (), ParkedServer (tx)); } // No clients were parked - make the server long-poll futures::select! { x = rx.fuse () => match x { Ok (Ok (one_req)) => { debug! ("Unparking server {}", watcher_code); ok_reply (rmp_serde::to_vec (&vec! [one_req]).unwrap ()) }, Ok (Err (RelayError::RelayShuttingDown)) => error_reply (StatusCode::SERVICE_UNAVAILABLE, "Server is shutting down, try again soon"), Err (_) => error_reply (StatusCode::INTERNAL_SERVER_ERROR, "Server error"), }, _ = delay_for (Duration::from_secs (30)).fuse () => { debug! ("Timed out http_listen for server {}", watcher_code); return error_reply (StatusCode::NO_CONTENT, "No requests now, long-poll again") } } } // Servers will come here to stream responses to clients async fn handle_http_response ( req: Request , state: Arc , req_id: String, ) -> Response { let (parts, mut body) = req.into_parts (); let resp_parts: http_serde::ResponseParts = rmp_serde::from_read_ref (&base64::decode (parts.headers.get (crate::PTTH_MAGIC_HEADER).unwrap ()).unwrap ()).unwrap (); // Intercept the body packets here so we can check when the stream // ends or errors out #[derive (Debug)] enum BodyFinishedReason { StreamFinished, ClientDisconnected, } use BodyFinishedReason::*; let (mut body_tx, body_rx) = mpsc::channel (2); let (body_finished_tx, body_finished_rx) = oneshot::channel (); let mut shutdown_watch_rx = state.shutdown_watch_rx.clone (); spawn (async move { if shutdown_watch_rx.recv ().await == Some (false) { loop { let item = body.next ().await; if let Some (item) = item { if let Ok (bytes) = &item { trace! ("Relaying {} bytes", bytes.len ()); } futures::select! { x = body_tx.send (item).fuse () => if let Err (_) = x { info! ("Body closed while relaying. (Client hung up?)"); body_finished_tx.send (ClientDisconnected).unwrap (); break; }, _ = shutdown_watch_rx.recv ().fuse () => { debug! ("Closing stream: relay is shutting down"); break; }, } } else { debug! ("Finished relaying bytes"); body_finished_tx.send (StreamFinished).unwrap (); break; } } } else { debug! ("Can't relay bytes, relay is shutting down"); } }); let body = Body::wrap_stream (body_rx); let tx = { let response_rendezvous = state.response_rendezvous.read ().await; match response_rendezvous.remove (&req_id) { None => { error! ("Server tried to respond to non-existent request"); return error_reply (StatusCode::BAD_REQUEST, "Request ID not found in response_rendezvous"); }, Some ((_, x)) => x, } }; // UKAUFFY4 (Send half) if tx.send (Ok ((resp_parts, body))).is_err () { let msg = "Failed to connect to client"; error! (msg); return error_reply (StatusCode::BAD_GATEWAY, msg); } debug! ("Connected server to client for streaming."); match body_finished_rx.await { Ok (StreamFinished) => { error_reply (StatusCode::OK, "StreamFinished") }, Ok (ClientDisconnected) => { error_reply (StatusCode::OK, "ClientDisconnected") }, Err (e) => { debug! ("body_finished_rx {}", e); error_reply (StatusCode::OK, "body_finished_rx Err") }, } } // Clients will come here to start requests, and always park for at least // a short amount of time. async fn handle_http_request ( req: http::request::Parts, uri: String, state: Arc , watcher_code: String ) -> Response { if ! state.config.server_tripcodes.contains_key (&watcher_code) { return error_reply (StatusCode::NOT_FOUND, "Unknown server"); } let req = match http_serde::RequestParts::from_hyper (req.method, uri, req.headers) { Ok (x) => x, _ => return error_reply (StatusCode::BAD_REQUEST, "Bad request"), }; let (tx, rx) = oneshot::channel (); let req_id = ulid::Ulid::new ().to_string (); { let response_rendezvous = state.response_rendezvous.read ().await; response_rendezvous.insert (req_id.clone (), tx); } trace! ("Created request {}", req_id); { let mut request_rendezvous = state.request_rendezvous.lock ().await; let wrapped = http_serde::WrappedRequest { id: req_id.clone (), req, }; use RequestRendezvous::*; let new_rendezvous = match request_rendezvous.remove (&watcher_code) { Some (ParkedClients (mut v)) => { debug! ("Parking request {} ({} already queued)", req_id, v.len ()); v.push (wrapped); ParkedClients (v) }, Some (ParkedServer (s)) => { // If sending to the server fails, queue it match s.send (Ok (wrapped)) { Ok (()) => { // TODO: This can actually still fail, if the server // disconnects right as we're sending this. // Then what? debug! ( "Sending request {} directly to server {}", req_id, watcher_code, ); ParkedClients (vec! []) }, Err (Ok (wrapped)) => { debug! ("Parking request {}", req_id); ParkedClients (vec! [wrapped]) }, Err (_) => unreachable! (), } }, None => { debug! ("Parking request {}", req_id); ParkedClients (vec! [wrapped]) }, }; request_rendezvous.insert (watcher_code, new_rendezvous); } let timeout = tokio::time::delay_for (std::time::Duration::from_secs (30)); let received = tokio::select! { val = rx => val, () = timeout => { debug! ("Timed out request {}", req_id); return error_reply (StatusCode::GATEWAY_TIMEOUT, "Remote server never responded") }, }; // UKAUFFY4 (Receive half) match received { Ok (Ok ((parts, body))) => { let mut resp = Response::builder () .status (hyper::StatusCode::from (parts.status_code)); for (k, v) in parts.headers.into_iter () { resp = resp.header (&k, v); } debug! ("Unparked request {}", req_id); resp.body (body) .unwrap () }, Ok (Err (RelayError::RelayShuttingDown)) => { error_reply (StatusCode::GATEWAY_TIMEOUT, "Relay shutting down") }, Err (_) => { debug! ("Responder sender dropped for request {}", req_id); error_reply (StatusCode::GATEWAY_TIMEOUT, "Remote server timed out") }, } } #[instrument (level = "trace", skip (req, state))] async fn handle_all (req: Request , state: Arc ) -> Result , Infallible> { let path = req.uri ().path (); //println! ("{}", path); debug! ("Request path: {}", path); let api_key = req.headers ().get ("X-ApiKey"); if req.method () == Method::POST { // This is stuff the server can use. Clients can't // POST right now return Ok (if let Some (request_code) = prefix_match ("/7ZSFUKGV/http_response/", path) { let request_code = request_code.into (); handle_http_response (req, state, request_code).await } else { error_reply (StatusCode::BAD_REQUEST, "Can't POST this") }); } Ok (if let Some (listen_code) = prefix_match ("/7ZSFUKGV/http_listen/", path) { let api_key = match api_key { None => return Ok (error_reply (StatusCode::UNAUTHORIZED, "Can't register as server without an API key")), Some (x) => x, }; handle_http_listen (state, listen_code.into (), api_key.as_bytes ()).await } else if let Some (rest) = prefix_match ("/frontend/servers/", path) { if rest == "" { use std::borrow::Cow; #[derive (Serialize)] struct ServerEntry <'a> { path: &'a str, name: Cow <'a, str>, } #[derive (Serialize)] struct ServerListPage <'a> { servers: Vec >, } let names = state.list_servers ().await; //println! ("Found {} servers", names.len ()); let page = ServerListPage { servers: names.iter () .map (|name| ServerEntry { name: percent_encoding::percent_decode_str (name).decode_utf8 ().unwrap_or_else (|_| "Server name isn't UTF-8".into ()), path: &name, }) .collect (), }; let s = state.handlebars.render ("relay_server_list", &page).unwrap (); ok_reply (s) } else if let Some (idx) = rest.find ('/') { let listen_code = String::from (&rest [0..idx]); let path = String::from (&rest [idx..]); let (parts, _) = req.into_parts (); handle_http_request (parts, path, state, listen_code).await } else { error_reply (StatusCode::BAD_REQUEST, "Bad URI format") } } else if path == "/" { let s = state.handlebars.render ("relay_root", &()).unwrap (); ok_reply (s) } else if path == "/frontend/relay_up_check" { error_reply (StatusCode::OK, "Relay is up") } else { error_reply (StatusCode::OK, "Hi") }) } pub fn load_templates () -> Result , Box > { let mut handlebars = Handlebars::new (); handlebars.set_strict_mode (true); for (k, v) in vec! [ ("relay_server_list", "relay_server_list.html"), ("relay_root", "relay_root.html"), ].into_iter () { handlebars.register_template_file (k, format! ("ptth_handlebars/{}", v))?; } Ok (handlebars) } pub async fn run_relay ( state: Arc , shutdown_oneshot: oneshot::Receiver <()> ) -> Result <(), Box > { let addr = SocketAddr::from (( [0, 0, 0, 0], 4000, )); { let mut tripcode_set = HashSet::new (); for (_, v) in state.config.server_tripcodes.iter () { if ! tripcode_set.insert (v) { panic! ("Two servers have the same tripcode. That is not allowed."); } } } info! ("Loaded {} server tripcodes", state.config.server_tripcodes.len ()); let make_svc = make_service_fn (|_conn| { let state = state.clone (); async { Ok::<_, Infallible> (service_fn (move |req| { let state = state.clone (); handle_all (req, state) })) } }); let server = Server::bind (&addr) .serve (make_svc); server.with_graceful_shutdown (async { shutdown_oneshot.await.ok (); state.shutdown_watch_tx.broadcast (true).unwrap (); use RelayError::*; let mut response_rendezvous = state.response_rendezvous.write ().await; let mut swapped = DashMap::default (); std::mem::swap (&mut swapped, &mut response_rendezvous); for (_, sender) in swapped.into_iter () { sender.send (Err (RelayShuttingDown)).ok (); } let mut request_rendezvous = state.request_rendezvous.lock ().await; for (_, x) in request_rendezvous.drain () { use RequestRendezvous::*; match x { ParkedClients (_) => (), ParkedServer (sender) => drop (sender.send (Err (RelayShuttingDown))), } } debug! ("Performed all cleanup"); }).await?; Ok (()) } #[cfg (test)] mod tests { }