ptth/crates/ptth_server/src/lib.rs

350 lines
8.1 KiB
Rust

#![warn (clippy::pedantic)]
// I don't see the point in documenting the errors outside of where the
// error type is defined.
#![allow (clippy::missing_errors_doc)]
// False positive on futures::select! macro
#![allow (clippy::mut_mut)]
use std::{
path::PathBuf,
sync::Arc,
time::Duration,
};
use futures::FutureExt;
use reqwest::Client;
use serde::Deserialize;
use tokio::{
sync::{
oneshot,
},
time::delay_for,
};
use ptth_core::{
http_serde,
prelude::*,
};
pub mod errors;
pub mod file_server;
pub mod load_toml;
use errors::ServerError;
// Thanks to https://github.com/robsheldon/bad-passwords-index
const BAD_PASSWORDS: &[u8] = include_bytes! ("bad_passwords.txt");
#[must_use]
pub fn password_is_bad (mut password: String) -> bool {
password.make_ascii_lowercase ();
let ac = aho_corasick::AhoCorasick::new (&[
password
]);
ac.find (BAD_PASSWORDS).is_some ()
}
struct State {
file_server: file_server::State,
config: Config,
client: Client,
}
// Unwrap a request from PTTH format and pass it into file_server.
// When file_server responds, wrap it back up and stream it to the relay.
async fn handle_one_req (
state: &Arc <State>,
wrapped_req: http_serde::WrappedRequest
) -> Result <(), ServerError>
{
let (req_id, parts) = (wrapped_req.id, wrapped_req.req);
debug! ("Handling request {}", req_id);
let default_root = PathBuf::from ("./");
let file_server_root: &std::path::Path = state.file_server.config.file_server_root
.as_ref ()
.unwrap_or (&default_root);
let response = file_server::serve_all (
&state.file_server,
file_server_root,
parts.method,
&parts.uri,
&parts.headers,
).await?;
let mut resp_req = state.client
.post (&format! ("{}/http_response/{}", state.config.relay_url, req_id))
.header (ptth_core::PTTH_MAGIC_HEADER, base64::encode (rmp_serde::to_vec (&response.parts).map_err (ServerError::MessagePackEncodeResponse)?));
if let Some (length) = response.content_length {
resp_req = resp_req.header ("Content-Length", length.to_string ());
}
if let Some (body) = response.body {
resp_req = resp_req.body (reqwest::Body::wrap_stream (body));
}
let req = resp_req.build ().map_err (ServerError::Step5Responding)?;
debug! ("{:?}", req.headers ());
//println! ("Step 6");
match state.client.execute (req).await {
Ok (r) => {
let status = r.status ();
let text = r.text ().await.map_err (ServerError::Step7AfterResponse)?;
debug! ("{:?} {:?}", status, text);
},
Err (e) => {
if e.is_request () {
warn! ("Error while POSTing response. Client probably hung up.");
}
else {
error! ("Err: {:?}", e);
}
},
}
Ok::<(), ServerError> (())
}
async fn handle_req_resp (
state: &Arc <State>,
req_resp: reqwest::Response
) -> Result <(), ServerError>
{
//println! ("Step 1");
let body = req_resp.bytes ().await.map_err (ServerError::CantCollectWrappedRequests)?;
let wrapped_reqs: Vec <http_serde::WrappedRequest> = match rmp_serde::from_read_ref (&body)
{
Ok (x) => x,
Err (e) => {
error! ("Can't parse wrapped requests: {:?}", e);
return Err (ServerError::CantParseWrappedRequests (e));
},
};
debug! ("Unwrapped {} requests", wrapped_reqs.len ());
for wrapped_req in wrapped_reqs {
let state = state.clone ();
// These have to detach, so we won't be able to catch the join errors.
tokio::spawn (async move {
handle_one_req (&state, wrapped_req).await
});
}
Ok (())
}
#[derive (Default, Deserialize)]
pub struct ConfigFile {
pub name: String,
pub api_key: String,
pub relay_url: String,
pub file_server_root: Option <PathBuf>,
}
impl ConfigFile {
#[must_use]
pub fn tripcode (&self) -> String {
base64::encode (blake3::hash (self.api_key.as_bytes ()).as_bytes ())
}
}
#[derive (Default)]
pub struct Config {
pub relay_url: String,
}
pub async fn run_server (
config_file: ConfigFile,
shutdown_oneshot: oneshot::Receiver <()>,
hidden_path: Option <PathBuf>,
asset_root: Option <PathBuf>
)
-> Result <(), ServerError>
{
use std::{
convert::TryInto,
};
use arc_swap::ArcSwap;
use http::status::StatusCode;
let asset_root = asset_root.unwrap_or_else (PathBuf::new);
if password_is_bad (config_file.api_key.clone ()) {
return Err (ServerError::WeakApiKey);
}
info! ("Server name is {}", config_file.name);
info! ("Tripcode is {}", config_file.tripcode ());
let mut headers = reqwest::header::HeaderMap::new ();
headers.insert ("X-ApiKey", config_file.api_key.try_into ().map_err (ServerError::ApiKeyInvalid)?);
let client = Client::builder ()
.default_headers (headers)
.timeout (Duration::from_secs (40))
.build ().map_err (ServerError::CantBuildHttpClient)?;
let handlebars = file_server::load_templates (&asset_root)?;
let metrics_startup = file_server::metrics::Startup::new (config_file.name);
let metrics_interval = Arc::new (ArcSwap::default ());
let interval_writer = Arc::clone (&metrics_interval);
tokio::spawn (async move {
file_server::metrics::Interval::monitor (interval_writer).await;
});
let state = Arc::new (State {
file_server: file_server::State {
config: file_server::Config {
file_server_root: config_file.file_server_root,
},
handlebars,
metrics_startup,
metrics_interval,
hidden_path,
},
config: Config {
relay_url: config_file.relay_url,
},
client,
});
let mut backoff_delay = 0;
let mut shutdown_oneshot = shutdown_oneshot.fuse ();
loop {
// TODO: Extract loop body to function?
if backoff_delay > 0 {
let mut delay = delay_for (Duration::from_millis (backoff_delay)).fuse ();
if futures::select! (
_ = delay => false,
_ = shutdown_oneshot => true,
) {
info! ("Received graceful shutdown");
break;
}
}
debug! ("http_listen");
let req_req = state.client.get (&format! ("{}/http_listen/{}", state.config.relay_url, state.file_server.metrics_startup.server_name)).send ();
let err_backoff_delay = std::cmp::min (30_000, backoff_delay * 2 + 500);
let req_req = futures::select! {
r = req_req.fuse () => r,
_ = shutdown_oneshot => {
info! ("Received graceful shutdown");
break;
},
};
let req_resp = match req_req {
Err (e) => {
if e.is_timeout () {
error! ("Client-side timeout. Is an overly-aggressive firewall closing long-lived connections? Is the network flakey?");
}
else {
error! ("Err: {:?}", e);
if backoff_delay != err_backoff_delay {
error! ("Non-timeout issue, increasing backoff_delay");
backoff_delay = err_backoff_delay;
}
}
continue;
},
Ok (x) => x,
};
if req_resp.status () == StatusCode::NO_CONTENT {
debug! ("http_listen long poll timed out on the server, good.");
continue;
}
else if req_resp.status () != StatusCode::OK {
error! ("{}", req_resp.status ());
let body = req_resp.bytes ().await.map_err (ServerError::Step3CollectBody)?;
let body = String::from_utf8 (body.to_vec ()).map_err (ServerError::Step3ErrorResponseNotUtf8)?;
error! ("{}", body);
if backoff_delay != err_backoff_delay {
error! ("Non-timeout issue, increasing backoff_delay");
backoff_delay = err_backoff_delay;
}
continue;
}
// Unpack the requests, spawn them into new tasks, then loop back
// around.
if handle_req_resp (&state, req_resp).await.is_err () {
backoff_delay = err_backoff_delay;
continue;
}
if backoff_delay != 0 {
debug! ("backoff_delay = 0");
backoff_delay = 0;
}
}
info! ("Exiting");
Ok (())
}
#[cfg (test)]
mod tests {
use super::*;
#[test]
fn tripcode_algo () {
let config = ConfigFile {
name: "TestName".into (),
api_key: "PlaypenCausalPlatformCommodeImproveCatalyze".into (),
relay_url: "".into (),
file_server_root: None,
};
assert_eq! (config.tripcode (), "A9rPwZyY89Ag4TJjMoyYA2NeGOm99Je6rq1s0rg8PfY=".to_string ());
}
#[test]
fn check_bad_passwords () {
for pw in &[
"",
" ",
"user",
"password",
"pAsSwOrD",
"secret",
"123123",
] {
assert! (password_is_bad (pw.to_string ()));
}
use rand::prelude::*;
let mut entropy = [0u8; 32];
thread_rng ().fill_bytes (&mut entropy);
let good_password = base64::encode (entropy);
assert! (! password_is_bad (good_password));
}
}