diff --git a/src/bin/ptth_relay.rs b/src/bin/ptth_relay.rs index dd3ebd2..8641805 100644 --- a/src/bin/ptth_relay.rs +++ b/src/bin/ptth_relay.rs @@ -3,14 +3,18 @@ use std::{ sync::Arc, }; +use tracing::{info}; + use ptth::relay; use ptth::relay::RelayState; #[tokio::main] async fn main () -> Result <(), Box > { + tracing_subscriber::fmt::init (); + let config_file = ptth::load_toml::load ("config/ptth_relay.toml"); - eprintln! ("ptth_relay Git version: {:?}", ptth::git_version::GIT_VERSION); + info! ("ptth_relay Git version: {:?}", ptth::git_version::GIT_VERSION); relay::run_relay ( Arc::new (RelayState::from (&config_file)), diff --git a/src/http_serde.rs b/src/http_serde.rs index 0d94ad3..bc52547 100644 --- a/src/http_serde.rs +++ b/src/http_serde.rs @@ -4,7 +4,7 @@ use std::{ }; use serde::{Deserialize, Serialize}; -use tokio::sync::mpsc::Receiver; +use tokio::sync::mpsc; // Hyper doesn't seem to make it easy to de/ser requests // and responses and stuff like that, so I do it by hand here. @@ -20,7 +20,7 @@ impl From for Error { } } -#[derive (Deserialize, Serialize)] +#[derive (Debug, Deserialize, Serialize)] pub enum Method { Get, Head, @@ -83,11 +83,15 @@ pub struct WrappedRequest { pub req: RequestParts, } -#[derive (Deserialize, Serialize)] +#[derive (Debug, Deserialize, Serialize)] pub enum StatusCode { - Ok, - NotFound, - PartialContent, + Ok, // 200 + PartialContent, // 206 + + BadRequest, // 400 + Forbidden, // 403 + NotFound, // 404 + MethodNotAllowed, // 405 } impl Default for StatusCode { @@ -100,8 +104,12 @@ impl From for hyper::StatusCode { fn from (x: StatusCode) -> Self { match x { StatusCode::Ok => Self::OK, - StatusCode::NotFound => Self::NOT_FOUND, StatusCode::PartialContent => Self::PARTIAL_CONTENT, + + StatusCode::BadRequest => Self::BAD_REQUEST, + StatusCode::Forbidden => Self::FORBIDDEN, + StatusCode::NotFound => Self::NOT_FOUND, + StatusCode::MethodNotAllowed => Self::METHOD_NOT_ALLOWED, } } } @@ -120,7 +128,7 @@ pub struct ResponseParts { // reqwest and hyper have different Body types for _some reason_ // so I have to do everything myself -type Body = Receiver , std::convert::Infallible>>; +type Body = mpsc::Receiver , std::convert::Infallible>>; #[derive (Default)] pub struct Response { diff --git a/src/relay/mod.rs b/src/relay/mod.rs index d615a80..803e30f 100644 --- a/src/relay/mod.rs +++ b/src/relay/mod.rs @@ -10,6 +10,7 @@ use std::{ }; use dashmap::DashMap; +use futures::stream::StreamExt; use handlebars::Handlebars; use hyper::{ Body, @@ -25,11 +26,14 @@ use serde::{ Serialize, }; use tokio::{ + spawn, sync::{ Mutex, + mpsc, oneshot, }, }; +use tracing::{debug, error, info, trace, warn}; use crate::{ http_serde, @@ -153,6 +157,9 @@ fn status_reply > (status: StatusCode, b: B) Response::builder ().status (status).body (b.into ()).unwrap () } +// Servers will come here and either handle queued requests from parked clients, +// or park themselves until a request comes in. + async fn handle_http_listen ( state: Arc , watcher_code: String, @@ -164,7 +171,7 @@ async fn handle_http_listen ( let expected_tripcode = match state.config.server_tripcodes.get (&watcher_code) { None => { - eprintln! ("Denied http_listen for non-existent server name {}", watcher_code); + error! ("Denied http_listen for non-existent server name {}", watcher_code); return trip_error; }, Some (x) => x, @@ -172,7 +179,7 @@ async fn handle_http_listen ( let actual_tripcode = blake3::hash (api_key); if expected_tripcode != &actual_tripcode { - eprintln! ("Denied http_listen for bad tripcode {}", base64::encode (actual_tripcode.as_bytes ())); + error! ("Denied http_listen for bad tripcode {}", base64::encode (actual_tripcode.as_bytes ())); return trip_error; } @@ -204,6 +211,8 @@ async fn handle_http_listen ( status_reply (StatusCode::OK, rmp_serde::to_vec (&vec! [one_req]).unwrap ()) } +// Servers will come here to stream responses to clients + async fn handle_http_response ( req: Request , state: Arc , @@ -211,21 +220,63 @@ async fn handle_http_response ( ) -> Response { - let (parts, body) = req.into_parts (); + let (parts, mut body) = req.into_parts (); let resp_parts: http_serde::ResponseParts = rmp_serde::from_read_ref (&base64::decode (parts.headers.get (crate::PTTH_MAGIC_HEADER).unwrap ()).unwrap ()).unwrap (); + // Intercept the body packets here so we can check when the stream + // ends or errors out + + let (mut body_tx, body_rx) = mpsc::channel (2); + + spawn (async move { + loop { + let item = body.next ().await; + + if let Some (item) = item { + if let Ok (bytes) = &item { + trace! ("Relaying {} bytes", bytes.len ()); + } + + if body_tx.send (item).await.is_err () { + error! ("Error relaying bytes"); + break; + } + } + else { + debug! ("Finished relaying bytes"); + break; + } + } + }); + + let body = Body::wrap_stream (body_rx); + match state.response_rendezvous.remove (&req_id) { Some ((_, tx)) => { + // UKAUFFY4 (Send half) match tx.send ((resp_parts, body)) { - Ok (()) => status_reply (StatusCode::OK, "Connected to remote client...\n"), - _ => status_reply (StatusCode::BAD_GATEWAY, "Failed to connect to client"), + Ok (()) => { + debug! ("Responding to server"); + status_reply (StatusCode::OK, "http_response completed.") + }, + _ => { + let msg = "Failed to connect to client"; + error! (msg); + status_reply (StatusCode::BAD_GATEWAY, msg) + }, } }, - None => status_reply (StatusCode::BAD_REQUEST, "Request ID not found in response_rendezvous"), + None => { + error! ("Server tried to respond to non-existent request"); + status_reply (StatusCode::BAD_REQUEST, "Request ID not found in response_rendezvous") + }, } } +// Clients will come here to start requests, and always park for at least +// a short amount of time. + async fn handle_http_request ( req: http::request::Parts, uri: String, @@ -286,6 +337,7 @@ async fn handle_http_request ( }, }; + // UKAUFFY4 (Receive half) match received { Ok ((parts, body)) => { let mut resp = Response::builder () @@ -393,8 +445,6 @@ pub fn load_templates () Ok (handlebars) } -use tracing::info; - pub async fn run_relay ( state: Arc , shutdown_oneshot: oneshot::Receiver <()> @@ -433,7 +483,6 @@ pub async fn run_relay ( let server = Server::bind (&addr) .serve (make_svc); - info! ("Configured for graceful shutdown"); server.with_graceful_shutdown (async { shutdown_oneshot.await.ok (); @@ -451,30 +500,5 @@ pub async fn run_relay ( #[cfg (test)] mod tests { - use std::time::Duration; - use tokio::{ - runtime::Runtime, - spawn, - sync::oneshot, - time::delay_for, - }; - - #[test] - fn so_crazy_it_might_work () { - let mut rt = Runtime::new ().unwrap (); - - rt.block_on (async { - let (tx, rx) = oneshot::channel (); - - let task_1 = spawn (async move { - delay_for (Duration::from_secs (1)).await; - tx.send (()).unwrap (); - }); - - rx.await.unwrap (); - - task_1.await.unwrap (); - }); - } } diff --git a/src/server/file_server.rs b/src/server/file_server.rs index b150296..1105993 100644 --- a/src/server/file_server.rs +++ b/src/server/file_server.rs @@ -21,6 +21,10 @@ use tokio::{ channel, }, }; +use tracing::{ + debug, error, info, + instrument, +}; use regex::Regex; @@ -33,7 +37,7 @@ fn parse_range_header (range_str: &str) -> (Option , Option ) { static ref RE: Regex = Regex::new (r"^bytes=(\d*)-(\d*)$").expect ("Couldn't compile regex for Range header"); } - println! ("{}", range_str); + debug! ("{}", range_str); let caps = match RE.captures (range_str) { Some (x) => x, @@ -98,6 +102,7 @@ async fn read_dir_entry (entry: DirEntry) -> TemplateDirEntry use std::borrow::Cow; +#[instrument (level = "debug", skip (handlebars, dir))] async fn serve_dir ( handlebars: &Handlebars <'static>, path: Cow <'_, str>, @@ -110,6 +115,8 @@ async fn serve_dir ( entries.push (read_dir_entry (entry).await); } + entries.sort_unstable_by (|a, b| a.file_name.partial_cmp (&b.file_name).unwrap ()); + #[derive (Serialize)] struct TemplateDirPage <'a> { path: Cow <'a, str>, @@ -132,6 +139,7 @@ async fn serve_dir ( resp } +#[instrument (level = "debug", skip (f))] async fn serve_file ( mut f: File, should_send_body: bool, @@ -157,7 +165,7 @@ async fn serve_file ( f.seek (SeekFrom::Start (start)).await.unwrap (); - println! ("Serving range {}-{}", start, end); + info! ("Serving range {}-{}", start, end); if should_send_body { tokio::spawn (async move { @@ -180,18 +188,18 @@ async fn serve_file ( } if tx.send (Ok::<_, Infallible> (buffer)).await.is_err () { - eprintln! ("Send failed while streaming file ({} bytes sent)", bytes_sent); + error! ("Send failed while streaming file ({} bytes sent)", bytes_sent); break; } bytes_left -= bytes_read; if bytes_left == 0 { - eprintln! ("Finished streaming file"); + info! ("Finished"); break; } bytes_sent += bytes_read; - println! ("Sent {} bytes", bytes_sent); + debug! ("Sent {} bytes", bytes_sent); //delay_for (Duration::from_millis (50)).await; } @@ -234,6 +242,7 @@ async fn serve_error ( resp } +#[instrument (level = "debug", skip (handlebars))] pub async fn serve_all ( handlebars: &Handlebars <'static>, root: &Path, @@ -243,7 +252,7 @@ pub async fn serve_all ( ) -> http_serde::Response { - println! ("Client requested {}", uri); + info! ("Client requested {}", uri); let mut range_start = None; let mut range_end = None; @@ -255,7 +264,14 @@ pub async fn serve_all ( range_end = end; } - let should_send_body = matches! (&method, http_serde::Method::Get); + let should_send_body = match &method { + http_serde::Method::Get => true, + http_serde::Method::Head => false, + m => { + debug! ("Unsupported method {:?}", m); + return serve_error (http_serde::StatusCode::MethodNotAllowed, "Unsupported method".into ()).await; + } + }; use percent_encoding::*; diff --git a/src/server/mod.rs b/src/server/mod.rs index 901b485..59050ca 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -16,7 +16,7 @@ use tokio::{ sync::oneshot, time::delay_for, }; -use tracing::info; +use tracing::{debug, error, info}; use crate::{ http_serde, @@ -40,13 +40,14 @@ fn status_reply (c: http_serde::StatusCode, body: &str) -> http_serde::Response } async fn handle_req_resp <'a> ( - state: Arc , + state: &Arc , req_resp: reqwest::Response ) { //println! ("Step 1"); if req_resp.status () != StatusCode::OK { // TODO: Error handling + error! ("http_listen didn't respond with 200 OK"); return; } @@ -54,15 +55,22 @@ async fn handle_req_resp <'a> ( let wrapped_reqs: Vec = match rmp_serde::from_read_ref (&body) { Ok (x) => x, - _ => return, + Err (e) => { + error! ("Can't parse wrapped requests: {:?}", e); + return; + }, }; + debug! ("Unwrapped {} requests", wrapped_reqs.len ()); + for wrapped_req in wrapped_reqs.into_iter () { let state = state.clone (); tokio::spawn (async move { let (req_id, parts) = (wrapped_req.id, wrapped_req.req); + debug! ("Handling request {}", req_id); + let response = if let Some (uri) = prefix_match (&parts.uri, "/files") { let default_root = PathBuf::from ("./"); let file_server_root: &std::path::Path = state.config.file_server_root @@ -78,6 +86,7 @@ async fn handle_req_resp <'a> ( ).await } else { + debug! ("404 not found"); status_reply (http_serde::StatusCode::NotFound, "404 Not Found") }; @@ -94,12 +103,16 @@ async fn handle_req_resp <'a> ( let req = resp_req.build ().unwrap (); - eprintln! ("{:?}", req.headers ()); + debug! ("{:?}", req.headers ()); //println! ("Step 6"); match state.client.execute (req).await { - Ok (r) => eprintln! ("{:?} {:?}", r.status (), r.text ().await.unwrap ()), - Err (e) => eprintln! ("Err: {:?}", e), + Ok (r) => { + let status = r.status (); + let text = r.text ().await.unwrap (); + debug! ("{:?} {:?}", status, text); + }, + Err (e) => error! ("Err: {:?}", e), } }); @@ -134,7 +147,7 @@ pub async fn run_server ( let tripcode = base64::encode (blake3::hash (config_file.api_key.as_bytes ()).as_bytes ()); - println! ("Our tripcode is {}", tripcode); + info! ("Our tripcode is {}", tripcode); let mut headers = reqwest::header::HeaderMap::new (); headers.insert ("X-ApiKey", config_file.api_key.try_into ().unwrap ()); @@ -170,7 +183,7 @@ pub async fn run_server ( } } - info! ("http_listen"); + debug! ("http_listen"); let req_req = state.client.get (&format! ("{}/http_listen/{}", state.config.relay_url, config_file.name)).send (); @@ -186,7 +199,7 @@ pub async fn run_server ( let req_resp = match req_req { Err (e) => { - eprintln! ("Err: {:?}", e); + error! ("Err: {:?}", e); backoff_delay = err_backoff_delay; continue; }, @@ -197,20 +210,18 @@ pub async fn run_server ( }; if req_resp.status () != StatusCode::OK { - eprintln! ("{}", req_resp.status ()); - eprintln! ("{}", String::from_utf8 (req_resp.bytes ().await.unwrap ().to_vec ()).unwrap ()); + error! ("{}", req_resp.status ()); + let body = req_resp.bytes ().await.unwrap (); + let body = String::from_utf8 (body.to_vec ()).unwrap (); + error! ("{}", body); backoff_delay = err_backoff_delay; continue; } - // Spawn another task for each request so we can - // immediately listen for the next connection + // Unpack the requests, spawn them into new tasks, then loop back + // around. - let state = state.clone (); - - tokio::spawn (async move { - handle_req_resp (state, req_resp).await; - }); + handle_req_resp (&state, req_resp).await; } info! ("Exiting"); diff --git a/todo.md b/todo.md index 3f9447b..259d4f9 100644 --- a/todo.md +++ b/todo.md @@ -13,7 +13,11 @@ - Server-side hash? - Log / audit log? -- Prevent directory traversal attacks +- Prevent directory traversal attacks in file_server.rs - Error handling - Reverse proxy to other local servers + +Off-project stuff: + +- Benchmark directory entry sorting