ptth/crates/ptth_relay/src/lib.rs

874 lines
20 KiB
Rust

//! # PTTH Relay
//!
//! The PTTH relay accepts incoming connections from PTTH servers, and
//! acts as a reverse proxy, forwarding incoming requests from HTTP clients
//! to PTTH servers.
#![warn (clippy::pedantic)]
// I don't see the point of writing the type twice if I'm initializing a struct
// and the type is already in the struct definition.
#![allow (clippy::default_trait_access)]
// I'm not sure if I like this one
#![allow (clippy::enum_glob_use)]
// I don't see the point in documenting the errors outside of where the
// error type is defined.
#![allow (clippy::missing_errors_doc)]
// False positive on futures::select! macro
#![allow (clippy::mut_mut)]
use std::{
borrow::Cow,
convert::Infallible,
net::SocketAddr,
path::{Path, PathBuf},
sync::Arc,
time::Duration,
};
use anyhow::bail;
use chrono::{
DateTime,
SecondsFormat,
Utc
};
use dashmap::DashMap;
use futures_util::StreamExt;
use handlebars::Handlebars;
use hyper::{
Body,
Request,
Response,
Server,
StatusCode,
};
use hyper::service::{make_service_fn, service_fn};
use serde::{
Serialize,
};
use tokio::{
sync::{
oneshot,
},
};
use tokio_stream::wrappers::ReceiverStream;
use ptth_core::{
http_serde,
prelude::*,
};
pub mod config;
pub mod errors;
pub mod key_validity;
mod git_version;
mod relay_state;
mod routing;
mod scraper_api;
mod server_endpoint;
pub use config::{
Config,
machine_editable,
};
pub use errors::*;
pub use relay_state::Relay;
use relay_state::{
AuditData,
AuditEvent,
};
use relay_state::{
RejectedServer,
RequestRendezvous,
};
fn ok_reply <B: Into <Body>> (b: B)
-> Result <Response <Body>, http::Error>
{
Response::builder ().status (StatusCode::OK).body (b.into ())
}
fn error_reply (status: StatusCode, b: &str)
-> Result <Response <Body>, http::Error>
{
Response::builder ()
.status (status)
.header ("content-type", "text/plain")
.body (format! ("{}\n", b).into ())
}
fn get_user_name (req: &http::request::Parts)
-> Option <String>
{
req.headers.get ("X-Email").and_then (|x| Some (x.to_str ().ok ()?.to_string ()))
}
/// Clients will come here to start requests, and always park for at least
/// a short amount of time.
async fn handle_http_request (
req: http::request::Parts,
uri: String,
state: &Relay,
server_name: &str
)
-> Result <Response <Body>, RequestError>
{
use RequestError::*;
let req_id = rusty_ulid::generate_ulid_string ();
debug! ("Created request {}", req_id);
let req_method = req.method.clone ();
if ! state.server_exists (server_name).await {
return Err (UnknownServer);
}
let req = http_serde::RequestParts::from_hyper (req.method, uri.clone (), req.headers)
.map_err (|_| BadRequest)?;
let (tx, rx) = oneshot::channel ();
let req_id = rusty_ulid::generate_ulid_string ();
debug! ("Forwarding {}", req_id);
{
let response_rendezvous = state.response_rendezvous.read ().await;
response_rendezvous.insert (req_id.clone (), tx);
}
state.park_client (server_name, req, &req_id).await;
// UKAUFFY4 (Receive half)
let received = match tokio::time::timeout (Duration::from_secs (30), rx).await
{
Err (_) => {
debug! ("Timed out request {}", req_id);
return Err (ServerNeverResponded);
}
Ok (x) => x,
};
let received = match received {
Err (_) => {
debug! ("Responder sender dropped for request {}", req_id);
return Err (ServerTimedOut);
},
Ok (x) => x,
};
let (parts, body) = match received {
Err (ShuttingDownError::ShuttingDown) => {
return Err (RelayShuttingDown);
},
Ok (x) => x,
};
let mut resp = Response::builder ()
.status (hyper::StatusCode::from (parts.status_code));
if
req_method == hyper::Method::GET &&
parts.headers.get ("accept-ranges").is_some ()
{
trace! ("Stream restart code could go here");
}
for (k, v) in parts.headers {
resp = resp.header (&k, v);
}
debug! ("Unparked request {}", req_id);
Ok (resp.body (body)?)
}
#[derive (Debug, PartialEq)]
enum LastSeen {
Negative,
Connected,
Description (String),
}
fn pretty_print_utc (
now: DateTime <Utc>,
last_seen: DateTime <Utc>
) -> String
{
let dur = now.signed_duration_since (last_seen);
if dur < chrono::Duration::zero () {
return last_seen.to_rfc3339_opts (SecondsFormat::Secs, true);
}
if dur.num_minutes () < 1 {
return format! ("{} s ago", dur.num_seconds ());
}
if dur.num_hours () < 1 {
return format! ("{} m ago", dur.num_minutes ());
}
last_seen.to_rfc3339_opts (SecondsFormat::Secs, true)
}
// Mnemonic is "now - last_seen"
fn pretty_print_last_seen (
now: DateTime <Utc>,
last_seen: DateTime <Utc>
) -> LastSeen
{
use LastSeen::*;
let dur = now.signed_duration_since (last_seen);
if dur < chrono::Duration::zero () {
return Negative;
}
if dur.num_minutes () < 1 {
return Connected;
}
if dur.num_hours () < 1 {
return Description (format! ("{} m ago", dur.num_minutes ()));
}
if dur.num_days () < 1 {
return Description (format! ("{} h ago", dur.num_hours ()));
}
Description (last_seen.to_rfc3339_opts (SecondsFormat::Secs, true))
}
#[derive (Serialize)]
struct ServerEntry <'a> {
name: String,
display_name: String,
last_seen: Cow <'a, str>,
}
#[derive (Serialize)]
struct ServerListPage <'a> {
dev_mode: bool,
git_version: Option <String>,
servers: Vec <ServerEntry <'a>>,
news_url: Option <String>,
connected_server_count: usize,
registered_server_count: usize,
date_rfc3339: String,
}
#[derive (Serialize)]
struct UnregisteredServerListPage {
unregistered_servers: Vec <UnregisteredServer>,
}
#[derive (Serialize)]
struct UnregisteredServer {
name: String,
tripcode: String,
last_seen: String,
}
#[derive (Serialize)]
struct AuditLogPage {
audit_log: Vec <AuditEntryPretty>,
}
#[derive (Serialize)]
struct AuditEntryPretty {
utc_pretty: String,
data_pretty: String,
}
async fn handle_server_list_internal (state: &Relay)
-> ServerListPage <'static>
{
use LastSeen::*;
let dev_mode;
let news_url;
{
let guard = state.config.read ().await;
dev_mode = guard.iso.dev_mode.is_some ();
news_url = guard.news_url.clone ();
}
let git_version = git_version::read ().await;
let server_list = scraper_api::v1_server_list (&state).await;
let now = Utc::now ();
let registered_server_count = server_list.servers.len ();
let mut connected_server_count = 0;
let servers = server_list.servers.into_iter ()
.map (|x| {
let last_seen = match x.last_seen {
None => "Never".into (),
Some (x) => match pretty_print_last_seen (now, x) {
Negative => "Error (negative time)".into (),
Connected => {
connected_server_count += 1;
"Connected".into ()
},
Description (s) => s.into (),
},
};
ServerEntry {
name: x.name,
display_name: x.display_name,
last_seen,
}
})
.collect ();
let date_rfc3339 = now.to_rfc3339_opts (SecondsFormat::Secs, true);
ServerListPage {
dev_mode,
git_version,
servers,
news_url,
connected_server_count,
registered_server_count,
date_rfc3339,
}
}
async fn handle_unregistered_servers_internal (state: &Relay)
-> UnregisteredServerListPage
{
use LastSeen::*;
let now = Utc::now ();
let mut server_list = state.unregistered_servers.to_vec ().await;
{
let me_config = state.me_config.read ().await;
server_list = server_list.into_iter ()
.filter (|s| ! me_config.servers.contains_key (&s.name))
.collect ();
}
server_list.sort_by_key (|s| {
(s.name.clone (), *s.tripcode.as_bytes (), now - s.seen)
});
server_list.dedup_by_key (|s| {
(s.name.clone (), s.tripcode)
});
let unregistered_servers = server_list.into_iter ()
.map (|x| {
let last_seen = match pretty_print_last_seen (now, x.seen) {
Negative => "Error (negative time)".into (),
Connected => "Recently".into (),
Description (s) => s,
};
let tripcode = base64::encode (x.tripcode.as_bytes ());
UnregisteredServer {
name: x.name,
tripcode,
last_seen,
}
}).collect ();
UnregisteredServerListPage {
unregistered_servers,
}
}
async fn handle_audit_log_internal (state: &Relay)
-> AuditLogPage
{
let utc_now = Utc::now ();
let audit_log = state.audit_log.to_vec ().await
.iter ().rev ().map (|e| {
AuditEntryPretty {
utc_pretty: pretty_print_utc (utc_now, e.time_utc),
data_pretty: format! ("{:?}", e.data),
}
}).collect ();
AuditLogPage {
audit_log,
}
}
async fn handle_server_list (
state: &Relay,
handlebars: Arc <Handlebars <'static>>
) -> Result <Response <Body>, RequestError>
{
let page = handle_server_list_internal (&state).await;
let s = handlebars.render ("server_list", &page)?;
Ok (ok_reply (s)?)
}
async fn handle_unregistered_servers (
state: &Relay,
handlebars: Arc <Handlebars <'static>>
) -> Result <Response <Body>, RequestError>
{
let page = handle_unregistered_servers_internal (&state).await;
let s = handlebars.render ("unregistered_servers", &page)?;
Ok (ok_reply (s)?)
}
async fn handle_audit_log (
state: &Relay,
handlebars: Arc <Handlebars <'static>>
) -> Result <Response <Body>, RequestError>
{
{
let cfg = state.config.read ().await;
if cfg.hide_audit_log {
return Ok (error_reply (StatusCode::FORBIDDEN, "Forbidden")?);
}
}
let page = handle_audit_log_internal (state).await;
let s = handlebars.render ("audit_log", &page)?;
Ok (ok_reply (s)?)
}
async fn handle_endless_sink (req: Request <Body>) -> Result <Response <Body>, http::Error>
{
let (_parts, mut body) = req.into_parts ();
let mut bytes_received = 0;
loop {
let item = body.next ().await;
if let Some (item) = item {
if let Ok (bytes) = &item {
bytes_received += bytes.len ();
}
}
else {
debug! ("Finished sinking debug bytes");
break;
}
}
Ok (ok_reply (format! ("Sank {} bytes\n", bytes_received))?)
}
async fn handle_endless_source (gib: usize, throttle: Option <usize>)
-> Result <Response <Body>, http::Error>
{
use tokio::sync::mpsc;
let block_bytes = 64 * 1024;
let num_blocks = (1024 * 1024 * 1024 / block_bytes) * gib;
let (tx, rx) = mpsc::channel (1);
tokio::spawn (async move {
let random_block = {
use rand::RngCore;
let mut rng = rand::thread_rng ();
let mut block = vec! [0_u8; 64 * 1024];
rng.fill_bytes (&mut block);
block
};
let mut interval = tokio::time::interval (Duration::from_millis (1000));
interval.set_missed_tick_behavior (tokio::time::MissedTickBehavior::Skip);
let mut blocks_sent = 0;
while blocks_sent < num_blocks {
if throttle.is_some () {
interval.tick ().await;
}
for _ in 0..throttle.unwrap_or (1) {
let item = Ok::<_, Infallible> (random_block.clone ());
if tx.send (item).await.is_err () {
debug! ("Endless source dropped");
return;
}
blocks_sent += 1;
}
}
debug! ("Endless source ended");
});
Response::builder ()
.status (StatusCode::OK)
.header ("content-type", "application/octet-stream")
.body (Body::wrap_stream (ReceiverStream::new (rx)))
}
async fn handle_gen_scraper_key (_state: &Relay)
-> Result <Response <Body>, http::Error>
{
let key = ptth_core::gen_key ();
let body = format! ("Random key: {}\n", key);
Response::builder ()
.status (StatusCode::OK)
.header ("content-type", "text/plain")
.body (Body::from (body))
}
async fn handle_register_server (req: Request <Body>, state: &Relay)
-> Result <(), anyhow::Error>
{
let (parts, body) = req.into_parts ();
let user = get_user_name (&parts);
let form_data = read_body_limited (body, 1_024).await?;
let server: crate::config::file::Server = serde_urlencoded::from_bytes (&form_data)?;
state.audit_log.push (AuditEvent::new (AuditData::RegisterServer {
user,
server: server.clone (),
})).await;
{
let mut me_config = state.me_config.write ().await;
me_config.servers.insert (server.name.clone (), server);
me_config.save (Path::new ("data/ptth_relay_me_config.toml")).await?;
}
Ok (())
}
async fn read_body_limited (mut body: Body, limit: usize) -> anyhow::Result <Vec <u8>>
{
let mut buffer = vec! [];
while let Some (chunk) = body.next ().await {
let chunk = chunk?;
if buffer.len () + chunk.len () > limit {
bail! ("Body was bigger than limit");
}
buffer.extend_from_slice (&chunk);
}
Ok (buffer)
}
#[instrument (level = "trace", skip (req, state, handlebars))]
async fn handle_all (
req: Request <Body>,
state: Arc <Relay>,
handlebars: Arc <Handlebars <'static>>
)
-> Result <Response <Body>, RequestError>
{
use routing::Route::*;
let state = &*state;
// The path is cloned here, so it's okay to consume the request
// later.
let path = req.uri ().path ().to_string ();
trace! ("Request path: {}", path);
let route = routing::route_url (req.method (), &path);
let route = match route {
Ok (x) => x,
Err (e) => {
use routing::Error;
let response = match e {
Error::BadUriFormat => error_reply (StatusCode::BAD_REQUEST, "Bad URI format")?,
Error::MethodNotAllowed => error_reply (StatusCode::METHOD_NOT_ALLOWED, "Method not allowed. Are you POST-ing to a GET-only url, or vice versa?")?,
Error::NotFound => error_reply (StatusCode::OK, "URL routing failed")?,
};
return Ok (response);
},
};
let response = match route {
ClientAuditLog => handle_audit_log (state, handlebars).await?,
ClientRelayIsUp => error_reply (StatusCode::OK, "Relay is up")?,
ClientServerGet {
listen_code,
path,
} => {
let (parts, _) = req.into_parts ();
let user = get_user_name (&parts);
state.audit_log.push (AuditEvent::new (AuditData::WebClientGet {
user,
server_name: listen_code.to_string (),
uri: path.to_string (),
})).await;
handle_http_request (parts, path.to_string (), &state, listen_code).await?
},
ClientServerList => handle_server_list (state, handlebars).await?,
ClientUnregisteredServers => handle_unregistered_servers (state, handlebars).await?,
Debug => {
let s = handlebars.render ("debug", &())?;
ok_reply (s)?
},
DebugEndlessSink => handle_endless_sink (req).await?,
DebugEndlessSource (throttle) => handle_endless_source (1, throttle).await?,
DebugGenKey => handle_gen_scraper_key (state).await?,
DebugMysteriousError => return Err (RequestError::Mysterious),
RegisterServer => {
match handle_register_server (req, state).await {
Ok (_) => Response::builder ()
.status (StatusCode::SEE_OTHER)
.header ("location", "unregistered_servers")
.body (Body::from ("Success. Redirecting..."))?,
Err (e) => error_reply (StatusCode::BAD_REQUEST, &format! ("{:?}", e))?,
}
}
Root => {
let s = handlebars.render ("root", &())?;
ok_reply (s)?
},
Scraper {
rest,
} => scraper_api::handle (req, state, rest).await?,
ServerHttpListen {
listen_code,
} => {
let api_key = req.headers ().get ("X-ApiKey");
let api_key = match api_key {
None => return Ok (error_reply (StatusCode::FORBIDDEN, "Forbidden")?),
Some (x) => x,
};
match check_server_api_key (state, listen_code, api_key.as_bytes ()).await {
Ok (_) => (),
Err (_) => return Ok (error_reply (StatusCode::FORBIDDEN, "Forbidden")?)
}
server_endpoint::handle_listen (state, listen_code.into ()).await?
},
ServerHttpResponse {
request_code,
} => {
let request_code = request_code.into ();
server_endpoint::handle_response (req, state, request_code).await?
},
};
Ok (response)
}
async fn check_server_api_key (state: &Relay, name: &str, api_key: &[u8])
-> Result <(), anyhow::Error>
{
let actual_tripcode = key_validity::BlakeHashWrapper::from_key (api_key);
let expected_human = {
let config = state.config.read ().await;
config.servers.get (name).map (|s| s.tripcode)
};
let expected_machine = {
let me_config = state.me_config.read ().await;
me_config.servers.get (name).map (|s| s.tripcode)
};
if expected_machine.is_none () && expected_human.is_none () {
state.unregistered_servers.push (crate::RejectedServer {
name: name.to_string (),
tripcode: *actual_tripcode,
seen: Utc::now (),
}).await;
bail! ("Denied API request for non-existent server name {}", name);
}
if Some (actual_tripcode) == expected_human {
return Ok (());
}
if Some (actual_tripcode) == expected_machine {
return Ok (());
}
bail! ("Denied API request for bad tripcode {}", base64::encode (actual_tripcode.as_bytes ()));
}
fn load_templates (asset_root: &Path)
-> Result <Handlebars <'static>, RelayError>
{
let mut handlebars = Handlebars::new ();
handlebars.set_strict_mode (true);
let asset_root = asset_root.join ("handlebars/relay");
for (k, v) in &[
("audit_log", "audit_log.hbs"),
("debug", "debug.hbs"),
("root", "root.hbs"),
("server_list", "server_list.hbs"),
("unregistered_servers", "unregistered_servers.hbs"),
] {
handlebars.register_template_file (k, &asset_root.join (v))?;
}
Ok (handlebars)
}
async fn reload_config (
state: &Arc <Relay>,
config_reload_path: &Path
) -> Result <(), ConfigError> {
// Reload human-editable config
let new_config = Config::from_file (config_reload_path).await?;
// Reload machine-editable config, if possible
// let me_config = machine_editable::Config::from_file (Path::new ("data/ptth_relay_me_config.toml")).await.ok ();
let mut config = state.config.write ().await;
trace! ("Reloading config");
if config.servers.len () != new_config.servers.len () {
debug! ("Loaded {} server configs", new_config.servers.len ());
}
if config.iso.enable_scraper_api != new_config.iso.enable_scraper_api {
debug! ("enable_scraper_api: {}", new_config.iso.enable_scraper_api);
}
(*config) = new_config;
if config.iso.dev_mode.is_some () {
error! ("Dev mode is enabled! This might turn off some security features. If you see this in production, escalate it to someone!");
}
Ok (())
}
pub async fn run_relay (
state: Arc <Relay>,
asset_root: &Path,
shutdown_oneshot: oneshot::Receiver <()>,
config_reload_path: Option <PathBuf>
)
-> Result <(), RelayError>
{
let handlebars = Arc::new (load_templates (asset_root)?);
if let Some (x) = git_version::read ().await {
info! ("ptth_relay Git version: {:?}", x);
}
else {
info! ("ptth_relay not built from Git");
}
if let Some (config_reload_path) = config_reload_path {
let state_2 = state.clone ();
tokio::spawn (async move {
let mut reload_interval = tokio::time::interval (Duration::from_secs (60));
reload_interval.set_missed_tick_behavior (tokio::time::MissedTickBehavior::Skip);
loop {
reload_interval.tick ().await;
reload_config (&state_2, &config_reload_path).await.ok ();
}
});
}
let make_svc = make_service_fn (|_conn| {
let state = state.clone ();
let handlebars = handlebars.clone ();
async {
Ok::<_, Infallible> (service_fn (move |req| {
let state = state.clone ();
let handlebars = handlebars.clone ();
async {
Ok::<_, Infallible> (handle_all (req, state, handlebars).await.unwrap_or_else (|e| {
use RequestError::*;
error! ("{}", e);
let status_code = match &e {
UnknownServer => StatusCode::NOT_FOUND,
BadRequest => StatusCode::BAD_REQUEST,
ServerNeverResponded | ServerTimedOut => StatusCode::GATEWAY_TIMEOUT,
_ => StatusCode::INTERNAL_SERVER_ERROR,
};
error_reply (status_code, "Error in relay").unwrap ()
}))
}
}))
}
});
let addr = {
let guard = state.config.read ().await;
SocketAddr::from ((
guard.address,
guard.port.unwrap_or (4000),
))
};
let server = Server::bind (&addr)
.serve (make_svc);
state.audit_log.push (AuditEvent::new (AuditData::RelayStart)).await;
trace! ("Serving relay on {:?}", addr);
server.with_graceful_shutdown (async {
use ShuttingDownError::ShuttingDown;
shutdown_oneshot.await.ok ();
state.shutdown_watch_tx.send (true).expect ("Can't broadcast graceful shutdown");
let mut response_rendezvous = state.response_rendezvous.write ().await;
let mut swapped = DashMap::default ();
std::mem::swap (&mut swapped, &mut response_rendezvous);
for (_, sender) in swapped {
sender.send (Err (ShuttingDown)).ok ();
}
let mut request_rendezvous = state.request_rendezvous.lock ().await;
for (_, x) in request_rendezvous.drain () {
use RequestRendezvous::*;
match x {
ParkedClients (_) => (),
ParkedServer (sender) => drop (sender.send (Err (ShuttingDown))),
}
}
debug! ("Performed all cleanup");
}).await?;
Ok (())
}
#[cfg (test)]
mod tests;