use std::{ error::Error, collections::*, convert::Infallible, iter::FromIterator, net::SocketAddr, sync::{ Arc }, }; use dashmap::DashMap; use futures::stream::StreamExt; use handlebars::Handlebars; use hyper::{ Body, Method, Request, Response, Server, StatusCode, }; use hyper::service::{make_service_fn, service_fn}; use serde::{ Deserialize, Serialize, }; use tokio::{ spawn, sync::{ Mutex, mpsc, oneshot, }, }; use tracing::{debug, error, info, trace, warn}; use crate::{ http_serde, prefix_match, }; /* Here's what we need to handle: When a request comes in: - Park the client in response_rendezvous - Look up the server ID in request_rendezvous - If a server is parked, unpark it and send the request - Otherwise, queue the request When a server comes to listen: - Look up the server ID in request_rendezvous - Either return all pending requests, or park the server When a server comes to respond: - Look up the parked client in response_rendezvous - Unpark the client and begin streaming So we need these lookups to be fast: - Server IDs, where (1 server) or (0 or many clients) can be parked - Request IDs, where 1 client is parked */ enum RequestRendezvous { ParkedClients (Vec ), ParkedServer (oneshot::Sender ), } type ResponseRendezvous = oneshot::Sender <(http_serde::ResponseParts, Body)>; // Stuff we need to load from the config file and use to // set up the HTTP server #[derive (Default, Deserialize)] pub struct ConfigFile { pub port: Option , pub server_tripcodes: HashMap , } // Stuff we actually need at runtime struct Config { server_tripcodes: HashMap , } impl From <&ConfigFile> for Config { fn from (f: &ConfigFile) -> Self { let trips = HashMap::from_iter (f.server_tripcodes.iter () .map (|(k, v)| { use std::convert::TryInto; let bytes: Vec = base64::decode (v).unwrap (); let bytes: [u8; 32] = (&bytes [..]).try_into ().unwrap (); let v = blake3::Hash::from (bytes); (k.clone (), v) })); Self { server_tripcodes: trips, } } } pub struct RelayState { config: Config, handlebars: Arc >, // Key: Server ID request_rendezvous: Mutex >, // Key: Request ID response_rendezvous: DashMap , } impl Default for RelayState { fn default () -> Self { Self { config: Config::from (&ConfigFile::default ()), handlebars: Arc::new (load_templates ().unwrap ()), request_rendezvous: Default::default (), response_rendezvous: Default::default (), } } } impl From <&ConfigFile> for RelayState { fn from (config_file: &ConfigFile) -> Self { Self { config: Config::from (config_file), handlebars: Arc::new (load_templates ().unwrap ()), request_rendezvous: Default::default (), response_rendezvous: Default::default (), } } } impl RelayState { pub async fn list_servers (&self) -> Vec { self.request_rendezvous.lock ().await.iter () .map (|(k, _)| (*k).clone ()) .collect () } } fn status_reply > (status: StatusCode, b: B) -> Response { Response::builder ().status (status).body (b.into ()).unwrap () } // Servers will come here and either handle queued requests from parked clients, // or park themselves until a request comes in. async fn handle_http_listen ( state: Arc , watcher_code: String, api_key: &[u8], ) -> Response { let trip_error = status_reply (StatusCode::UNAUTHORIZED, "Bad X-ApiKey"); let expected_tripcode = match state.config.server_tripcodes.get (&watcher_code) { None => { error! ("Denied http_listen for non-existent server name {}", watcher_code); return trip_error; }, Some (x) => x, }; let actual_tripcode = blake3::hash (api_key); if expected_tripcode != &actual_tripcode { error! ("Denied http_listen for bad tripcode {}", base64::encode (actual_tripcode.as_bytes ())); return trip_error; } use RequestRendezvous::*; let (tx, rx) = oneshot::channel (); { let mut request_rendezvous = state.request_rendezvous.lock ().await; if let Some (ParkedClients (v)) = request_rendezvous.remove (&watcher_code) { // 1 or more clients were parked - Make the server // handle them immediately return status_reply (StatusCode::OK, rmp_serde::to_vec (&v).unwrap ()); } request_rendezvous.insert (watcher_code, ParkedServer (tx)); } // No clients were parked - make the server long-poll let one_req = match rx.await { Ok (r) => r, Err (_) => return status_reply (StatusCode::SERVICE_UNAVAILABLE, "Server is shutting down, try again soon"), }; status_reply (StatusCode::OK, rmp_serde::to_vec (&vec! [one_req]).unwrap ()) } // Servers will come here to stream responses to clients async fn handle_http_response ( req: Request , state: Arc , req_id: String, ) -> Response { let (parts, mut body) = req.into_parts (); let resp_parts: http_serde::ResponseParts = rmp_serde::from_read_ref (&base64::decode (parts.headers.get (crate::PTTH_MAGIC_HEADER).unwrap ()).unwrap ()).unwrap (); // Intercept the body packets here so we can check when the stream // ends or errors out let (mut body_tx, body_rx) = mpsc::channel (2); spawn (async move { loop { let item = body.next ().await; if let Some (item) = item { if let Ok (bytes) = &item { trace! ("Relaying {} bytes", bytes.len ()); } if body_tx.send (item).await.is_err () { error! ("Error relaying bytes"); break; } } else { debug! ("Finished relaying bytes"); break; } } }); let body = Body::wrap_stream (body_rx); match state.response_rendezvous.remove (&req_id) { Some ((_, tx)) => { // UKAUFFY4 (Send half) match tx.send ((resp_parts, body)) { Ok (()) => { debug! ("Responding to server"); status_reply (StatusCode::OK, "http_response completed.") }, _ => { let msg = "Failed to connect to client"; error! (msg); status_reply (StatusCode::BAD_GATEWAY, msg) }, } }, None => { error! ("Server tried to respond to non-existent request"); status_reply (StatusCode::BAD_REQUEST, "Request ID not found in response_rendezvous") }, } } // Clients will come here to start requests, and always park for at least // a short amount of time. async fn handle_http_request ( req: http::request::Parts, uri: String, state: Arc , watcher_code: String ) -> Response { if ! state.config.server_tripcodes.contains_key (&watcher_code) { return status_reply (StatusCode::NOT_FOUND, "Unknown server"); } let req = match http_serde::RequestParts::from_hyper (req.method, uri, req.headers) { Ok (x) => x, _ => return status_reply (StatusCode::BAD_REQUEST, "Bad request"), }; let (tx, rx) = oneshot::channel (); let id = ulid::Ulid::new ().to_string (); state.response_rendezvous.insert (id.clone (), tx); { let mut request_rendezvous = state.request_rendezvous.lock ().await; let wrapped = http_serde::WrappedRequest { id, req, }; use RequestRendezvous::*; let new_rendezvous = match request_rendezvous.remove (&watcher_code) { Some (ParkedClients (mut v)) => { v.push (wrapped); ParkedClients (v) }, Some (ParkedServer (s)) => { // If sending to the server fails, queue it match s.send (wrapped) { Ok (()) => ParkedClients (vec! []), Err (wrapped) => ParkedClients (vec! [wrapped]), } }, None => ParkedClients (vec! [wrapped]), }; request_rendezvous.insert (watcher_code, new_rendezvous); } let timeout = tokio::time::delay_for (std::time::Duration::from_secs (30)); let received = tokio::select! { val = rx => val, () = timeout => { return status_reply (StatusCode::GATEWAY_TIMEOUT, "Remote server never responded") }, }; // UKAUFFY4 (Receive half) match received { Ok ((parts, body)) => { let mut resp = Response::builder () .status (hyper::StatusCode::from (parts.status_code)); for (k, v) in parts.headers.into_iter () { resp = resp.header (&k, v); } resp.body (body) .unwrap () }, _ => status_reply (StatusCode::GATEWAY_TIMEOUT, "Remote server timed out"), } } async fn handle_all (req: Request , state: Arc ) -> Result , Infallible> { let path = req.uri ().path (); //println! ("{}", path); let api_key = req.headers ().get ("X-ApiKey"); if req.method () == Method::POST { // This is stuff the server can use. Clients can't // POST right now return Ok (if let Some (request_code) = prefix_match (path, "/7ZSFUKGV/http_response/") { let request_code = request_code.into (); handle_http_response (req, state, request_code).await } else { status_reply (StatusCode::BAD_REQUEST, "Can't POST this\n") }); } Ok (if let Some (listen_code) = prefix_match (path, "/7ZSFUKGV/http_listen/") { let api_key = match api_key { None => return Ok (status_reply (StatusCode::UNAUTHORIZED, "Can't register as server without an API key")), Some (x) => x, }; handle_http_listen (state, listen_code.into (), api_key.as_bytes ()).await } else if let Some (rest) = prefix_match (path, "/frontend/servers/") { if rest == "" { #[derive (Serialize)] struct ServerEntry <'a> { path: &'a str, name: &'a str, } #[derive (Serialize)] struct ServerListPage <'a> { servers: Vec >, } let names = state.list_servers ().await; //println! ("Found {} servers", names.len ()); let page = ServerListPage { servers: names.iter () .map (|name| ServerEntry { name: &name, path: &name, }) .collect (), }; let s = state.handlebars.render ("relay_server_list", &page).unwrap (); status_reply (StatusCode::OK, s) } else if let Some (idx) = rest.find ('/') { let listen_code = String::from (&rest [0..idx]); let path = String::from (&rest [idx..]); let (parts, _) = req.into_parts (); handle_http_request (parts, path, state, listen_code).await } else { status_reply (StatusCode::BAD_REQUEST, "Bad URI format") } } else if path == "/frontend/relay_up_check" { status_reply (StatusCode::OK, "Relay is up\n") } else { status_reply (StatusCode::OK, "Hi\n") }) } pub fn load_templates () -> Result , Box > { let mut handlebars = Handlebars::new (); handlebars.set_strict_mode (true); for (k, v) in vec! [ ("relay_server_list", "relay_server_list.html"), ].into_iter () { handlebars.register_template_file (k, format! ("ptth_handlebars/{}", v))?; } Ok (handlebars) } pub async fn run_relay ( state: Arc , shutdown_oneshot: oneshot::Receiver <()> ) -> Result <(), Box > { let addr = SocketAddr::from (( [0, 0, 0, 0], 4000, )); { let mut tripcode_set = HashSet::new (); for (_, v) in state.config.server_tripcodes.iter () { if ! tripcode_set.insert (v) { panic! ("Two servers have the same tripcode. That is not allowed."); } } } info! ("Loaded {} server tripcodes", state.config.server_tripcodes.len ()); let make_svc = make_service_fn (|_conn| { let state = state.clone (); async { Ok::<_, Infallible> (service_fn (move |req| { let state = state.clone (); handle_all (req, state) })) } }); let server = Server::bind (&addr) .serve (make_svc); server.with_graceful_shutdown (async { shutdown_oneshot.await.ok (); state.response_rendezvous.clear (); let mut request_rendezvoux = state.request_rendezvous.lock ().await; request_rendezvoux.clear (); info! ("Received graceful shutdown"); }).await?; info! ("Exiting"); Ok (()) } #[cfg (test)] mod tests { }