918 lines
21 KiB
Rust
918 lines
21 KiB
Rust
//! # PTTH Relay
|
|
//!
|
|
//! The PTTH relay accepts incoming connections from PTTH servers, and
|
|
//! acts as a reverse proxy, forwarding incoming requests from HTTP clients
|
|
//! to PTTH servers.
|
|
|
|
#![warn (clippy::pedantic)]
|
|
|
|
// I don't see the point of writing the type twice if I'm initializing a struct
|
|
// and the type is already in the struct definition.
|
|
#![allow (clippy::default_trait_access)]
|
|
|
|
// I'm not sure if I like this one
|
|
#![allow (clippy::enum_glob_use)]
|
|
|
|
// I don't see the point in documenting the errors outside of where the
|
|
// error type is defined.
|
|
#![allow (clippy::missing_errors_doc)]
|
|
|
|
// False positive on futures::select! macro
|
|
#![allow (clippy::mut_mut)]
|
|
|
|
use std::{
|
|
borrow::Cow,
|
|
convert::Infallible,
|
|
net::SocketAddr,
|
|
path::{Path, PathBuf},
|
|
sync::Arc,
|
|
time::Duration,
|
|
};
|
|
|
|
use anyhow::bail;
|
|
use chrono::{
|
|
DateTime,
|
|
SecondsFormat,
|
|
Utc
|
|
};
|
|
use dashmap::DashMap;
|
|
use futures_util::StreamExt;
|
|
use handlebars::Handlebars;
|
|
use hyper::{
|
|
Body,
|
|
Request,
|
|
Response,
|
|
Server,
|
|
StatusCode,
|
|
};
|
|
use hyper::service::{make_service_fn, service_fn};
|
|
use serde::{
|
|
Serialize,
|
|
};
|
|
use tokio::{
|
|
sync::{
|
|
oneshot,
|
|
},
|
|
};
|
|
use tokio_stream::wrappers::ReceiverStream;
|
|
|
|
use ptth_core::{
|
|
http_serde,
|
|
prelude::*,
|
|
};
|
|
|
|
pub mod config;
|
|
pub mod errors;
|
|
pub mod key_validity;
|
|
|
|
mod git_version;
|
|
mod relay_state;
|
|
mod routing;
|
|
mod scraper_api;
|
|
mod server_endpoint;
|
|
|
|
pub use config::{
|
|
Config,
|
|
machine_editable,
|
|
};
|
|
pub use errors::*;
|
|
pub use relay_state::Relay;
|
|
use relay_state::{
|
|
AuditData,
|
|
AuditEvent,
|
|
};
|
|
|
|
use relay_state::{
|
|
RejectedServer,
|
|
RequestRendezvous,
|
|
};
|
|
|
|
fn ok_reply <B: Into <Body>> (b: B)
|
|
-> Result <Response <Body>, http::Error>
|
|
{
|
|
Response::builder ().status (StatusCode::OK).body (b.into ())
|
|
}
|
|
|
|
fn error_reply (status: StatusCode, b: &str)
|
|
-> Result <Response <Body>, http::Error>
|
|
{
|
|
Response::builder ()
|
|
.status (status)
|
|
.header ("content-type", "text/plain")
|
|
.body (format! ("{}\n", b).into ())
|
|
}
|
|
|
|
fn get_user_name (req: &http::request::Parts)
|
|
-> Option <String>
|
|
{
|
|
req.headers.get ("X-Email").and_then (|x| Some (x.to_str ().ok ()?.to_string ()))
|
|
}
|
|
|
|
/// Clients will come here to start requests, and always park for at least
|
|
/// a short amount of time.
|
|
|
|
async fn handle_http_request (
|
|
req: http::request::Parts,
|
|
uri: String,
|
|
state: &Relay,
|
|
server_name: &str
|
|
)
|
|
-> Result <Response <Body>, RequestError>
|
|
{
|
|
use RequestError::*;
|
|
|
|
let req_id = rusty_ulid::generate_ulid_string ();
|
|
debug! ("Created request {}", req_id);
|
|
|
|
let req_method = req.method.clone ();
|
|
|
|
if ! state.server_exists (server_name).await {
|
|
return Err (UnknownServer);
|
|
}
|
|
|
|
let req = http_serde::RequestParts::from_hyper (req.method, uri.clone (), req.headers)
|
|
.map_err (|_| BadRequest)?;
|
|
|
|
let (tx, rx) = oneshot::channel ();
|
|
|
|
let tx = relay_state::ResponseRendezvous {
|
|
timeout: Instant::now () + Duration::from_secs (120),
|
|
tx,
|
|
};
|
|
|
|
let req_id = rusty_ulid::generate_ulid_string ();
|
|
|
|
debug! ("Forwarding {}", req_id);
|
|
|
|
{
|
|
let response_rendezvous = state.response_rendezvous.read ().await;
|
|
response_rendezvous.insert (req_id.clone (), tx);
|
|
}
|
|
|
|
state.park_client (server_name, req, &req_id).await;
|
|
|
|
// UKAUFFY4 (Receive half)
|
|
let received = match tokio::time::timeout (Duration::from_secs (30), rx).await
|
|
{
|
|
Err (_) => {
|
|
debug! ("Timed out request {}", req_id);
|
|
return Err (ServerNeverResponded);
|
|
}
|
|
Ok (x) => x,
|
|
};
|
|
|
|
let received = match received {
|
|
Err (_) => {
|
|
debug! ("Responder sender dropped for request {}", req_id);
|
|
return Err (ServerTimedOut);
|
|
},
|
|
Ok (x) => x,
|
|
};
|
|
|
|
let (parts, body) = match received {
|
|
Err (ShuttingDownError::ShuttingDown) => {
|
|
return Err (RelayShuttingDown);
|
|
},
|
|
Ok (x) => x,
|
|
};
|
|
|
|
let mut resp = Response::builder ()
|
|
.status (hyper::StatusCode::from (parts.status_code));
|
|
|
|
if
|
|
req_method == hyper::Method::GET &&
|
|
parts.headers.get ("accept-ranges").is_some ()
|
|
{
|
|
trace! ("Stream restart code could go here");
|
|
}
|
|
|
|
for (k, v) in parts.headers {
|
|
resp = resp.header (&k, v);
|
|
}
|
|
|
|
debug! ("Unparked request {}", req_id);
|
|
|
|
Ok (resp.body (body)?)
|
|
}
|
|
|
|
#[derive (Debug, PartialEq)]
|
|
enum LastSeen {
|
|
Negative,
|
|
Connected,
|
|
Description (String),
|
|
}
|
|
|
|
fn pretty_print_utc (
|
|
now: DateTime <Utc>,
|
|
last_seen: DateTime <Utc>
|
|
) -> String
|
|
{
|
|
let dur = now.signed_duration_since (last_seen);
|
|
|
|
if dur < chrono::Duration::zero () {
|
|
return last_seen.to_rfc3339_opts (SecondsFormat::Secs, true);
|
|
}
|
|
|
|
if dur.num_minutes () < 1 {
|
|
return format! ("{} s ago", dur.num_seconds ());
|
|
}
|
|
|
|
if dur.num_hours () < 1 {
|
|
return format! ("{} m ago", dur.num_minutes ());
|
|
}
|
|
|
|
last_seen.to_rfc3339_opts (SecondsFormat::Secs, true)
|
|
}
|
|
|
|
// Mnemonic is "now - last_seen"
|
|
|
|
fn pretty_print_last_seen (
|
|
now: DateTime <Utc>,
|
|
last_seen: DateTime <Utc>
|
|
) -> LastSeen
|
|
{
|
|
use LastSeen::*;
|
|
|
|
let dur = now.signed_duration_since (last_seen);
|
|
|
|
if dur < chrono::Duration::zero () {
|
|
return Negative;
|
|
}
|
|
|
|
if dur.num_minutes () < 1 {
|
|
return Connected;
|
|
}
|
|
|
|
if dur.num_hours () < 1 {
|
|
return Description (format! ("{} m ago", dur.num_minutes ()));
|
|
}
|
|
|
|
if dur.num_days () < 1 {
|
|
return Description (format! ("{} h ago", dur.num_hours ()));
|
|
}
|
|
|
|
Description (last_seen.to_rfc3339_opts (SecondsFormat::Secs, true))
|
|
}
|
|
|
|
#[derive (Serialize)]
|
|
struct ServerEntry <'a> {
|
|
name: String,
|
|
display_name: String,
|
|
last_seen: Cow <'a, str>,
|
|
}
|
|
|
|
#[derive (Serialize)]
|
|
struct ServerListPage <'a> {
|
|
dev_mode: bool,
|
|
git_version: Option <String>,
|
|
servers: Vec <ServerEntry <'a>>,
|
|
news_url: Option <String>,
|
|
connected_server_count: usize,
|
|
registered_server_count: usize,
|
|
date_rfc3339: String,
|
|
}
|
|
|
|
#[derive (Serialize)]
|
|
struct UnregisteredServerListPage {
|
|
unregistered_servers: Vec <UnregisteredServer>,
|
|
}
|
|
|
|
#[derive (Serialize)]
|
|
struct UnregisteredServer {
|
|
name: String,
|
|
tripcode: String,
|
|
last_seen: String,
|
|
}
|
|
|
|
#[derive (Serialize)]
|
|
struct AuditLogPage {
|
|
audit_log: Vec <AuditEntryPretty>,
|
|
}
|
|
|
|
#[derive (Serialize)]
|
|
struct AuditEntryPretty {
|
|
utc_pretty: String,
|
|
data_pretty: String,
|
|
}
|
|
|
|
async fn handle_server_list_internal (state: &Relay)
|
|
-> ServerListPage <'static>
|
|
{
|
|
use LastSeen::*;
|
|
|
|
let dev_mode;
|
|
let news_url;
|
|
|
|
{
|
|
let guard = state.config.read ().await;
|
|
dev_mode = guard.iso.dev_mode.is_some ();
|
|
news_url = guard.news_url.clone ();
|
|
}
|
|
let git_version = git_version::read ().await;
|
|
|
|
let server_list = scraper_api::v1_server_list (&state).await;
|
|
|
|
let now = Utc::now ();
|
|
|
|
let registered_server_count = server_list.servers.len ();
|
|
let mut connected_server_count = 0;
|
|
|
|
let servers = server_list.servers.into_iter ()
|
|
.map (|x| {
|
|
let last_seen = match x.last_seen {
|
|
None => "Never".into (),
|
|
Some (x) => match pretty_print_last_seen (now, x) {
|
|
Negative => "Error (negative time)".into (),
|
|
Connected => {
|
|
connected_server_count += 1;
|
|
"Connected".into ()
|
|
},
|
|
Description (s) => s.into (),
|
|
},
|
|
};
|
|
|
|
ServerEntry {
|
|
name: x.name,
|
|
display_name: x.display_name,
|
|
last_seen,
|
|
}
|
|
})
|
|
.collect ();
|
|
|
|
let date_rfc3339 = now.to_rfc3339_opts (SecondsFormat::Secs, true);
|
|
|
|
ServerListPage {
|
|
dev_mode,
|
|
git_version,
|
|
servers,
|
|
news_url,
|
|
connected_server_count,
|
|
registered_server_count,
|
|
date_rfc3339,
|
|
}
|
|
}
|
|
|
|
async fn handle_unregistered_servers_internal (state: &Relay)
|
|
-> UnregisteredServerListPage
|
|
{
|
|
use LastSeen::*;
|
|
|
|
let now = Utc::now ();
|
|
|
|
let mut server_list = state.unregistered_servers.to_vec ().await;
|
|
|
|
{
|
|
let me_config = state.me_config.read ().await;
|
|
server_list = server_list.into_iter ()
|
|
.filter (|s| ! me_config.servers.contains_key (&s.name))
|
|
.collect ();
|
|
}
|
|
|
|
server_list.sort_by_key (|s| {
|
|
(s.name.clone (), *s.tripcode.as_bytes (), now - s.seen)
|
|
});
|
|
server_list.dedup_by_key (|s| {
|
|
(s.name.clone (), s.tripcode)
|
|
});
|
|
|
|
let unregistered_servers = server_list.into_iter ()
|
|
.map (|x| {
|
|
let last_seen = match pretty_print_last_seen (now, x.seen) {
|
|
Negative => "Error (negative time)".into (),
|
|
Connected => "Recently".into (),
|
|
Description (s) => s,
|
|
};
|
|
|
|
let tripcode = base64::encode (x.tripcode.as_bytes ());
|
|
|
|
UnregisteredServer {
|
|
name: x.name,
|
|
tripcode,
|
|
last_seen,
|
|
}
|
|
}).collect ();
|
|
|
|
UnregisteredServerListPage {
|
|
unregistered_servers,
|
|
}
|
|
}
|
|
|
|
async fn handle_audit_log_internal (state: &Relay)
|
|
-> AuditLogPage
|
|
{
|
|
let utc_now = Utc::now ();
|
|
let audit_log = state.audit_log.to_vec ().await
|
|
.iter ().rev ().map (|e| {
|
|
AuditEntryPretty {
|
|
utc_pretty: pretty_print_utc (utc_now, e.time_utc),
|
|
data_pretty: format! ("{:?}", e.data),
|
|
}
|
|
}).collect ();
|
|
|
|
AuditLogPage {
|
|
audit_log,
|
|
}
|
|
}
|
|
|
|
async fn handle_server_list (
|
|
state: &Relay,
|
|
handlebars: Arc <Handlebars <'static>>
|
|
) -> Result <Response <Body>, RequestError>
|
|
{
|
|
let page = handle_server_list_internal (&state).await;
|
|
|
|
let s = handlebars.render ("server_list", &page)?;
|
|
Ok (ok_reply (s)?)
|
|
}
|
|
|
|
async fn handle_unregistered_servers (
|
|
state: &Relay,
|
|
handlebars: Arc <Handlebars <'static>>
|
|
) -> Result <Response <Body>, RequestError>
|
|
{
|
|
let page = handle_unregistered_servers_internal (&state).await;
|
|
|
|
let s = handlebars.render ("unregistered_servers", &page)?;
|
|
Ok (ok_reply (s)?)
|
|
}
|
|
|
|
async fn handle_audit_log (
|
|
state: &Relay,
|
|
handlebars: Arc <Handlebars <'static>>
|
|
) -> Result <Response <Body>, RequestError>
|
|
{
|
|
{
|
|
let cfg = state.config.read ().await;
|
|
if cfg.hide_audit_log {
|
|
return Ok (error_reply (StatusCode::FORBIDDEN, "Forbidden")?);
|
|
}
|
|
}
|
|
|
|
let page = handle_audit_log_internal (state).await;
|
|
|
|
let s = handlebars.render ("audit_log", &page)?;
|
|
Ok (ok_reply (s)?)
|
|
}
|
|
|
|
async fn handle_endless_sink (req: Request <Body>) -> Result <Response <Body>, http::Error>
|
|
{
|
|
let (_parts, mut body) = req.into_parts ();
|
|
|
|
let mut bytes_received = 0;
|
|
|
|
loop {
|
|
let item = body.next ().await;
|
|
|
|
if let Some (item) = item {
|
|
if let Ok (bytes) = &item {
|
|
bytes_received += bytes.len ();
|
|
}
|
|
}
|
|
else {
|
|
debug! ("Finished sinking debug bytes");
|
|
break;
|
|
}
|
|
}
|
|
|
|
Ok (ok_reply (format! ("Sank {} bytes\n", bytes_received))?)
|
|
}
|
|
|
|
async fn handle_endless_source (gib: usize, throttle: Option <usize>)
|
|
-> Result <Response <Body>, http::Error>
|
|
{
|
|
use tokio::sync::mpsc;
|
|
|
|
let block_bytes = 64 * 1024;
|
|
let num_blocks = (1024 * 1024 * 1024 / block_bytes) * gib;
|
|
|
|
let (tx, rx) = mpsc::channel (1);
|
|
|
|
tokio::spawn (async move {
|
|
let random_block = {
|
|
use rand::RngCore;
|
|
|
|
let mut rng = rand::thread_rng ();
|
|
let mut block = vec! [0_u8; 64 * 1024];
|
|
rng.fill_bytes (&mut block);
|
|
block
|
|
};
|
|
|
|
let mut interval = tokio::time::interval (Duration::from_millis (1000));
|
|
interval.set_missed_tick_behavior (tokio::time::MissedTickBehavior::Skip);
|
|
let mut blocks_sent = 0;
|
|
|
|
while blocks_sent < num_blocks {
|
|
if throttle.is_some () {
|
|
interval.tick ().await;
|
|
}
|
|
|
|
for _ in 0..throttle.unwrap_or (1) {
|
|
let item = Ok::<_, Infallible> (random_block.clone ());
|
|
if tx.send (item).await.is_err () {
|
|
debug! ("Endless source dropped");
|
|
return;
|
|
}
|
|
blocks_sent += 1;
|
|
}
|
|
}
|
|
|
|
debug! ("Endless source ended");
|
|
});
|
|
|
|
Response::builder ()
|
|
.status (StatusCode::OK)
|
|
.header ("content-type", "application/octet-stream")
|
|
.body (Body::wrap_stream (ReceiverStream::new (rx)))
|
|
}
|
|
|
|
async fn handle_gen_scraper_key (_state: &Relay)
|
|
-> Result <Response <Body>, http::Error>
|
|
{
|
|
let key = ptth_core::gen_key ();
|
|
|
|
let body = format! ("Random key: {}\n", key);
|
|
|
|
Response::builder ()
|
|
.status (StatusCode::OK)
|
|
.header ("content-type", "text/plain")
|
|
.body (Body::from (body))
|
|
}
|
|
|
|
async fn handle_register_server (req: Request <Body>, state: &Relay)
|
|
-> Result <(), anyhow::Error>
|
|
{
|
|
let (parts, body) = req.into_parts ();
|
|
let user = get_user_name (&parts);
|
|
|
|
let form_data = read_body_limited (body, 1_024).await?;
|
|
let server: crate::config::file::Server = serde_urlencoded::from_bytes (&form_data)?;
|
|
|
|
state.audit_log.push (AuditEvent::new (AuditData::RegisterServer {
|
|
user,
|
|
server: server.clone (),
|
|
})).await;
|
|
|
|
{
|
|
let mut me_config = state.me_config.write ().await;
|
|
|
|
me_config.servers.insert (server.name.clone (), server);
|
|
me_config.save (Path::new ("data/ptth_relay_me_config.toml")).await?;
|
|
}
|
|
|
|
Ok (())
|
|
}
|
|
|
|
async fn read_body_limited (mut body: Body, limit: usize) -> anyhow::Result <Vec <u8>>
|
|
{
|
|
let mut buffer = vec! [];
|
|
while let Some (chunk) = body.next ().await {
|
|
let chunk = chunk?;
|
|
|
|
if buffer.len () + chunk.len () > limit {
|
|
bail! ("Body was bigger than limit");
|
|
}
|
|
|
|
buffer.extend_from_slice (&chunk);
|
|
}
|
|
|
|
Ok (buffer)
|
|
}
|
|
|
|
#[instrument (level = "trace", skip (req, state, handlebars))]
|
|
async fn handle_all (
|
|
req: Request <Body>,
|
|
state: Arc <Relay>,
|
|
handlebars: Arc <Handlebars <'static>>
|
|
)
|
|
-> Result <Response <Body>, RequestError>
|
|
{
|
|
use routing::Route::*;
|
|
|
|
let state = &*state;
|
|
|
|
// The path is cloned here, so it's okay to consume the request
|
|
// later.
|
|
let path = req.uri ().path ().to_string ();
|
|
trace! ("Request path: {}", path);
|
|
|
|
let route = routing::route_url (req.method (), &path);
|
|
let route = match route {
|
|
Ok (x) => x,
|
|
Err (e) => {
|
|
use routing::Error;
|
|
|
|
let response = match e {
|
|
Error::BadUriFormat => error_reply (StatusCode::BAD_REQUEST, "Bad URI format")?,
|
|
Error::MethodNotAllowed => error_reply (StatusCode::METHOD_NOT_ALLOWED, "Method not allowed. Are you POST-ing to a GET-only url, or vice versa?")?,
|
|
Error::NotFound => error_reply (StatusCode::OK, "URL routing failed")?,
|
|
};
|
|
return Ok (response);
|
|
},
|
|
};
|
|
let response = match route {
|
|
ClientAuditLog => handle_audit_log (state, handlebars).await?,
|
|
ClientRelayIsUp => error_reply (StatusCode::OK, "Relay is up")?,
|
|
ClientServerGet {
|
|
listen_code,
|
|
path,
|
|
} => {
|
|
let (parts, _) = req.into_parts ();
|
|
|
|
let user = get_user_name (&parts);
|
|
state.audit_log.push (AuditEvent::new (AuditData::WebClientGet {
|
|
user,
|
|
server_name: listen_code.to_string (),
|
|
uri: path.to_string (),
|
|
})).await;
|
|
|
|
handle_http_request (parts, path.to_string (), &state, listen_code).await?
|
|
},
|
|
ClientServerList => handle_server_list (state, handlebars).await?,
|
|
ClientUnregisteredServers => handle_unregistered_servers (state, handlebars).await?,
|
|
Debug => {
|
|
let s = handlebars.render ("debug", &())?;
|
|
ok_reply (s)?
|
|
},
|
|
DebugEndlessSink => handle_endless_sink (req).await?,
|
|
DebugEndlessSource (throttle) => handle_endless_source (1, throttle).await?,
|
|
DebugGenKey => handle_gen_scraper_key (state).await?,
|
|
DebugMysteriousError => return Err (RequestError::Mysterious),
|
|
RegisterServer => {
|
|
match handle_register_server (req, state).await {
|
|
Ok (_) => Response::builder ()
|
|
.status (StatusCode::SEE_OTHER)
|
|
.header ("location", "unregistered_servers")
|
|
.body (Body::from ("Success. Redirecting..."))?,
|
|
Err (e) => error_reply (StatusCode::BAD_REQUEST, &format! ("{:?}", e))?,
|
|
}
|
|
}
|
|
Root => {
|
|
let s = handlebars.render ("root", &())?;
|
|
ok_reply (s)?
|
|
},
|
|
Scraper {
|
|
rest,
|
|
} => scraper_api::handle (req, state, rest).await?,
|
|
ServerHttpListen {
|
|
listen_code,
|
|
} => {
|
|
let api_key = req.headers ().get ("X-ApiKey");
|
|
|
|
let api_key = match api_key {
|
|
None => return Ok (error_reply (StatusCode::FORBIDDEN, "Forbidden")?),
|
|
Some (x) => x,
|
|
};
|
|
|
|
match check_server_api_key (state, listen_code, api_key.as_bytes ()).await {
|
|
Ok (_) => (),
|
|
Err (_) => return Ok (error_reply (StatusCode::FORBIDDEN, "Forbidden")?)
|
|
}
|
|
|
|
server_endpoint::handle_listen (state, listen_code.into ()).await?
|
|
},
|
|
ServerHttpResponse {
|
|
request_code,
|
|
} => {
|
|
let request_code = request_code.into ();
|
|
server_endpoint::handle_response (req, state, request_code).await?
|
|
},
|
|
};
|
|
Ok (response)
|
|
}
|
|
|
|
async fn check_server_api_key (state: &Relay, name: &str, api_key: &[u8])
|
|
-> Result <(), anyhow::Error>
|
|
{
|
|
let actual_tripcode = key_validity::BlakeHashWrapper::from_key (api_key);
|
|
|
|
let expected_human = {
|
|
let config = state.config.read ().await;
|
|
|
|
config.servers.get (name).map (|s| s.tripcode)
|
|
};
|
|
|
|
let expected_machine = {
|
|
let me_config = state.me_config.read ().await;
|
|
|
|
me_config.servers.get (name).map (|s| s.tripcode)
|
|
};
|
|
|
|
if expected_machine.is_none () && expected_human.is_none () {
|
|
state.unregistered_servers.push (crate::RejectedServer {
|
|
name: name.to_string (),
|
|
tripcode: *actual_tripcode,
|
|
seen: Utc::now (),
|
|
}).await;
|
|
bail! ("Denied API request for non-existent server name {}", name);
|
|
}
|
|
|
|
if Some (actual_tripcode) == expected_human {
|
|
return Ok (());
|
|
}
|
|
if Some (actual_tripcode) == expected_machine {
|
|
return Ok (());
|
|
}
|
|
|
|
bail! ("Denied API request for bad tripcode {}", base64::encode (actual_tripcode.as_bytes ()));
|
|
}
|
|
|
|
fn load_templates (asset_root: &Path)
|
|
-> Result <Handlebars <'static>, RelayError>
|
|
{
|
|
let mut handlebars = Handlebars::new ();
|
|
handlebars.set_strict_mode (true);
|
|
|
|
let asset_root = asset_root.join ("handlebars/relay");
|
|
|
|
for (k, v) in &[
|
|
("audit_log", "audit_log.hbs"),
|
|
("debug", "debug.hbs"),
|
|
("root", "root.hbs"),
|
|
("server_list", "server_list.hbs"),
|
|
("unregistered_servers", "unregistered_servers.hbs"),
|
|
] {
|
|
handlebars.register_template_file (k, &asset_root.join (v))?;
|
|
}
|
|
|
|
Ok (handlebars)
|
|
}
|
|
|
|
async fn reload_config (
|
|
state: &Arc <Relay>,
|
|
config_reload_path: &Path
|
|
) -> Result <(), ConfigError> {
|
|
// Reload human-editable config
|
|
let new_config = Config::from_file (config_reload_path).await?;
|
|
|
|
// Reload machine-editable config, if possible
|
|
// let me_config = machine_editable::Config::from_file (Path::new ("data/ptth_relay_me_config.toml")).await.ok ();
|
|
|
|
let mut config = state.config.write ().await;
|
|
|
|
trace! ("Reloading config");
|
|
if config.servers.len () != new_config.servers.len () {
|
|
debug! ("Loaded {} server configs", new_config.servers.len ());
|
|
}
|
|
if config.iso.enable_scraper_api != new_config.iso.enable_scraper_api {
|
|
debug! ("enable_scraper_api: {}", new_config.iso.enable_scraper_api);
|
|
}
|
|
|
|
(*config) = new_config;
|
|
|
|
if config.iso.dev_mode.is_some () {
|
|
error! ("Dev mode is enabled! This might turn off some security features. If you see this in production, escalate it to someone!");
|
|
}
|
|
|
|
Ok (())
|
|
}
|
|
|
|
pub async fn run_relay (
|
|
state: Arc <Relay>,
|
|
asset_root: &Path,
|
|
shutdown_oneshot: oneshot::Receiver <()>,
|
|
config_reload_path: Option <PathBuf>
|
|
)
|
|
-> Result <(), RelayError>
|
|
{
|
|
let handlebars = Arc::new (load_templates (asset_root)?);
|
|
|
|
if let Some (x) = git_version::read ().await {
|
|
info! ("ptth_relay Git version: {:?}", x);
|
|
}
|
|
else {
|
|
info! ("ptth_relay not built from Git");
|
|
}
|
|
|
|
if let Some (config_reload_path) = config_reload_path {
|
|
let state_2 = state.clone ();
|
|
tokio::spawn (async move {
|
|
let mut reload_interval = tokio::time::interval (Duration::from_secs (60));
|
|
reload_interval.set_missed_tick_behavior (tokio::time::MissedTickBehavior::Skip);
|
|
|
|
loop {
|
|
reload_interval.tick ().await;
|
|
reload_config (&state_2, &config_reload_path).await.ok ();
|
|
}
|
|
});
|
|
}
|
|
|
|
// Set a task to periodically sweep and time-out requests where the client
|
|
// and server are never going to rendezvous
|
|
|
|
let state_2 = Arc::clone (&state);
|
|
tokio::spawn (async move {
|
|
let mut interval = tokio::time::interval (Duration::from_secs (60));
|
|
interval.set_missed_tick_behavior (tokio::time::MissedTickBehavior::Skip);
|
|
|
|
loop {
|
|
use std::convert::TryFrom;
|
|
|
|
use rusty_ulid::Ulid;
|
|
|
|
interval.tick ().await;
|
|
|
|
{
|
|
let timeout_ms = Utc::now ().timestamp () - 120_000;
|
|
if let Ok (timeout_ms) = u64::try_from (timeout_ms) {
|
|
let timeout_ulid = Ulid::from_timestamp_with_rng (timeout_ms, &mut rand::thread_rng ()).to_string ();
|
|
|
|
let mut request_rendezvous = state_2.request_rendezvous.lock ().await;
|
|
request_rendezvous.iter_mut ()
|
|
.for_each (|(k, v)| {
|
|
match v {
|
|
RequestRendezvous::ParkedServer (_) => (),
|
|
RequestRendezvous::ParkedClients (requests) => requests.retain (|req| req.id.as_str () >= timeout_ulid.as_str ()),
|
|
}
|
|
});
|
|
}
|
|
}
|
|
|
|
{
|
|
let now = Instant::now ();
|
|
let response_rendezvous = state_2.response_rendezvous.read ().await;
|
|
response_rendezvous.retain (|_, v| v.timeout >= now);
|
|
}
|
|
}
|
|
});
|
|
|
|
let make_svc = make_service_fn (|_conn| {
|
|
let state = state.clone ();
|
|
let handlebars = handlebars.clone ();
|
|
|
|
async {
|
|
Ok::<_, Infallible> (service_fn (move |req| {
|
|
let state = state.clone ();
|
|
let handlebars = handlebars.clone ();
|
|
|
|
async {
|
|
Ok::<_, Infallible> (handle_all (req, state, handlebars).await.unwrap_or_else (|e| {
|
|
use RequestError::*;
|
|
error! ("{}", e);
|
|
|
|
let status_code = match &e {
|
|
UnknownServer => StatusCode::NOT_FOUND,
|
|
BadRequest => StatusCode::BAD_REQUEST,
|
|
ServerNeverResponded | ServerTimedOut => StatusCode::GATEWAY_TIMEOUT,
|
|
|
|
_ => StatusCode::INTERNAL_SERVER_ERROR,
|
|
};
|
|
|
|
error_reply (status_code, "Error in relay").unwrap ()
|
|
}))
|
|
}
|
|
}))
|
|
}
|
|
});
|
|
|
|
let addr = {
|
|
let guard = state.config.read ().await;
|
|
|
|
SocketAddr::from ((
|
|
guard.address,
|
|
guard.port.unwrap_or (4000),
|
|
))
|
|
};
|
|
|
|
let server = Server::bind (&addr)
|
|
.serve (make_svc);
|
|
|
|
state.audit_log.push (AuditEvent::new (AuditData::RelayStart)).await;
|
|
|
|
trace! ("Serving relay on {:?}", addr);
|
|
|
|
server.with_graceful_shutdown (async {
|
|
use ShuttingDownError::ShuttingDown;
|
|
|
|
shutdown_oneshot.await.ok ();
|
|
|
|
state.shutdown_watch_tx.send (true).expect ("Can't broadcast graceful shutdown");
|
|
|
|
let mut response_rendezvous = state.response_rendezvous.write ().await;
|
|
let mut swapped = DashMap::default ();
|
|
|
|
std::mem::swap (&mut swapped, &mut response_rendezvous);
|
|
|
|
for (_, sender) in swapped {
|
|
sender.tx.send (Err (ShuttingDown)).ok ();
|
|
}
|
|
|
|
let mut request_rendezvous = state.request_rendezvous.lock ().await;
|
|
|
|
for (_, x) in request_rendezvous.drain () {
|
|
use RequestRendezvous::*;
|
|
|
|
match x {
|
|
ParkedClients (_) => (),
|
|
ParkedServer (sender) => drop (sender.send (Err (ShuttingDown))),
|
|
}
|
|
}
|
|
|
|
debug! ("Performed all cleanup");
|
|
}).await?;
|
|
|
|
Ok (())
|
|
}
|
|
|
|
#[cfg (test)]
|
|
mod tests;
|