diff --git a/crates/ptth_relay/src/lib.rs b/crates/ptth_relay/src/lib.rs index 1b94650..c3ed79f 100644 --- a/crates/ptth_relay/src/lib.rs +++ b/crates/ptth_relay/src/lib.rs @@ -21,9 +21,7 @@ use std::{ iter::FromIterator, net::SocketAddr, path::{Path, PathBuf}, - sync::{ - Arc, - }, + sync::Arc, time::Duration, }; @@ -33,10 +31,6 @@ use chrono::{ Utc }; use dashmap::DashMap; -use futures::{ - FutureExt, - stream::StreamExt, -}; use handlebars::Handlebars; use hyper::{ Body, @@ -51,15 +45,12 @@ use serde::{ Serialize, }; use tokio::{ - spawn, sync::{ Mutex, - mpsc, oneshot, RwLock, watch, }, - time::delay_for, }; use ptth_core::{ @@ -71,6 +62,7 @@ use ptth_core::{ pub mod config; pub mod errors; pub mod git_version; +mod server_endpoint; pub use config::Config; pub use errors::*; @@ -180,191 +172,6 @@ fn error_reply (status: StatusCode, b: &str) .body (format! ("{}\n", b).into ()) } -// Servers will come here and either handle queued requests from parked clients, -// or park themselves until a request comes in. - -async fn handle_http_listen ( - state: Arc , - watcher_code: String, - api_key: &[u8], -) --> Result , RequestError> -{ - use RequestRendezvous::*; - - let trip_error = || Ok (error_reply (StatusCode::UNAUTHORIZED, "Bad X-ApiKey")?); - - let expected_tripcode = { - let config = state.config.read ().await; - - match config.servers.get (&watcher_code) { - None => { - error! ("Denied http_listen for non-existent server name {}", watcher_code); - return trip_error (); - }, - Some (x) => (*x).tripcode, - } - }; - let actual_tripcode = blake3::hash (api_key); - - if expected_tripcode != actual_tripcode { - error! ("Denied http_listen for bad tripcode {}", base64::encode (actual_tripcode.as_bytes ())); - return trip_error (); - } - - // End of early returns - - { - let mut server_status = state.server_status.lock ().await; - - let mut status = server_status.entry (watcher_code.clone ()).or_insert_with (Default::default); - - status.last_seen = Utc::now (); - } - - let (tx, rx) = oneshot::channel (); - - { - let mut request_rendezvous = state.request_rendezvous.lock ().await; - - if let Some (ParkedClients (v)) = request_rendezvous.remove (&watcher_code) - { - if ! v.is_empty () { - // 1 or more clients were parked - Make the server - // handle them immediately - - debug! ("Sending {} parked requests to server {}", v.len (), watcher_code); - return Ok (ok_reply (rmp_serde::to_vec (&v)?)?); - } - } - - debug! ("Parking server {}", watcher_code); - request_rendezvous.insert (watcher_code.clone (), ParkedServer (tx)); - } - - // No clients were parked - make the server long-poll - - futures::select! { - x = rx.fuse () => match x { - Ok (Ok (one_req)) => { - debug! ("Unparking server {}", watcher_code); - Ok (ok_reply (rmp_serde::to_vec (&vec! [one_req])?)?) - }, - Ok (Err (ShuttingDownError::ShuttingDown)) => Ok (error_reply (StatusCode::SERVICE_UNAVAILABLE, "Server is shutting down, try again soon")?), - Err (_) => Ok (error_reply (StatusCode::INTERNAL_SERVER_ERROR, "Server error")?), - }, - _ = delay_for (Duration::from_secs (30)).fuse () => { - debug! ("Timed out http_listen for server {}", watcher_code); - return Ok (error_reply (StatusCode::NO_CONTENT, "No requests now, long-poll again")?) - } - } -} - -// Servers will come here to stream responses to clients - -async fn handle_http_response ( - req: Request , - state: Arc , - req_id: String, -) --> Result , HandleHttpResponseError> -{ - #[derive (Debug)] - enum BodyFinishedReason { - StreamFinished, - ClientDisconnected, - } - use BodyFinishedReason::*; - use HandleHttpResponseError::*; - - let (parts, mut body) = req.into_parts (); - - let magic_header = parts.headers.get (ptth_core::PTTH_MAGIC_HEADER).ok_or (MissingPtthMagicHeader)?; - - let magic_header = base64::decode (magic_header).map_err (PtthMagicHeaderNotBase64)?; - - let resp_parts: http_serde::ResponseParts = rmp_serde::from_read_ref (&magic_header).map_err (PtthMagicHeaderNotMsgPack)?; - - // Intercept the body packets here so we can check when the stream - // ends or errors out - - let (mut body_tx, body_rx) = mpsc::channel (2); - let (body_finished_tx, body_finished_rx) = oneshot::channel (); - let mut shutdown_watch_rx = state.shutdown_watch_rx.clone (); - - let relay_task = spawn (async move { - if shutdown_watch_rx.recv ().await == Some (false) { - loop { - let item = body.next ().await; - - if let Some (item) = item { - if let Ok (bytes) = &item { - trace! ("Relaying {} bytes", bytes.len ()); - } - - futures::select! { - x = body_tx.send (item).fuse () => if let Err (_) = x { - info! ("Body closed while relaying. (Client hung up?)"); - body_finished_tx.send (ClientDisconnected).map_err (|_| LostServer)?; - break; - }, - _ = shutdown_watch_rx.recv ().fuse () => { - debug! ("Closing stream: relay is shutting down"); - break; - }, - } - } - else { - debug! ("Finished relaying bytes"); - body_finished_tx.send (StreamFinished).map_err (|_| LostServer)?; - break; - } - } - } - else { - debug! ("Can't relay bytes, relay is shutting down"); - } - - Ok::<(), HandleHttpResponseError> (()) - }); - - let body = Body::wrap_stream (body_rx); - - let tx = { - let response_rendezvous = state.response_rendezvous.read ().await; - match response_rendezvous.remove (&req_id) { - None => { - error! ("Server tried to respond to non-existent request"); - return Ok (error_reply (StatusCode::BAD_REQUEST, "Request ID not found in response_rendezvous")?); - }, - Some ((_, x)) => x, - } - }; - - // UKAUFFY4 (Send half) - if tx.send (Ok ((resp_parts, body))).is_err () { - let msg = "Failed to connect to client"; - error! (msg); - return Ok (error_reply (StatusCode::BAD_GATEWAY, msg)?); - } - - relay_task.await??; - - debug! ("Connected server to client for streaming."); - match body_finished_rx.await { - Ok (StreamFinished) => { - Ok (error_reply (StatusCode::OK, "StreamFinished")?) - }, - Ok (ClientDisconnected) => { - Ok (error_reply (StatusCode::OK, "ClientDisconnected")?) - }, - Err (e) => { - debug! ("body_finished_rx {}", e); - Ok (error_reply (StatusCode::OK, "body_finished_rx Err")?) - }, - } -} - // Clients will come here to start requests, and always park for at least // a short amount of time. @@ -616,7 +423,7 @@ async fn handle_all (req: Request , state: Arc ) return if let Some (request_code) = prefix_match ("/7ZSFUKGV/http_response/", path) { let request_code = request_code.into (); - Ok (handle_http_response (req, state, request_code).await?) + Ok (server_endpoint::handle_response (req, state, request_code).await?) } else { Ok (error_reply (StatusCode::BAD_REQUEST, "Can't POST this")?) @@ -628,7 +435,7 @@ async fn handle_all (req: Request , state: Arc ) None => return Ok (error_reply (StatusCode::UNAUTHORIZED, "Can't register as server without an API key")?), Some (x) => x, }; - handle_http_listen (state, listen_code.into (), api_key.as_bytes ()).await + server_endpoint::handle_listen (state, listen_code.into (), api_key.as_bytes ()).await } else if let Some (rest) = prefix_match ("/frontend/servers/", path) { if rest == "" { diff --git a/crates/ptth_relay/src/server_endpoint.rs b/crates/ptth_relay/src/server_endpoint.rs new file mode 100644 index 0000000..dc4083a --- /dev/null +++ b/crates/ptth_relay/src/server_endpoint.rs @@ -0,0 +1,227 @@ +use std::{ + sync::Arc, + time::Duration, +}; + +use chrono::Utc; +use futures::{ + FutureExt, + stream::StreamExt, +}; +use hyper::{ + Body, + Response, + Request, + StatusCode, +}; +use tokio::{ + spawn, + sync::{ + mpsc, + oneshot, + }, + time::delay_for, +}; + +use ptth_core::{ + http_serde, + prelude::*, +}; + +use super::{ + error_reply, + errors::{ + RequestError, + ShuttingDownError, + }, + HandleHttpResponseError, + ok_reply, + RelayState, +}; + +// Servers will come here and either handle queued requests from parked clients, +// or park themselves until a request comes in. +// Step 1 + +pub async fn handle_listen ( + state: Arc , + watcher_code: String, + api_key: &[u8], +) +-> Result , RequestError> +{ + use super::RequestRendezvous::*; + + let trip_error = || Ok (error_reply (StatusCode::UNAUTHORIZED, "Bad X-ApiKey")?); + + let expected_tripcode = { + let config = state.config.read ().await; + + match config.servers.get (&watcher_code) { + None => { + error! ("Denied http_listen for non-existent server name {}", watcher_code); + return trip_error (); + }, + Some (x) => (*x).tripcode, + } + }; + let actual_tripcode = blake3::hash (api_key); + + if expected_tripcode != actual_tripcode { + error! ("Denied http_listen for bad tripcode {}", base64::encode (actual_tripcode.as_bytes ())); + return trip_error (); + } + + // End of early returns + + { + let mut server_status = state.server_status.lock ().await; + + let mut status = server_status.entry (watcher_code.clone ()).or_insert_with (Default::default); + + status.last_seen = Utc::now (); + } + + let (tx, rx) = oneshot::channel (); + + { + let mut request_rendezvous = state.request_rendezvous.lock ().await; + + if let Some (ParkedClients (v)) = request_rendezvous.remove (&watcher_code) + { + if ! v.is_empty () { + // 1 or more clients were parked - Make the server + // handle them immediately + + debug! ("Sending {} parked requests to server {}", v.len (), watcher_code); + return Ok (ok_reply (rmp_serde::to_vec (&v)?)?); + } + } + + debug! ("Parking server {}", watcher_code); + request_rendezvous.insert (watcher_code.clone (), ParkedServer (tx)); + } + + // No clients were parked - make the server long-poll + + futures::select! { + x = rx.fuse () => match x { + Ok (Ok (one_req)) => { + debug! ("Unparking server {}", watcher_code); + Ok (ok_reply (rmp_serde::to_vec (&vec! [one_req])?)?) + }, + Ok (Err (ShuttingDownError::ShuttingDown)) => Ok (error_reply (StatusCode::SERVICE_UNAVAILABLE, "Server is shutting down, try again soon")?), + Err (_) => Ok (error_reply (StatusCode::INTERNAL_SERVER_ERROR, "Server error")?), + }, + _ = delay_for (Duration::from_secs (30)).fuse () => { + debug! ("Timed out http_listen for server {}", watcher_code); + return Ok (error_reply (StatusCode::NO_CONTENT, "No requests now, long-poll again")?) + } + } +} + +// Servers will come here to stream responses to clients +// Step 5 + +pub async fn handle_response ( + req: Request , + state: Arc , + req_id: String, +) +-> Result , HandleHttpResponseError> +{ + #[derive (Debug)] + enum BodyFinishedReason { + StreamFinished, + ClientDisconnected, + } + use BodyFinishedReason::*; + use HandleHttpResponseError::*; + + let (parts, mut body) = req.into_parts (); + + let magic_header = parts.headers.get (ptth_core::PTTH_MAGIC_HEADER).ok_or (MissingPtthMagicHeader)?; + + let magic_header = base64::decode (magic_header).map_err (PtthMagicHeaderNotBase64)?; + + let resp_parts: http_serde::ResponseParts = rmp_serde::from_read_ref (&magic_header).map_err (PtthMagicHeaderNotMsgPack)?; + + // Intercept the body packets here so we can check when the stream + // ends or errors out + + let (mut body_tx, body_rx) = mpsc::channel (2); + let (body_finished_tx, body_finished_rx) = oneshot::channel (); + let mut shutdown_watch_rx = state.shutdown_watch_rx.clone (); + + let relay_task = spawn (async move { + if shutdown_watch_rx.recv ().await == Some (false) { + loop { + let item = body.next ().await; + + if let Some (item) = item { + if let Ok (bytes) = &item { + trace! ("Relaying {} bytes", bytes.len ()); + } + + futures::select! { + x = body_tx.send (item).fuse () => if let Err (_) = x { + info! ("Body closed while relaying. (Client hung up?)"); + body_finished_tx.send (ClientDisconnected).map_err (|_| LostServer)?; + break; + }, + _ = shutdown_watch_rx.recv ().fuse () => { + debug! ("Closing stream: relay is shutting down"); + break; + }, + } + } + else { + debug! ("Finished relaying bytes"); + body_finished_tx.send (StreamFinished).map_err (|_| LostServer)?; + break; + } + } + } + else { + debug! ("Can't relay bytes, relay is shutting down"); + } + + Ok::<(), HandleHttpResponseError> (()) + }); + + let body = Body::wrap_stream (body_rx); + + let tx = { + let response_rendezvous = state.response_rendezvous.read ().await; + match response_rendezvous.remove (&req_id) { + None => { + error! ("Server tried to respond to non-existent request"); + return Ok (error_reply (StatusCode::BAD_REQUEST, "Request ID not found in response_rendezvous")?); + }, + Some ((_, x)) => x, + } + }; + + // UKAUFFY4 (Send half) + if tx.send (Ok ((resp_parts, body))).is_err () { + let msg = "Failed to connect to client"; + error! (msg); + return Ok (error_reply (StatusCode::BAD_GATEWAY, msg)?); + } + + relay_task.await??; + + debug! ("Connected server to client for streaming."); + match body_finished_rx.await { + Ok (StreamFinished) => { + Ok (error_reply (StatusCode::OK, "StreamFinished")?) + }, + Ok (ClientDisconnected) => { + Ok (error_reply (StatusCode::OK, "ClientDisconnected")?) + }, + Err (e) => { + debug! ("body_finished_rx {}", e); + Ok (error_reply (StatusCode::OK, "body_finished_rx Err")?) + }, + } +}