ptth/crates/ptth_relay/src/lib.rs

672 lines
16 KiB
Rust

#![warn (clippy::pedantic)]
// I don't see the point of writing the type twice if I'm initializing a struct
// and the type is already in the struct definition.
#![allow (clippy::default_trait_access)]
// I'm not sure if I like this one
#![allow (clippy::enum_glob_use)]
// I don't see the point in documenting the errors outside of where the
// error type is defined.
#![allow (clippy::missing_errors_doc)]
// False positive on futures::select! macro
#![allow (clippy::mut_mut)]
use std::{
borrow::Cow,
collections::HashMap,
convert::TryFrom,
iter::FromIterator,
net::SocketAddr,
path::{Path, PathBuf},
sync::Arc,
time::Duration,
};
use chrono::{
DateTime,
SecondsFormat,
Utc
};
use dashmap::DashMap;
use handlebars::Handlebars;
use hyper::{
Body,
Method,
Request,
Response,
Server,
StatusCode,
};
use hyper::service::{make_service_fn, service_fn};
use serde::{
Serialize,
};
use tokio::{
sync::{
Mutex,
oneshot,
RwLock,
watch,
},
};
use ptth_core::{
http_serde,
prefix_match,
prelude::*,
};
pub mod config;
pub mod errors;
pub mod git_version;
pub mod key_validity;
mod server_endpoint;
pub use config::Config;
pub use errors::*;
/*
Here's what we need to handle:
When a request comes in:
- Park the client in response_rendezvous
- Look up the server ID in request_rendezvous
- If a server is parked, unpark it and send the request
- Otherwise, queue the request
When a server comes to listen:
- Look up the server ID in request_rendezvous
- Either return all pending requests, or park the server
When a server comes to respond:
- Look up the parked client in response_rendezvous
- Unpark the client and begin streaming
So we need these lookups to be fast:
- Server IDs, where (1 server) or (0 or many clients)
can be parked
- Request IDs, where 1 client is parked
*/
enum RequestRendezvous {
ParkedClients (Vec <http_serde::WrappedRequest>),
ParkedServer (oneshot::Sender <Result <http_serde::WrappedRequest, ShuttingDownError>>),
}
type ResponseRendezvous = oneshot::Sender <Result <(http_serde::ResponseParts, Body), ShuttingDownError>>;
#[derive (Clone)]
pub struct ServerStatus {
last_seen: DateTime <Utc>,
}
impl Default for ServerStatus {
fn default () -> Self {
Self {
last_seen: Utc::now (),
}
}
}
pub struct RelayState {
config: RwLock <Config>,
handlebars: Arc <Handlebars <'static>>,
// Key: Server ID
request_rendezvous: Mutex <HashMap <String, RequestRendezvous>>,
server_status: Mutex <HashMap <String, ServerStatus>>,
// Key: Request ID
response_rendezvous: RwLock <DashMap <String, ResponseRendezvous>>,
shutdown_watch_tx: watch::Sender <bool>,
shutdown_watch_rx: watch::Receiver <bool>,
}
impl TryFrom <Config> for RelayState {
type Error = RelayError;
fn try_from (config: Config) -> Result <Self, Self::Error> {
let (shutdown_watch_tx, shutdown_watch_rx) = watch::channel (false);
Ok (Self {
config: config.into (),
handlebars: Arc::new (load_templates (&PathBuf::new ())?),
request_rendezvous: Default::default (),
server_status: Default::default (),
response_rendezvous: Default::default (),
shutdown_watch_tx,
shutdown_watch_rx,
})
}
}
impl RelayState {
pub async fn list_servers (&self) -> Vec <String> {
self.request_rendezvous.lock ().await.iter ()
.map (|(k, _)| (*k).clone ())
.collect ()
}
}
fn ok_reply <B: Into <Body>> (b: B)
-> Result <Response <Body>, http::Error>
{
Response::builder ().status (StatusCode::OK).body (b.into ())
}
fn error_reply (status: StatusCode, b: &str)
-> Result <Response <Body>, http::Error>
{
Response::builder ()
.status (status)
.header ("content-type", "text/plain")
.body (format! ("{}\n", b).into ())
}
// Clients will come here to start requests, and always park for at least
// a short amount of time.
async fn handle_http_request (
req: http::request::Parts,
uri: String,
state: Arc <RelayState>,
watcher_code: String
)
-> Result <Response <Body>, http::Error>
{
{
let config = state.config.read ().await;
if ! config.servers.contains_key (&watcher_code) {
return error_reply (StatusCode::NOT_FOUND, "Unknown server");
}
}
let req = match http_serde::RequestParts::from_hyper (req.method, uri, req.headers) {
Ok (x) => x,
Err (_) => return error_reply (StatusCode::BAD_REQUEST, "Bad request"),
};
let (tx, rx) = oneshot::channel ();
let req_id = ulid::Ulid::new ().to_string ();
{
let response_rendezvous = state.response_rendezvous.read ().await;
response_rendezvous.insert (req_id.clone (), tx);
}
trace! ("Created request {}", req_id);
{
use RequestRendezvous::*;
let mut request_rendezvous = state.request_rendezvous.lock ().await;
let wrapped = http_serde::WrappedRequest {
id: req_id.clone (),
req,
};
let new_rendezvous = match request_rendezvous.remove (&watcher_code) {
Some (ParkedClients (mut v)) => {
debug! ("Parking request {} ({} already queued)", req_id, v.len ());
v.push (wrapped);
ParkedClients (v)
},
Some (ParkedServer (s)) => {
// If sending to the server fails, queue it
match s.send (Ok (wrapped)) {
Ok (()) => {
// TODO: This can actually still fail, if the server
// disconnects right as we're sending this.
// Then what?
debug! (
"Sending request {} directly to server {}",
req_id,
watcher_code,
);
ParkedClients (vec! [])
},
Err (Ok (wrapped)) => {
debug! ("Parking request {}", req_id);
ParkedClients (vec! [wrapped])
},
Err (_) => unreachable! (),
}
},
None => {
debug! ("Parking request {}", req_id);
ParkedClients (vec! [wrapped])
},
};
request_rendezvous.insert (watcher_code, new_rendezvous);
}
let timeout = tokio::time::delay_for (std::time::Duration::from_secs (30));
let received = tokio::select! {
val = rx => val,
() = timeout => {
debug! ("Timed out request {}", req_id);
return error_reply (StatusCode::GATEWAY_TIMEOUT, "Remote server never responded")
},
};
// UKAUFFY4 (Receive half)
match received {
Ok (Ok ((parts, body))) => {
let mut resp = Response::builder ()
.status (hyper::StatusCode::from (parts.status_code));
for (k, v) in parts.headers {
resp = resp.header (&k, v);
}
debug! ("Unparked request {}", req_id);
resp.body (body)
},
Ok (Err (ShuttingDownError::ShuttingDown)) => {
error_reply (StatusCode::GATEWAY_TIMEOUT, "Relay shutting down")
},
Err (_) => {
debug! ("Responder sender dropped for request {}", req_id);
error_reply (StatusCode::GATEWAY_TIMEOUT, "Remote server timed out")
},
}
}
#[derive (Debug, PartialEq)]
enum LastSeen {
Negative,
Connected,
Description (String),
}
// Mnemonic is "now - last_seen"
fn pretty_print_last_seen (
now: DateTime <Utc>,
last_seen: DateTime <Utc>
) -> LastSeen
{
use LastSeen::*;
let dur = now.signed_duration_since (last_seen);
if dur < chrono::Duration::zero () {
return Negative;
}
if dur.num_minutes () < 1 {
return Connected;
}
if dur.num_hours () < 1 {
return Description (format! ("{} m ago", dur.num_minutes ()));
}
if dur.num_days () < 1 {
return Description (format! ("{} h ago", dur.num_hours ()));
}
Description (last_seen.to_rfc3339_opts (SecondsFormat::Secs, true))
}
#[derive (Serialize)]
struct ServerEntry <'a> {
id: String,
display_name: String,
last_seen: Cow <'a, str>,
}
#[derive (Serialize)]
struct ServerListPage <'a> {
dev_mode: bool,
git_version: Option <String>,
servers: Vec <ServerEntry <'a>>,
}
async fn handle_server_list_internal (state: &Arc <RelayState>)
-> ServerListPage <'static>
{
let dev_mode;
let display_names: HashMap <String, String> = {
let guard = state.config.read ().await;
dev_mode = guard.iso.dev_mode.is_some ();
let servers = (*guard).servers.iter ()
.map (|(k, v)| {
let display_name = v.display_name
.clone ()
.unwrap_or_else (|| k.clone ());
(k.clone (), display_name)
});
HashMap::from_iter (servers)
};
let server_statuses = {
let guard = state.server_status.lock ().await;
(*guard).clone ()
};
let now = Utc::now ();
let mut servers: Vec <_> = display_names.into_iter ()
.map (|(id, display_name)| {
use LastSeen::*;
let status = match server_statuses.get (&id) {
None => return ServerEntry {
display_name,
id,
last_seen: "Never".into (),
},
Some (x) => x,
};
let last_seen = match pretty_print_last_seen (now, status.last_seen) {
Negative => "Error (negative time)".into (),
Connected => "Connected".into (),
Description (s) => s.into (),
};
ServerEntry {
display_name,
id,
last_seen,
}
})
.collect ();
servers.sort_by (|a, b| a.display_name.cmp (&b.display_name));
ServerListPage {
dev_mode,
git_version: git_version::read_git_version ().await,
servers,
}
}
async fn handle_server_list (
state: Arc <RelayState>
) -> Result <Response <Body>, RequestError>
{
let page = handle_server_list_internal (&state).await;
let s = state.handlebars.render ("relay_server_list", &page)?;
Ok (ok_reply (s)?)
}
#[instrument (level = "trace", skip (req, state))]
async fn handle_scraper_api_v1 (
req: Request <Body>,
state: Arc <RelayState>,
path_rest: &str
)
-> Result <Response <Body>, RequestError>
{
use key_validity::KeyValidity;
let api_key = req.headers ().get ("X-ApiKey");
let api_key = match api_key {
None => return Ok (error_reply (StatusCode::FORBIDDEN, "Can't run scraper without an API key")?),
Some (x) => x,
};
let bad_key = || error_reply (StatusCode::FORBIDDEN, "403 Forbidden");
{
let config = state.config.read ().await;
let dev_mode = match &config.iso.dev_mode {
None => return Ok (bad_key ()?),
Some (x) => x,
};
let expected_key = match &dev_mode.scraper_key {
None => return Ok (bad_key ()?),
Some (x) => x,
};
let now = chrono::Utc::now ();
match expected_key.is_valid (now, api_key.as_bytes ()) {
KeyValidity::Valid => (),
KeyValidity::WrongKey (bad_hash) => {
error! ("Bad scraper key with hash {:?}", bad_hash);
return Ok (bad_key ()?);
}
err => {
error! ("Bad scraper key {:?}", err);
return Ok (bad_key ()?);
},
}
}
if path_rest == "test" {
Ok (error_reply (StatusCode::OK, "You're valid!")?)
}
else {
Ok (error_reply (StatusCode::NOT_FOUND, "Unknown API endpoint")?)
}
}
#[instrument (level = "trace", skip (req, state))]
async fn handle_scraper_api (
req: Request <Body>,
state: Arc <RelayState>,
path_rest: &str
)
-> Result <Response <Body>, RequestError>
{
{
if ! state.config.read ().await.iso.enable_scraper_auth {
return Ok (error_reply (StatusCode::FORBIDDEN, "Scraper API disabled")?);
}
}
if let Some (rest) = prefix_match ("v1/", path_rest) {
handle_scraper_api_v1 (req, state, rest).await
}
else if let Some (rest) = prefix_match ("api/", path_rest) {
handle_scraper_api_v1 (req, state, rest).await
}
else {
Ok (error_reply (StatusCode::NOT_FOUND, "Unknown scraper API version")?)
}
}
#[instrument (level = "trace", skip (req, state))]
async fn handle_all (req: Request <Body>, state: Arc <RelayState>)
-> Result <Response <Body>, RequestError>
{
let path = req.uri ().path ().to_string ();
//println! ("{}", path);
debug! ("Request path: {}", path);
if req.method () == Method::POST {
// This is stuff the server can use. Clients can't
// POST right now
return if let Some (request_code) = prefix_match ("/7ZSFUKGV/http_response/", &path) {
let request_code = request_code.into ();
Ok (server_endpoint::handle_response (req, state, request_code).await?)
}
else {
Ok (error_reply (StatusCode::BAD_REQUEST, "Can't POST this")?)
};
}
if let Some (listen_code) = prefix_match ("/7ZSFUKGV/http_listen/", &path) {
let api_key = req.headers ().get ("X-ApiKey");
let api_key = match api_key {
None => return Ok (error_reply (StatusCode::FORBIDDEN, "Can't run server without an API key")?),
Some (x) => x,
};
server_endpoint::handle_listen (state, listen_code.into (), api_key.as_bytes ()).await
}
else if let Some (rest) = prefix_match ("/frontend/servers/", &path) {
if rest == "" {
Ok (handle_server_list (state).await?)
}
else if let Some (idx) = rest.find ('/') {
let listen_code = String::from (&rest [0..idx]);
let path = String::from (&rest [idx..]);
let (parts, _) = req.into_parts ();
Ok (handle_http_request (parts, path, state, listen_code).await?)
}
else {
Ok (error_reply (StatusCode::BAD_REQUEST, "Bad URI format")?)
}
}
else if path == "/" {
let s = state.handlebars.render ("relay_root", &())?;
Ok (ok_reply (s)?)
}
else if path == "/frontend/relay_up_check" {
Ok (error_reply (StatusCode::OK, "Relay is up")?)
}
else if path == "/frontend/test_mysterious_error" {
Err (RequestError::Mysterious)
}
else if let Some (rest) = prefix_match ("/scraper/", &path) {
handle_scraper_api (req, state, rest).await
}
else {
Ok (error_reply (StatusCode::OK, "Hi")?)
}
}
pub fn load_templates (asset_root: &Path)
-> Result <Handlebars <'static>, RelayError>
{
let mut handlebars = Handlebars::new ();
handlebars.set_strict_mode (true);
let asset_root = asset_root.join ("handlebars/relay");
for (k, v) in &[
("relay_server_list", "relay_server_list.html"),
("relay_root", "relay_root.html"),
] {
handlebars.register_template_file (k, &asset_root.join (v))?;
}
Ok (handlebars)
}
async fn reload_config (
state: &Arc <RelayState>,
config_reload_path: &Path
) -> Result <(), ConfigError> {
let new_config = Config::from_file (config_reload_path).await?;
let mut config = state.config.write ().await;
(*config) = new_config;
debug! ("Loaded {} server configs", config.servers.len ());
debug! ("enable_scraper_auth: {}", config.iso.enable_scraper_auth);
if config.iso.dev_mode.is_some () {
error! ("Dev mode is enabled! This might turn off some security features. If you see this in production, escalate it to someone!");
}
Ok (())
}
pub async fn run_relay (
state: Arc <RelayState>,
shutdown_oneshot: oneshot::Receiver <()>,
config_reload_path: Option <PathBuf>
)
-> Result <(), RelayError>
{
if let Some (config_reload_path) = config_reload_path {
let state_2 = state.clone ();
tokio::spawn (async move {
let mut reload_interval = tokio::time::interval (Duration::from_secs (60));
loop {
reload_interval.tick ().await;
reload_config (&state_2, &config_reload_path).await.ok ();
}
});
}
let make_svc = make_service_fn (|_conn| {
let state = state.clone ();
async {
Ok::<_, RequestError> (service_fn (move |req| {
let state = state.clone ();
handle_all (req, state)
}))
}
});
let addr = SocketAddr::from ((
[0, 0, 0, 0],
state.config.read ().await.port.unwrap_or (4000),
));
let server = Server::bind (&addr)
.serve (make_svc);
server.with_graceful_shutdown (async {
use ShuttingDownError::ShuttingDown;
shutdown_oneshot.await.ok ();
state.shutdown_watch_tx.broadcast (true).expect ("Can't broadcast graceful shutdown");
let mut response_rendezvous = state.response_rendezvous.write ().await;
let mut swapped = DashMap::default ();
std::mem::swap (&mut swapped, &mut response_rendezvous);
for (_, sender) in swapped {
sender.send (Err (ShuttingDown)).ok ();
}
let mut request_rendezvous = state.request_rendezvous.lock ().await;
for (_, x) in request_rendezvous.drain () {
use RequestRendezvous::*;
match x {
ParkedClients (_) => (),
ParkedServer (sender) => drop (sender.send (Err (ShuttingDown))),
}
}
debug! ("Performed all cleanup");
}).await?;
Ok (())
}
#[cfg (test)]
mod tests;