From e7edf842828a02b8d4940069f9cdb37ef4699890 Mon Sep 17 00:00:00 2001 From: _ <_@_> Date: Sun, 1 Nov 2020 20:07:46 -0600 Subject: [PATCH] :bug: Fix rendezvous problem. Now clients can queue up for a server, which fixes a few things: - A server can receive multiple requests at once, reducing roundtrip count in theory - Clients can wait up to 30 seconds on the relay before the server is ready for them - If the server has just left to service a request, the client will queue instead of seeing the server as absent and giving up --- Cargo.toml | 1 + src/lib.rs | 24 +++- src/relay/mod.rs | 288 ++++++++++++++++++++++++---------------------- src/server/mod.rs | 46 +++++--- todo.md | 1 - 5 files changed, 200 insertions(+), 160 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 767a950..2f8286c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,6 +12,7 @@ license = "AGPL-3.0" base64 = "0.12.3" blake3 = "0.3.7" +dashmap = "3.11.10" futures = "0.3.7" handlebars = "3.5.1" http = "0.2.1" diff --git a/src/lib.rs b/src/lib.rs index 8b05ca2..c154d07 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -15,6 +15,10 @@ pub mod server; #[cfg (test)] mod tests { + use std::{ + sync::Arc, + }; + use tokio::{ runtime::Runtime, spawn, @@ -35,12 +39,16 @@ mod tests { rt.block_on (async { let relay_url = "http://127.0.0.1:4000"; - spawn (async { - relay::main ().await.unwrap (); + let relay_state = Arc::new (relay::RelayState::default ()); + + let relay_state_2 = relay_state.clone (); + spawn (async move { + relay::run_relay (relay_state_2).await.unwrap (); }); - let relay_url_2 = relay_url.into (); + assert! (relay_state.list_servers ().await.is_empty ()); + let relay_url_2 = relay_url.into (); let server_name = "alien_wildlands"; let server_name_2 = server_name.into (); @@ -54,9 +62,15 @@ mod tests { server::main (opt).await.unwrap (); }); - tokio::time::delay_for (std::time::Duration::from_secs (1)).await; + tokio::time::delay_for (std::time::Duration::from_millis (500)).await; - let client = Client::new (); + assert_eq! (relay_state.list_servers ().await, vec! [ + server_name.to_string (), + ]); + + let client = Client::builder () + .timeout (std::time::Duration::from_secs (2)) + .build ().unwrap (); let resp = client.get (&format! ("{}/relay_up_check", relay_url)) .send ().await.unwrap ().bytes ().await.unwrap (); diff --git a/src/relay/mod.rs b/src/relay/mod.rs index 1de526b..d3dbb05 100644 --- a/src/relay/mod.rs +++ b/src/relay/mod.rs @@ -2,14 +2,15 @@ pub mod watcher; use std::{ error::Error, + collections::*, convert::Infallible, net::SocketAddr, sync::{ Arc }, - time::{Duration}, }; +use dashmap::DashMap; use futures::channel::oneshot; use handlebars::Handlebars; use hyper::{ @@ -24,27 +25,74 @@ use hyper::service::{make_service_fn, service_fn}; use serde::Serialize; use tokio::{ sync::Mutex, - time::delay_for, }; use crate::{ http_serde, }; -use watcher::*; -#[derive (Default)] -struct ServerState { +/* + +Here's what we need to handle: + +When a request comes in: + +- Park the client in response_rendezvous +- Look up the server ID in request_rendezvous +- If a server is parked, unpark it and send the request +- Otherwise, queue the request + +When a server comes to listen: + +- Look up the server ID in request_rendezvous +- Either return all pending requests, or park the server + +When a server comes to respond: + +- Look up the parked client in response_rendezvous +- Unpark the client and begin streaming + +So we need these lookups to be fast: + +- Server IDs, where (1 server) or (0 or many clients) +can be parked +- Request IDs, where 1 client is parked + +*/ + +enum RequestRendezvous { + ParkedClients (Vec ), + ParkedServer (oneshot::Sender ), +} + +type ResponseRendezvous = oneshot::Sender <(http_serde::ResponseParts, Body)>; + +pub struct RelayState { handlebars: Arc >, - // Holds clients that are waiting for a response to come - // back from a server. + // Key: Server ID + request_rendezvous: Mutex >, - client_watchers: Arc >>, - - // Holds servers that are waiting for a request to come in - // from a client. - - server_watchers: Arc >>, + // Key: Request ID + response_rendezvous: DashMap , +} + +impl Default for RelayState { + fn default () -> Self { + Self { + handlebars: Arc::new (load_templates ().unwrap ()), + request_rendezvous: Default::default (), + response_rendezvous: Default::default (), + } + } +} + +impl RelayState { + pub async fn list_servers (&self) -> Vec { + self.request_rendezvous.lock ().await.iter () + .map (|(k, _)| (*k).clone ()) + .collect () + } } fn status_reply > (status: StatusCode, b: B) @@ -53,113 +101,122 @@ fn status_reply > (status: StatusCode, b: B) Response::builder ().status (status).body (b.into ()).unwrap () } -async fn handle_http_listen (state: Arc , watcher_code: String) +async fn handle_http_listen (state: Arc , watcher_code: String) -> Response { - //println! ("Step 1"); - match Watchers::long_poll (state.server_watchers.clone (), watcher_code).await { - Some (parts) => { - //println! ("Step 3"); - status_reply (StatusCode::OK, rmp_serde::to_vec (&parts).unwrap ()) - }, - _ => status_reply (StatusCode::GATEWAY_TIMEOUT, "no\n"), + use RequestRendezvous::*; + + let (tx, rx) = oneshot::channel (); + + { + let mut request_rendezvous = state.request_rendezvous.lock ().await; + + if let Some (ParkedClients (v)) = request_rendezvous.remove (&watcher_code) + { + return status_reply (StatusCode::OK, rmp_serde::to_vec (&v).unwrap ()); + } + + request_rendezvous.insert (watcher_code, ParkedServer (tx)); } + + let one_req = vec! [ + rx.await.unwrap (), + ]; + + return status_reply (StatusCode::OK, rmp_serde::to_vec (&one_req).unwrap ()); } async fn handle_http_response ( req: Request , - state: Arc , + state: Arc , req_id: String, ) -> Response { - //println! ("Step 6"); let (parts, body) = req.into_parts (); let resp_parts: http_serde::ResponseParts = rmp_serde::from_read_ref (&base64::decode (parts.headers.get (crate::PTTH_MAGIC_HEADER).unwrap ()).unwrap ()).unwrap (); - { - let mut watchers = state.client_watchers.lock ().await; - - //println! ("Step 7"); - if ! watchers.wake_one ((resp_parts, body), &req_id) - { - println! ("Step 8 (bad thing)"); - status_reply (StatusCode::BAD_REQUEST, "A bad thing happened.\n") - } - else { - //println! ("Step 8"); - status_reply (StatusCode::OK, "ok\n") - } + match state.response_rendezvous.remove (&req_id) { + Some ((_, tx)) => { + match tx.send ((resp_parts, body)) { + Ok (()) => status_reply (StatusCode::OK, "Connected to remote client...\n"), + _ => status_reply (StatusCode::BAD_GATEWAY, "Failed to connect to client"), + } + + }, + None => status_reply (StatusCode::BAD_REQUEST, "Request ID not found in response_rendezvous"), } } async fn handle_http_request ( req: http::request::Parts, uri: String, - state: Arc , + state: Arc , watcher_code: String ) -> Response { - let parts = { - let id = ulid::Ulid::new ().to_string (); - let req = match http_serde::RequestParts::from_hyper (req.method, uri, req.headers) { - Ok (x) => x, - _ => return status_reply (StatusCode::BAD_REQUEST, "Bad request"), - }; - - http_serde::WrappedRequest { - id, - req, - } + let id = ulid::Ulid::new ().to_string (); + let req = match http_serde::RequestParts::from_hyper (req.method, uri, req.headers) { + Ok (x) => x, + _ => return status_reply (StatusCode::BAD_REQUEST, "Bad request"), }; - //println! ("Step 2 {}", parts.id); + let (tx, rx) = oneshot::channel (); - let (s, r) = oneshot::channel (); - let timeout = Duration::from_secs (5); + state.response_rendezvous.insert (id.clone (), tx); - let id_2 = parts.id.clone (); { - let mut that = state.client_watchers.lock ().await; - that.add_watcher_with_id (s, id_2) + let mut request_rendezvous = state.request_rendezvous.lock ().await; + + let wrapped = http_serde::WrappedRequest { + id, + req, + }; + + use RequestRendezvous::*; + + let new_rendezvous = match request_rendezvous.remove (&watcher_code) { + Some (ParkedClients (mut v)) => { + v.push (wrapped); + ParkedClients (v) + }, + Some (ParkedServer (s)) => { + // If sending to the server fails, queue it + + match s.send (wrapped) { + Ok (()) => ParkedClients (vec! []), + Err (wrapped) => ParkedClients (vec! [wrapped]), + } + }, + None => ParkedClients (vec! [wrapped]), + }; + + request_rendezvous.insert (watcher_code, new_rendezvous); } - let req_id = parts.id.clone (); + let timeout = tokio::time::delay_for (std::time::Duration::from_secs (30)); - tokio::spawn (async move { - { - let mut watchers = state.server_watchers.lock ().await; - - //println! ("Step 3"); - if ! watchers.wake_one (parts, &watcher_code) { - watchers.remove_watcher (&req_id); - } - } - - delay_for (timeout).await; - { - let mut that = state.client_watchers.lock ().await; - that.remove_watcher (&req_id); - } - }); + let received = tokio::select! { + val = rx => val, + () = timeout => { + return status_reply (StatusCode::GATEWAY_TIMEOUT, "Remote server never responded") + }, + }; - match r.await { - Ok ((resp_parts, body)) => { - //println! ("Step 7"); - + match received { + Ok ((parts, body)) => { let mut resp = Response::builder () - .status (hyper::StatusCode::from (resp_parts.status_code)); + .status (hyper::StatusCode::from (parts.status_code)); - for (k, v) in resp_parts.headers.into_iter () { + for (k, v) in parts.headers.into_iter () { resp = resp.header (&k, v); } - resp - .body (body) + resp.body (body) .unwrap () }, - _ => status_reply (StatusCode::GATEWAY_TIMEOUT, "server didn't reply in time or somethin'"), + _ => status_reply (StatusCode::GATEWAY_TIMEOUT, "Remote server timed out"), } } @@ -173,7 +230,7 @@ fn prefix_match <'a> (hay: &'a str, needle: &str) -> Option <&'a str> } } -async fn handle_all (req: Request , state: Arc ) +async fn handle_all (req: Request , state: Arc ) -> Result , Infallible> { let path = req.uri ().path (); @@ -208,11 +265,7 @@ async fn handle_all (req: Request , state: Arc ) servers: Vec >, } - let names: Vec <_> = { - state.server_watchers.lock ().await.senders.iter () - .map (|(k, _)| (*k).clone ()) - .collect () - }; + let names = state.list_servers ().await; //println! ("Found {} servers", names.len ()); @@ -262,17 +315,10 @@ pub fn load_templates () Ok (handlebars) } -pub async fn main () -> Result <(), Box > { +pub async fn run_relay (state: Arc ) -> Result <(), Box > +{ let addr = SocketAddr::from(([0, 0, 0, 0], 4000)); - let state = ServerState { - handlebars: Arc::new (load_templates ()?), - server_watchers: Default::default (), - client_watchers: Default::default (), - }; - - let state = Arc::new (state); - let make_svc = make_service_fn (|_conn| { let state = state.clone (); @@ -292,47 +338,19 @@ pub async fn main () -> Result <(), Box > { Ok (()) } +pub async fn main () -> Result <(), Box > { + let state = RelayState { + handlebars: Arc::new (load_templates ()?), + request_rendezvous: Default::default (), + response_rendezvous: Default::default (), + }; + + let state = Arc::new (state); + + run_relay (state).await +} + #[cfg (test)] mod tests { - // Toy model of a relay for a single server - // with one consumer thread. - // To scale this up, we can just put a bunch into a - // concurrent hash map inside of mutexes or something - struct RelayStateMachine { - - } - - enum RequestStateMachine { - WaitForServerAccept, // Client has connected - WaitForServerResponse, // Server has accepted request - } - - /* - - Here's what we need to handle: - - When a request comes in: - - - Look up the server - - If the server is parked, unpark it - - Park the client - - When a server comes to listen: - - - Look up the server - - Either return all pending requests, or park the server - - When a server comes to respond: - - - Look up the parked client - - Begin a stream, unparking the client - - So we need these lookups to be fast: - - - Server IDs, where 0 or 1 servers and 0 or many clients - can be parked - - Request IDs, where 1 client is parked - - */ } diff --git a/src/server/mod.rs b/src/server/mod.rs index f6a0c6d..c87bd79 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -21,7 +21,7 @@ pub mod file_server; async fn handle_req_resp <'a> ( opt: &'a Opt, handlebars: Arc >, - client: &'a Client, + client: Arc , req_resp: reqwest::Response ) { //println! ("Step 1"); @@ -32,30 +32,39 @@ async fn handle_req_resp <'a> ( } let body = req_resp.bytes ().await.unwrap (); - let wrapped_req: http_serde::WrappedRequest = match rmp_serde::from_read_ref (&body) + let wrapped_reqs: Vec = match rmp_serde::from_read_ref (&body) { Ok (x) => x, _ => return, }; - let (req_id, parts) = (wrapped_req.id, wrapped_req.req); - - let response = file_server::serve_all (handlebars, &opt.file_server_root, parts).await; - - let mut resp_req = client - .post (&format! ("{}/7ZSFUKGV_http_response/{}", opt.relay_url, req_id)) - .header (crate::PTTH_MAGIC_HEADER, base64::encode (rmp_serde::to_vec (&response.parts).unwrap ())); - - if let Some (body) = response.body { - resp_req = resp_req.body (reqwest::Body::wrap_stream (body)); - } - - //println! ("Step 6"); - if let Err (e) = resp_req.send ().await { - println! ("Err: {:?}", e); + for wrapped_req in wrapped_reqs.into_iter () { + let handlebars = handlebars.clone (); + let opt = opt.clone (); + let client = client.clone (); + + tokio::spawn (async move { + let (req_id, parts) = (wrapped_req.id, wrapped_req.req); + + let response = file_server::serve_all (handlebars, &opt.file_server_root, parts).await; + + let mut resp_req = client + .post (&format! ("{}/7ZSFUKGV_http_response/{}", opt.relay_url, req_id)) + .header (crate::PTTH_MAGIC_HEADER, base64::encode (rmp_serde::to_vec (&response.parts).unwrap ())); + + if let Some (body) = response.body { + resp_req = resp_req.body (reqwest::Body::wrap_stream (body)); + } + + //println! ("Step 6"); + if let Err (e) = resp_req.send ().await { + println! ("Err: {:?}", e); + } + }); } } +#[derive (Clone)] pub struct Opt { pub relay_url: String, pub server_name: String, @@ -65,7 +74,6 @@ pub struct Opt { pub async fn main (opt: Opt) -> Result <(), Box > { let client = Arc::new (Client::new ()); let opt = Arc::new (opt); - let handlebars = Arc::new (file_server::load_templates ()?); let mut backoff_delay = 0; @@ -97,7 +105,7 @@ pub async fn main (opt: Opt) -> Result <(), Box > { let handlebars = handlebars.clone (); tokio::spawn (async move { - handle_req_resp (&opt, handlebars, &client, req_resp).await; + handle_req_resp (&opt, handlebars, client, req_resp).await; }); } } diff --git a/todo.md b/todo.md index b356638..f651d58 100644 --- a/todo.md +++ b/todo.md @@ -1,4 +1,3 @@ -- Fix possible timing gap when refreshing http_listen - Set up tokens or privkeys or tripcodes or something so clients can't trivially impersonate servers