#![warn (clippy::pedantic)] // I don't see the point in documenting the errors outside of where the // error type is defined. #![allow (clippy::missing_errors_doc)] // False positive on futures::select! macro #![allow (clippy::mut_mut)] use std::{ path::PathBuf, sync::Arc, time::Duration, }; use futures::FutureExt; use reqwest::Client; use serde::Deserialize; use tokio::{ sync::{ oneshot, }, }; use tokio_stream::wrappers::ReceiverStream; use ptth_core::{ http_serde, prelude::*, }; pub mod errors; pub mod file_server; pub mod load_toml; use errors::ServerError; // Thanks to https://github.com/robsheldon/bad-passwords-index const BAD_PASSWORDS: &[u8] = include_bytes! ("bad_passwords.txt"); #[must_use] pub fn password_is_bad (mut password: String) -> bool { password.make_ascii_lowercase (); let ac = aho_corasick::AhoCorasick::new (&[ password ]); ac.find (BAD_PASSWORDS).is_some () } struct State { file_server: file_server::State, config: Config, client: Client, } // Unwrap a request from PTTH format and pass it into file_server. // When file_server responds, wrap it back up and stream it to the relay. async fn handle_one_req ( state: &Arc , wrapped_req: http_serde::WrappedRequest ) -> Result <(), ServerError> { let (req_id, parts) = (wrapped_req.id, wrapped_req.req); debug! ("Handling request {}", req_id); let default_root = PathBuf::from ("./"); let file_server_root: &std::path::Path = state.file_server.config.file_server_root .as_ref () .unwrap_or (&default_root); let response = file_server::serve_all ( &state.file_server, file_server_root, parts.method, &parts.uri, &parts.headers, ).await?; let mut resp_req = state.client .post (&format! ("{}/http_response/{}", state.config.relay_url, req_id)) .header (ptth_core::PTTH_MAGIC_HEADER, base64::encode (rmp_serde::to_vec (&response.parts).map_err (ServerError::MessagePackEncodeResponse)?)); if let Some (length) = response.content_length { resp_req = resp_req.header ("Content-Length", length.to_string ()); } if let Some (body) = response.body { resp_req = resp_req.body (reqwest::Body::wrap_stream (ReceiverStream::new (body))); } let req = resp_req.build ().map_err (ServerError::Step5Responding)?; debug! ("{:?}", req.headers ()); //println! ("Step 6"); match state.client.execute (req).await { Ok (r) => { let status = r.status (); let text = r.text ().await.map_err (ServerError::Step7AfterResponse)?; debug! ("{:?} {:?}", status, text); }, Err (e) => { if e.is_request () { warn! ("Error while POSTing response. Client probably hung up."); } else { error! ("Err: {:?}", e); } }, } Ok::<(), ServerError> (()) } async fn handle_req_resp ( state: &Arc , req_resp: reqwest::Response ) -> Result <(), ServerError> { //println! ("Step 1"); let body = req_resp.bytes ().await.map_err (ServerError::CantCollectWrappedRequests)?; let wrapped_reqs: Vec = match rmp_serde::from_read_ref (&body) { Ok (x) => x, Err (e) => { error! ("Can't parse wrapped requests: {:?}", e); return Err (ServerError::CantParseWrappedRequests (e)); }, }; debug! ("Unwrapped {} requests", wrapped_reqs.len ()); for wrapped_req in wrapped_reqs { let state = state.clone (); // These have to detach, so we won't be able to catch the join errors. tokio::spawn (async move { handle_one_req (&state, wrapped_req).await }); } Ok (()) } #[derive (Default, Deserialize)] pub struct ConfigFile { pub name: String, pub api_key: String, pub relay_url: String, pub file_server_root: Option , } impl ConfigFile { #[must_use] pub fn tripcode (&self) -> String { base64::encode (blake3::hash (self.api_key.as_bytes ()).as_bytes ()) } } #[derive (Default)] pub struct Config { pub relay_url: String, } pub async fn run_server ( config_file: ConfigFile, shutdown_oneshot: oneshot::Receiver <()>, hidden_path: Option , asset_root: Option ) -> Result <(), ServerError> { use std::{ convert::TryInto, }; use arc_swap::ArcSwap; use http::status::StatusCode; let asset_root = asset_root.unwrap_or_else (PathBuf::new); if password_is_bad (config_file.api_key.clone ()) { return Err (ServerError::WeakApiKey); } info! ("Server name is {}", config_file.name); info! ("Tripcode is {}", config_file.tripcode ()); let mut headers = reqwest::header::HeaderMap::new (); headers.insert ("X-ApiKey", config_file.api_key.try_into ().map_err (ServerError::ApiKeyInvalid)?); let client = Client::builder () .default_headers (headers) .timeout (Duration::from_secs (40)) .build ().map_err (ServerError::CantBuildHttpClient)?; let handlebars = file_server::load_templates (&asset_root)?; let metrics_startup = file_server::metrics::Startup::new (config_file.name); let metrics_interval = Arc::new (ArcSwap::default ()); let interval_writer = Arc::clone (&metrics_interval); tokio::spawn (async move { file_server::metrics::Interval::monitor (interval_writer).await; }); let state = Arc::new (State { file_server: file_server::State { config: file_server::Config { file_server_root: config_file.file_server_root, }, handlebars, metrics_startup, metrics_interval, hidden_path, }, config: Config { relay_url: config_file.relay_url, }, client, }); let mut backoff_delay = 0; let mut shutdown_oneshot = shutdown_oneshot.fuse (); loop { // TODO: Extract loop body to function? if backoff_delay > 0 { let sleep = tokio::time::sleep (Duration::from_millis (backoff_delay)); tokio::pin! (sleep); tokio::select! { _ = &mut sleep => {}, _ = &mut shutdown_oneshot => { info! ("Received graceful shutdown"); break; }, } } debug! ("http_listen"); let req_req = state.client.get (&format! ("{}/http_listen/{}", state.config.relay_url, state.file_server.metrics_startup.server_name)).send (); let err_backoff_delay = std::cmp::min (30_000, backoff_delay * 2 + 500); let req_req = futures::select! { r = req_req.fuse () => r, _ = shutdown_oneshot => { info! ("Received graceful shutdown"); break; }, }; let req_resp = match req_req { Err (e) => { if e.is_timeout () { error! ("Client-side timeout. Is an overly-aggressive firewall closing long-lived connections? Is the network flakey?"); } else { error! ("Err: {:?}", e); if backoff_delay != err_backoff_delay { error! ("Non-timeout issue, increasing backoff_delay"); backoff_delay = err_backoff_delay; } } continue; }, Ok (x) => x, }; if req_resp.status () == StatusCode::NO_CONTENT { debug! ("http_listen long poll timed out on the server, good."); continue; } else if req_resp.status () != StatusCode::OK { error! ("{}", req_resp.status ()); let body = req_resp.bytes ().await.map_err (ServerError::Step3CollectBody)?; let body = String::from_utf8 (body.to_vec ()).map_err (ServerError::Step3ErrorResponseNotUtf8)?; error! ("{}", body); if backoff_delay != err_backoff_delay { error! ("Non-timeout issue, increasing backoff_delay"); backoff_delay = err_backoff_delay; } continue; } // Unpack the requests, spawn them into new tasks, then loop back // around. if handle_req_resp (&state, req_resp).await.is_err () { backoff_delay = err_backoff_delay; continue; } if backoff_delay != 0 { debug! ("backoff_delay = 0"); backoff_delay = 0; } } info! ("Exiting"); Ok (()) } #[cfg (test)] mod tests { use super::*; #[test] fn tripcode_algo () { let config = ConfigFile { name: "TestName".into (), api_key: "PlaypenCausalPlatformCommodeImproveCatalyze".into (), relay_url: "".into (), file_server_root: None, }; assert_eq! (config.tripcode (), "A9rPwZyY89Ag4TJjMoyYA2NeGOm99Je6rq1s0rg8PfY=".to_string ()); } #[test] fn check_bad_passwords () { for pw in &[ "", " ", "user", "password", "pAsSwOrD", "secret", "123123", ] { assert! (password_is_bad (pw.to_string ())); } use rand::prelude::*; let mut entropy = [0u8; 32]; thread_rng ().fill_bytes (&mut entropy); let good_password = base64::encode (entropy); assert! (! password_is_bad (good_password)); } }