ptth/src/relay/mod.rs

631 lines
15 KiB
Rust

use std::{
error::Error,
collections::*,
convert::Infallible,
iter::FromIterator,
net::SocketAddr,
sync::{
Arc,
},
time::Duration,
};
use dashmap::DashMap;
use futures::{
FutureExt,
stream::StreamExt,
};
use handlebars::Handlebars;
use hyper::{
Body,
Method,
Request,
Response,
Server,
StatusCode,
};
use hyper::service::{make_service_fn, service_fn};
use serde::{
Deserialize,
Serialize,
};
use tokio::{
spawn,
sync::{
Mutex,
mpsc,
oneshot,
RwLock,
watch,
},
time::delay_for,
};
use tracing::{
debug, error, info, trace,
instrument,
};
use crate::{
http_serde,
prefix_match,
};
/*
Here's what we need to handle:
When a request comes in:
- Park the client in response_rendezvous
- Look up the server ID in request_rendezvous
- If a server is parked, unpark it and send the request
- Otherwise, queue the request
When a server comes to listen:
- Look up the server ID in request_rendezvous
- Either return all pending requests, or park the server
When a server comes to respond:
- Look up the parked client in response_rendezvous
- Unpark the client and begin streaming
So we need these lookups to be fast:
- Server IDs, where (1 server) or (0 or many clients)
can be parked
- Request IDs, where 1 client is parked
*/
#[derive (Debug)]
enum RelayError {
RelayShuttingDown,
}
enum RequestRendezvous {
ParkedClients (Vec <http_serde::WrappedRequest>),
ParkedServer (oneshot::Sender <Result <http_serde::WrappedRequest, RelayError>>),
}
type ResponseRendezvous = oneshot::Sender <Result <(http_serde::ResponseParts, Body), RelayError>>;
// Stuff we need to load from the config file and use to
// set up the HTTP server
#[derive (Default, Deserialize)]
pub struct ConfigFile {
pub port: Option <u16>,
pub server_tripcodes: HashMap <String, String>,
}
// Stuff we actually need at runtime
struct Config {
server_tripcodes: HashMap <String, blake3::Hash>,
}
impl From <&ConfigFile> for Config {
fn from (f: &ConfigFile) -> Self {
let trips = HashMap::from_iter (f.server_tripcodes.iter ()
.map (|(k, v)| {
use std::convert::TryInto;
let bytes: Vec <u8> = base64::decode (v).unwrap ();
let bytes: [u8; 32] = (&bytes [..]).try_into ().unwrap ();
let v = blake3::Hash::from (bytes);
(k.clone (), v)
}));
Self {
server_tripcodes: trips,
}
}
}
pub struct RelayState {
config: Config,
handlebars: Arc <Handlebars <'static>>,
// Key: Server ID
request_rendezvous: Mutex <HashMap <String, RequestRendezvous>>,
// Key: Request ID
response_rendezvous: RwLock <DashMap <String, ResponseRendezvous>>,
shutdown_watch_tx: watch::Sender <bool>,
shutdown_watch_rx: watch::Receiver <bool>,
}
impl From <&ConfigFile> for RelayState {
fn from (config_file: &ConfigFile) -> Self {
let (shutdown_watch_tx, shutdown_watch_rx) = watch::channel (false);
Self {
config: Config::from (config_file),
handlebars: Arc::new (load_templates ().unwrap ()),
request_rendezvous: Default::default (),
response_rendezvous: Default::default (),
shutdown_watch_tx,
shutdown_watch_rx,
}
}
}
impl RelayState {
pub async fn list_servers (&self) -> Vec <String> {
self.request_rendezvous.lock ().await.iter ()
.map (|(k, _)| (*k).clone ())
.collect ()
}
}
fn ok_reply <B: Into <Body>> (b: B)
-> Response <Body>
{
Response::builder ().status (StatusCode::OK).body (b.into ()).unwrap ()
}
fn error_reply (status: StatusCode, b: &str)
-> Response <Body>
{
Response::builder ()
.status (status)
.header ("content-type", "text/plain")
.body (format! ("{}\n", b).into ()).unwrap ()
}
// Servers will come here and either handle queued requests from parked clients,
// or park themselves until a request comes in.
async fn handle_http_listen (
state: Arc <RelayState>,
watcher_code: String,
api_key: &[u8],
)
-> Response <Body>
{
let trip_error = error_reply (StatusCode::UNAUTHORIZED, "Bad X-ApiKey");
let expected_tripcode = match state.config.server_tripcodes.get (&watcher_code) {
None => {
error! ("Denied http_listen for non-existent server name {}", watcher_code);
return trip_error;
},
Some (x) => x,
};
let actual_tripcode = blake3::hash (api_key);
if expected_tripcode != &actual_tripcode {
error! ("Denied http_listen for bad tripcode {}", base64::encode (actual_tripcode.as_bytes ()));
return trip_error;
}
use RequestRendezvous::*;
let (tx, rx) = oneshot::channel ();
{
let mut request_rendezvous = state.request_rendezvous.lock ().await;
if let Some (ParkedClients (v)) = request_rendezvous.remove (&watcher_code)
{
if ! v.is_empty () {
// 1 or more clients were parked - Make the server
// handle them immediately
debug! ("Sending {} parked requests to server {}", v.len (), watcher_code);
return ok_reply (rmp_serde::to_vec (&v).unwrap ());
}
}
debug! ("Parking server {}", watcher_code);
request_rendezvous.insert (watcher_code.clone (), ParkedServer (tx));
}
// No clients were parked - make the server long-poll
futures::select! {
x = rx.fuse () => match x {
Ok (Ok (one_req)) => {
debug! ("Unparking server {}", watcher_code);
ok_reply (rmp_serde::to_vec (&vec! [one_req]).unwrap ())
},
Ok (Err (RelayError::RelayShuttingDown)) => error_reply (StatusCode::SERVICE_UNAVAILABLE, "Server is shutting down, try again soon"),
Err (_) => error_reply (StatusCode::INTERNAL_SERVER_ERROR, "Server error"),
},
_ = delay_for (Duration::from_secs (30)).fuse () => {
debug! ("Timed out http_listen for server {}", watcher_code);
return error_reply (StatusCode::NO_CONTENT, "No requests now, long-poll again")
}
}
}
// Servers will come here to stream responses to clients
async fn handle_http_response (
req: Request <Body>,
state: Arc <RelayState>,
req_id: String,
)
-> Response <Body>
{
let (parts, mut body) = req.into_parts ();
let resp_parts: http_serde::ResponseParts = rmp_serde::from_read_ref (&base64::decode (parts.headers.get (crate::PTTH_MAGIC_HEADER).unwrap ()).unwrap ()).unwrap ();
// Intercept the body packets here so we can check when the stream
// ends or errors out
#[derive (Debug)]
enum BodyFinishedReason {
StreamFinished,
ClientDisconnected,
}
use BodyFinishedReason::*;
let (mut body_tx, body_rx) = mpsc::channel (2);
let (body_finished_tx, body_finished_rx) = oneshot::channel ();
let mut shutdown_watch_rx = state.shutdown_watch_rx.clone ();
spawn (async move {
if shutdown_watch_rx.recv ().await == Some (false) {
loop {
let item = body.next ().await;
if let Some (item) = item {
if let Ok (bytes) = &item {
trace! ("Relaying {} bytes", bytes.len ());
}
futures::select! {
x = body_tx.send (item).fuse () => if let Err (_) = x {
info! ("Body closed while relaying. (Client hung up?)");
body_finished_tx.send (ClientDisconnected).unwrap ();
break;
},
_ = shutdown_watch_rx.recv ().fuse () => {
debug! ("Closing stream: relay is shutting down");
break;
},
}
}
else {
debug! ("Finished relaying bytes");
body_finished_tx.send (StreamFinished).unwrap ();
break;
}
}
}
else {
debug! ("Can't relay bytes, relay is shutting down");
}
});
let body = Body::wrap_stream (body_rx);
let tx = {
let response_rendezvous = state.response_rendezvous.read ().await;
match response_rendezvous.remove (&req_id) {
None => {
error! ("Server tried to respond to non-existent request");
return error_reply (StatusCode::BAD_REQUEST, "Request ID not found in response_rendezvous");
},
Some ((_, x)) => x,
}
};
// UKAUFFY4 (Send half)
if tx.send (Ok ((resp_parts, body))).is_err () {
let msg = "Failed to connect to client";
error! (msg);
return error_reply (StatusCode::BAD_GATEWAY, msg);
}
debug! ("Connected server to client for streaming.");
match body_finished_rx.await {
Ok (StreamFinished) => {
error_reply (StatusCode::OK, "StreamFinished")
},
Ok (ClientDisconnected) => {
error_reply (StatusCode::OK, "ClientDisconnected")
},
Err (e) => {
debug! ("body_finished_rx {}", e);
error_reply (StatusCode::OK, "body_finished_rx Err")
},
}
}
// Clients will come here to start requests, and always park for at least
// a short amount of time.
async fn handle_http_request (
req: http::request::Parts,
uri: String,
state: Arc <RelayState>,
watcher_code: String
)
-> Response <Body>
{
if ! state.config.server_tripcodes.contains_key (&watcher_code) {
return error_reply (StatusCode::NOT_FOUND, "Unknown server");
}
let req = match http_serde::RequestParts::from_hyper (req.method, uri, req.headers) {
Ok (x) => x,
_ => return error_reply (StatusCode::BAD_REQUEST, "Bad request"),
};
let (tx, rx) = oneshot::channel ();
let req_id = ulid::Ulid::new ().to_string ();
{
let response_rendezvous = state.response_rendezvous.read ().await;
response_rendezvous.insert (req_id.clone (), tx);
}
trace! ("Created request {}", req_id);
{
let mut request_rendezvous = state.request_rendezvous.lock ().await;
let wrapped = http_serde::WrappedRequest {
id: req_id.clone (),
req,
};
use RequestRendezvous::*;
let new_rendezvous = match request_rendezvous.remove (&watcher_code) {
Some (ParkedClients (mut v)) => {
debug! ("Parking request {} ({} already queued)", req_id, v.len ());
v.push (wrapped);
ParkedClients (v)
},
Some (ParkedServer (s)) => {
// If sending to the server fails, queue it
match s.send (Ok (wrapped)) {
Ok (()) => {
// TODO: This can actually still fail, if the server
// disconnects right as we're sending this.
// Then what?
debug! (
"Sending request {} directly to server {}",
req_id,
watcher_code,
);
ParkedClients (vec! [])
},
Err (Ok (wrapped)) => {
debug! ("Parking request {}", req_id);
ParkedClients (vec! [wrapped])
},
Err (_) => unreachable! (),
}
},
None => {
debug! ("Parking request {}", req_id);
ParkedClients (vec! [wrapped])
},
};
request_rendezvous.insert (watcher_code, new_rendezvous);
}
let timeout = tokio::time::delay_for (std::time::Duration::from_secs (30));
let received = tokio::select! {
val = rx => val,
() = timeout => {
debug! ("Timed out request {}", req_id);
return error_reply (StatusCode::GATEWAY_TIMEOUT, "Remote server never responded")
},
};
// UKAUFFY4 (Receive half)
match received {
Ok (Ok ((parts, body))) => {
let mut resp = Response::builder ()
.status (hyper::StatusCode::from (parts.status_code));
for (k, v) in parts.headers.into_iter () {
resp = resp.header (&k, v);
}
debug! ("Unparked request {}", req_id);
resp.body (body)
.unwrap ()
},
Ok (Err (RelayError::RelayShuttingDown)) => {
error_reply (StatusCode::GATEWAY_TIMEOUT, "Relay shutting down")
},
Err (_) => {
debug! ("Responder sender dropped for request {}", req_id);
error_reply (StatusCode::GATEWAY_TIMEOUT, "Remote server timed out")
},
}
}
#[instrument (level = "trace", skip (req, state))]
async fn handle_all (req: Request <Body>, state: Arc <RelayState>)
-> Result <Response <Body>, Infallible>
{
let path = req.uri ().path ();
//println! ("{}", path);
debug! ("Request path: {}", path);
let api_key = req.headers ().get ("X-ApiKey");
if req.method () == Method::POST {
// This is stuff the server can use. Clients can't
// POST right now
return Ok (if let Some (request_code) = prefix_match ("/7ZSFUKGV/http_response/", path) {
let request_code = request_code.into ();
handle_http_response (req, state, request_code).await
}
else {
error_reply (StatusCode::BAD_REQUEST, "Can't POST this")
});
}
Ok (if let Some (listen_code) = prefix_match ("/7ZSFUKGV/http_listen/", path) {
let api_key = match api_key {
None => return Ok (error_reply (StatusCode::UNAUTHORIZED, "Can't register as server without an API key")),
Some (x) => x,
};
handle_http_listen (state, listen_code.into (), api_key.as_bytes ()).await
}
else if let Some (rest) = prefix_match ("/frontend/servers/", path) {
if rest == "" {
#[derive (Serialize)]
struct ServerEntry <'a> {
path: &'a str,
name: &'a str,
}
#[derive (Serialize)]
struct ServerListPage <'a> {
servers: Vec <ServerEntry <'a>>,
}
let names = state.list_servers ().await;
//println! ("Found {} servers", names.len ());
let page = ServerListPage {
servers: names.iter ()
.map (|name| ServerEntry {
name: &name,
path: &name,
})
.collect (),
};
let s = state.handlebars.render ("relay_server_list", &page).unwrap ();
ok_reply (s)
}
else if let Some (idx) = rest.find ('/') {
let listen_code = String::from (&rest [0..idx]);
let path = String::from (&rest [idx..]);
let (parts, _) = req.into_parts ();
handle_http_request (parts, path, state, listen_code).await
}
else {
error_reply (StatusCode::BAD_REQUEST, "Bad URI format")
}
}
else if path == "/" {
let s = state.handlebars.render ("relay_root", &()).unwrap ();
ok_reply (s)
}
else if path == "/frontend/relay_up_check" {
error_reply (StatusCode::OK, "Relay is up")
}
else {
error_reply (StatusCode::OK, "Hi")
})
}
pub fn load_templates ()
-> Result <Handlebars <'static>, Box <dyn Error>>
{
let mut handlebars = Handlebars::new ();
handlebars.set_strict_mode (true);
for (k, v) in vec! [
("relay_server_list", "relay_server_list.html"),
("relay_root", "relay_root.html"),
].into_iter () {
handlebars.register_template_file (k, format! ("ptth_handlebars/{}", v))?;
}
Ok (handlebars)
}
pub async fn run_relay (
state: Arc <RelayState>,
shutdown_oneshot: oneshot::Receiver <()>
)
-> Result <(), Box <dyn Error>>
{
let addr = SocketAddr::from ((
[0, 0, 0, 0],
4000,
));
{
let mut tripcode_set = HashSet::new ();
for (_, v) in state.config.server_tripcodes.iter () {
if ! tripcode_set.insert (v) {
panic! ("Two servers have the same tripcode. That is not allowed.");
}
}
}
info! ("Loaded {} server tripcodes", state.config.server_tripcodes.len ());
let make_svc = make_service_fn (|_conn| {
let state = state.clone ();
async {
Ok::<_, Infallible> (service_fn (move |req| {
let state = state.clone ();
handle_all (req, state)
}))
}
});
let server = Server::bind (&addr)
.serve (make_svc);
server.with_graceful_shutdown (async {
shutdown_oneshot.await.ok ();
state.shutdown_watch_tx.broadcast (true).unwrap ();
use RelayError::*;
let mut response_rendezvous = state.response_rendezvous.write ().await;
let mut swapped = DashMap::default ();
std::mem::swap (&mut swapped, &mut response_rendezvous);
for (_, sender) in swapped.into_iter () {
sender.send (Err (RelayShuttingDown)).ok ();
}
let mut request_rendezvous = state.request_rendezvous.lock ().await;
for (_, x) in request_rendezvous.drain () {
use RequestRendezvous::*;
match x {
ParkedClients (_) => (),
ParkedServer (sender) => drop (sender.send (Err (RelayShuttingDown))),
}
}
debug! ("Performed all cleanup");
}).await?;
Ok (())
}
#[cfg (test)]
mod tests {
}