🐛 Fix rendezvous problem.

Now clients can queue up for a server, which fixes a few things:

- A server can receive multiple requests at once, reducing roundtrip
count in theory
- Clients can wait up to 30 seconds on the relay before the server
is ready for them
- If the server has just left to service a request, the client will
queue instead of seeing the server as absent and giving up
main
_ 2020-11-01 20:07:46 -06:00
parent 067e240ff4
commit e7edf84282
5 changed files with 200 additions and 160 deletions

View File

@ -12,6 +12,7 @@ license = "AGPL-3.0"
base64 = "0.12.3"
blake3 = "0.3.7"
dashmap = "3.11.10"
futures = "0.3.7"
handlebars = "3.5.1"
http = "0.2.1"

View File

@ -15,6 +15,10 @@ pub mod server;
#[cfg (test)]
mod tests {
use std::{
sync::Arc,
};
use tokio::{
runtime::Runtime,
spawn,
@ -35,12 +39,16 @@ mod tests {
rt.block_on (async {
let relay_url = "http://127.0.0.1:4000";
spawn (async {
relay::main ().await.unwrap ();
let relay_state = Arc::new (relay::RelayState::default ());
let relay_state_2 = relay_state.clone ();
spawn (async move {
relay::run_relay (relay_state_2).await.unwrap ();
});
let relay_url_2 = relay_url.into ();
assert! (relay_state.list_servers ().await.is_empty ());
let relay_url_2 = relay_url.into ();
let server_name = "alien_wildlands";
let server_name_2 = server_name.into ();
@ -54,9 +62,15 @@ mod tests {
server::main (opt).await.unwrap ();
});
tokio::time::delay_for (std::time::Duration::from_secs (1)).await;
tokio::time::delay_for (std::time::Duration::from_millis (500)).await;
let client = Client::new ();
assert_eq! (relay_state.list_servers ().await, vec! [
server_name.to_string (),
]);
let client = Client::builder ()
.timeout (std::time::Duration::from_secs (2))
.build ().unwrap ();
let resp = client.get (&format! ("{}/relay_up_check", relay_url))
.send ().await.unwrap ().bytes ().await.unwrap ();

View File

@ -2,14 +2,15 @@ pub mod watcher;
use std::{
error::Error,
collections::*,
convert::Infallible,
net::SocketAddr,
sync::{
Arc
},
time::{Duration},
};
use dashmap::DashMap;
use futures::channel::oneshot;
use handlebars::Handlebars;
use hyper::{
@ -24,27 +25,74 @@ use hyper::service::{make_service_fn, service_fn};
use serde::Serialize;
use tokio::{
sync::Mutex,
time::delay_for,
};
use crate::{
http_serde,
};
use watcher::*;
#[derive (Default)]
struct ServerState {
/*
Here's what we need to handle:
When a request comes in:
- Park the client in response_rendezvous
- Look up the server ID in request_rendezvous
- If a server is parked, unpark it and send the request
- Otherwise, queue the request
When a server comes to listen:
- Look up the server ID in request_rendezvous
- Either return all pending requests, or park the server
When a server comes to respond:
- Look up the parked client in response_rendezvous
- Unpark the client and begin streaming
So we need these lookups to be fast:
- Server IDs, where (1 server) or (0 or many clients)
can be parked
- Request IDs, where 1 client is parked
*/
enum RequestRendezvous {
ParkedClients (Vec <http_serde::WrappedRequest>),
ParkedServer (oneshot::Sender <http_serde::WrappedRequest>),
}
type ResponseRendezvous = oneshot::Sender <(http_serde::ResponseParts, Body)>;
pub struct RelayState {
handlebars: Arc <Handlebars <'static>>,
// Holds clients that are waiting for a response to come
// back from a server.
// Key: Server ID
request_rendezvous: Mutex <HashMap <String, RequestRendezvous>>,
client_watchers: Arc <Mutex <Watchers <(http_serde::ResponseParts, Body)>>>,
// Holds servers that are waiting for a request to come in
// from a client.
server_watchers: Arc <Mutex <Watchers <http_serde::WrappedRequest>>>,
// Key: Request ID
response_rendezvous: DashMap <String, ResponseRendezvous>,
}
impl Default for RelayState {
fn default () -> Self {
Self {
handlebars: Arc::new (load_templates ().unwrap ()),
request_rendezvous: Default::default (),
response_rendezvous: Default::default (),
}
}
}
impl RelayState {
pub async fn list_servers (&self) -> Vec <String> {
self.request_rendezvous.lock ().await.iter ()
.map (|(k, _)| (*k).clone ())
.collect ()
}
}
fn status_reply <B: Into <Body>> (status: StatusCode, b: B)
@ -53,113 +101,122 @@ fn status_reply <B: Into <Body>> (status: StatusCode, b: B)
Response::builder ().status (status).body (b.into ()).unwrap ()
}
async fn handle_http_listen (state: Arc <ServerState>, watcher_code: String)
async fn handle_http_listen (state: Arc <RelayState>, watcher_code: String)
-> Response <Body>
{
//println! ("Step 1");
match Watchers::long_poll (state.server_watchers.clone (), watcher_code).await {
Some (parts) => {
//println! ("Step 3");
status_reply (StatusCode::OK, rmp_serde::to_vec (&parts).unwrap ())
},
_ => status_reply (StatusCode::GATEWAY_TIMEOUT, "no\n"),
use RequestRendezvous::*;
let (tx, rx) = oneshot::channel ();
{
let mut request_rendezvous = state.request_rendezvous.lock ().await;
if let Some (ParkedClients (v)) = request_rendezvous.remove (&watcher_code)
{
return status_reply (StatusCode::OK, rmp_serde::to_vec (&v).unwrap ());
}
request_rendezvous.insert (watcher_code, ParkedServer (tx));
}
let one_req = vec! [
rx.await.unwrap (),
];
return status_reply (StatusCode::OK, rmp_serde::to_vec (&one_req).unwrap ());
}
async fn handle_http_response (
req: Request <Body>,
state: Arc <ServerState>,
state: Arc <RelayState>,
req_id: String,
)
-> Response <Body>
{
//println! ("Step 6");
let (parts, body) = req.into_parts ();
let resp_parts: http_serde::ResponseParts = rmp_serde::from_read_ref (&base64::decode (parts.headers.get (crate::PTTH_MAGIC_HEADER).unwrap ()).unwrap ()).unwrap ();
{
let mut watchers = state.client_watchers.lock ().await;
//println! ("Step 7");
if ! watchers.wake_one ((resp_parts, body), &req_id)
{
println! ("Step 8 (bad thing)");
status_reply (StatusCode::BAD_REQUEST, "A bad thing happened.\n")
}
else {
//println! ("Step 8");
status_reply (StatusCode::OK, "ok\n")
}
match state.response_rendezvous.remove (&req_id) {
Some ((_, tx)) => {
match tx.send ((resp_parts, body)) {
Ok (()) => status_reply (StatusCode::OK, "Connected to remote client...\n"),
_ => status_reply (StatusCode::BAD_GATEWAY, "Failed to connect to client"),
}
},
None => status_reply (StatusCode::BAD_REQUEST, "Request ID not found in response_rendezvous"),
}
}
async fn handle_http_request (
req: http::request::Parts,
uri: String,
state: Arc <ServerState>,
state: Arc <RelayState>,
watcher_code: String
)
-> Response <Body>
{
let parts = {
let id = ulid::Ulid::new ().to_string ();
let req = match http_serde::RequestParts::from_hyper (req.method, uri, req.headers) {
Ok (x) => x,
_ => return status_reply (StatusCode::BAD_REQUEST, "Bad request"),
};
http_serde::WrappedRequest {
id,
req,
}
let id = ulid::Ulid::new ().to_string ();
let req = match http_serde::RequestParts::from_hyper (req.method, uri, req.headers) {
Ok (x) => x,
_ => return status_reply (StatusCode::BAD_REQUEST, "Bad request"),
};
//println! ("Step 2 {}", parts.id);
let (tx, rx) = oneshot::channel ();
let (s, r) = oneshot::channel ();
let timeout = Duration::from_secs (5);
state.response_rendezvous.insert (id.clone (), tx);
let id_2 = parts.id.clone ();
{
let mut that = state.client_watchers.lock ().await;
that.add_watcher_with_id (s, id_2)
let mut request_rendezvous = state.request_rendezvous.lock ().await;
let wrapped = http_serde::WrappedRequest {
id,
req,
};
use RequestRendezvous::*;
let new_rendezvous = match request_rendezvous.remove (&watcher_code) {
Some (ParkedClients (mut v)) => {
v.push (wrapped);
ParkedClients (v)
},
Some (ParkedServer (s)) => {
// If sending to the server fails, queue it
match s.send (wrapped) {
Ok (()) => ParkedClients (vec! []),
Err (wrapped) => ParkedClients (vec! [wrapped]),
}
},
None => ParkedClients (vec! [wrapped]),
};
request_rendezvous.insert (watcher_code, new_rendezvous);
}
let req_id = parts.id.clone ();
let timeout = tokio::time::delay_for (std::time::Duration::from_secs (30));
tokio::spawn (async move {
{
let mut watchers = state.server_watchers.lock ().await;
//println! ("Step 3");
if ! watchers.wake_one (parts, &watcher_code) {
watchers.remove_watcher (&req_id);
}
}
delay_for (timeout).await;
{
let mut that = state.client_watchers.lock ().await;
that.remove_watcher (&req_id);
}
});
let received = tokio::select! {
val = rx => val,
() = timeout => {
return status_reply (StatusCode::GATEWAY_TIMEOUT, "Remote server never responded")
},
};
match r.await {
Ok ((resp_parts, body)) => {
//println! ("Step 7");
match received {
Ok ((parts, body)) => {
let mut resp = Response::builder ()
.status (hyper::StatusCode::from (resp_parts.status_code));
.status (hyper::StatusCode::from (parts.status_code));
for (k, v) in resp_parts.headers.into_iter () {
for (k, v) in parts.headers.into_iter () {
resp = resp.header (&k, v);
}
resp
.body (body)
resp.body (body)
.unwrap ()
},
_ => status_reply (StatusCode::GATEWAY_TIMEOUT, "server didn't reply in time or somethin'"),
_ => status_reply (StatusCode::GATEWAY_TIMEOUT, "Remote server timed out"),
}
}
@ -173,7 +230,7 @@ fn prefix_match <'a> (hay: &'a str, needle: &str) -> Option <&'a str>
}
}
async fn handle_all (req: Request <Body>, state: Arc <ServerState>)
async fn handle_all (req: Request <Body>, state: Arc <RelayState>)
-> Result <Response <Body>, Infallible>
{
let path = req.uri ().path ();
@ -208,11 +265,7 @@ async fn handle_all (req: Request <Body>, state: Arc <ServerState>)
servers: Vec <ServerEntry <'a>>,
}
let names: Vec <_> = {
state.server_watchers.lock ().await.senders.iter ()
.map (|(k, _)| (*k).clone ())
.collect ()
};
let names = state.list_servers ().await;
//println! ("Found {} servers", names.len ());
@ -262,17 +315,10 @@ pub fn load_templates ()
Ok (handlebars)
}
pub async fn main () -> Result <(), Box <dyn Error>> {
pub async fn run_relay (state: Arc <RelayState>) -> Result <(), Box <dyn Error>>
{
let addr = SocketAddr::from(([0, 0, 0, 0], 4000));
let state = ServerState {
handlebars: Arc::new (load_templates ()?),
server_watchers: Default::default (),
client_watchers: Default::default (),
};
let state = Arc::new (state);
let make_svc = make_service_fn (|_conn| {
let state = state.clone ();
@ -292,47 +338,19 @@ pub async fn main () -> Result <(), Box <dyn Error>> {
Ok (())
}
pub async fn main () -> Result <(), Box <dyn Error>> {
let state = RelayState {
handlebars: Arc::new (load_templates ()?),
request_rendezvous: Default::default (),
response_rendezvous: Default::default (),
};
let state = Arc::new (state);
run_relay (state).await
}
#[cfg (test)]
mod tests {
// Toy model of a relay for a single server
// with one consumer thread.
// To scale this up, we can just put a bunch into a
// concurrent hash map inside of mutexes or something
struct RelayStateMachine {
}
enum RequestStateMachine {
WaitForServerAccept, // Client has connected
WaitForServerResponse, // Server has accepted request
}
/*
Here's what we need to handle:
When a request comes in:
- Look up the server
- If the server is parked, unpark it
- Park the client
When a server comes to listen:
- Look up the server
- Either return all pending requests, or park the server
When a server comes to respond:
- Look up the parked client
- Begin a stream, unparking the client
So we need these lookups to be fast:
- Server IDs, where 0 or 1 servers and 0 or many clients
can be parked
- Request IDs, where 1 client is parked
*/
}

View File

@ -21,7 +21,7 @@ pub mod file_server;
async fn handle_req_resp <'a> (
opt: &'a Opt,
handlebars: Arc <Handlebars <'static>>,
client: &'a Client,
client: Arc <Client>,
req_resp: reqwest::Response
) {
//println! ("Step 1");
@ -32,30 +32,39 @@ async fn handle_req_resp <'a> (
}
let body = req_resp.bytes ().await.unwrap ();
let wrapped_req: http_serde::WrappedRequest = match rmp_serde::from_read_ref (&body)
let wrapped_reqs: Vec <http_serde::WrappedRequest> = match rmp_serde::from_read_ref (&body)
{
Ok (x) => x,
_ => return,
};
let (req_id, parts) = (wrapped_req.id, wrapped_req.req);
let response = file_server::serve_all (handlebars, &opt.file_server_root, parts).await;
let mut resp_req = client
.post (&format! ("{}/7ZSFUKGV_http_response/{}", opt.relay_url, req_id))
.header (crate::PTTH_MAGIC_HEADER, base64::encode (rmp_serde::to_vec (&response.parts).unwrap ()));
if let Some (body) = response.body {
resp_req = resp_req.body (reqwest::Body::wrap_stream (body));
}
//println! ("Step 6");
if let Err (e) = resp_req.send ().await {
println! ("Err: {:?}", e);
for wrapped_req in wrapped_reqs.into_iter () {
let handlebars = handlebars.clone ();
let opt = opt.clone ();
let client = client.clone ();
tokio::spawn (async move {
let (req_id, parts) = (wrapped_req.id, wrapped_req.req);
let response = file_server::serve_all (handlebars, &opt.file_server_root, parts).await;
let mut resp_req = client
.post (&format! ("{}/7ZSFUKGV_http_response/{}", opt.relay_url, req_id))
.header (crate::PTTH_MAGIC_HEADER, base64::encode (rmp_serde::to_vec (&response.parts).unwrap ()));
if let Some (body) = response.body {
resp_req = resp_req.body (reqwest::Body::wrap_stream (body));
}
//println! ("Step 6");
if let Err (e) = resp_req.send ().await {
println! ("Err: {:?}", e);
}
});
}
}
#[derive (Clone)]
pub struct Opt {
pub relay_url: String,
pub server_name: String,
@ -65,7 +74,6 @@ pub struct Opt {
pub async fn main (opt: Opt) -> Result <(), Box <dyn Error>> {
let client = Arc::new (Client::new ());
let opt = Arc::new (opt);
let handlebars = Arc::new (file_server::load_templates ()?);
let mut backoff_delay = 0;
@ -97,7 +105,7 @@ pub async fn main (opt: Opt) -> Result <(), Box <dyn Error>> {
let handlebars = handlebars.clone ();
tokio::spawn (async move {
handle_req_resp (&opt, handlebars, &client, req_resp).await;
handle_req_resp (&opt, handlebars, client, req_resp).await;
});
}
}

View File

@ -1,4 +1,3 @@
- Fix possible timing gap when refreshing http_listen
- Set up tokens or privkeys or tripcodes or something so
clients can't trivially impersonate servers