use std::{ sync::Arc, time::Duration, }; use chrono::Utc; use futures::{ FutureExt, stream::StreamExt, }; use hyper::{ Body, Response, Request, StatusCode, }; use tokio::{ spawn, sync::{ mpsc, oneshot, }, }; use tokio_stream::wrappers::ReceiverStream; use ptth_core::{ http_serde, prelude::*, }; use super::{ error_reply, errors::{ RequestError, ShuttingDownError, }, HandleHttpResponseError, ok_reply, RelayState, }; // Servers will come here and either handle queued requests from parked clients, // or park themselves until a request comes in. // Step 1 pub async fn handle_listen ( state: Arc , watcher_code: String, api_key: &[u8], ) -> Result , RequestError> { use super::RequestRendezvous::*; let trip_error = || Ok (error_reply (StatusCode::UNAUTHORIZED, "Bad X-ApiKey")?); let expected_tripcode = { let config = state.config.read ().await; match config.servers.get (&watcher_code) { None => { error! ("Denied http_listen for non-existent server name {}", watcher_code); return trip_error (); }, Some (x) => *(*x).tripcode, } }; let actual_tripcode = blake3::hash (api_key); if expected_tripcode != actual_tripcode { error! ("Denied http_listen for bad tripcode {}", base64::encode (actual_tripcode.as_bytes ())); return trip_error (); } // End of early returns { // TODO: Move into relay_state.rs let mut server_status = state.server_status.lock ().await; let mut status = server_status.entry (watcher_code.clone ()).or_insert_with (Default::default); status.last_seen = Utc::now (); } let (tx, rx) = oneshot::channel (); { let mut request_rendezvous = state.request_rendezvous.lock ().await; if let Some (ParkedClients (v)) = request_rendezvous.remove (&watcher_code) { if ! v.is_empty () { // 1 or more clients were parked - Make the server // handle them immediately debug! ("Sending {} parked requests to server {}", v.len (), watcher_code); return Ok (ok_reply (rmp_serde::to_vec (&v)?)?); } } trace! ("Parking server {}", watcher_code); request_rendezvous.insert (watcher_code.clone (), ParkedServer (tx)); } // No clients were parked - make the server long-poll futures::select! { x = rx.fuse () => match x { Ok (Ok (one_req)) => { trace! ("Unparking server {}", watcher_code); Ok (ok_reply (rmp_serde::to_vec (&vec! [one_req])?)?) }, Ok (Err (ShuttingDownError::ShuttingDown)) => Ok (error_reply (StatusCode::SERVICE_UNAVAILABLE, "Server is shutting down, try again soon")?), Err (_) => Ok (error_reply (StatusCode::INTERNAL_SERVER_ERROR, "Server error")?), }, _ = tokio::time::sleep (Duration::from_secs (30)).fuse () => { trace! ("Timed out http_listen for server {}", watcher_code); return Ok (error_reply (StatusCode::NO_CONTENT, "No requests now, long-poll again")?) } } } // Servers will come here to stream responses to clients // Step 5 pub async fn handle_response ( req: Request , state: Arc , req_id: String, ) -> Result , HandleHttpResponseError> { #[derive (Debug)] enum BodyFinishedReason { StreamFinished, ClientDisconnected, } use BodyFinishedReason::*; use HandleHttpResponseError::*; let (parts, mut body) = req.into_parts (); let magic_header = parts.headers.get (ptth_core::PTTH_MAGIC_HEADER).ok_or (MissingPtthMagicHeader)?; let magic_header = base64::decode (magic_header).map_err (PtthMagicHeaderNotBase64)?; let resp_parts: http_serde::ResponseParts = rmp_serde::from_read_ref (&magic_header).map_err (PtthMagicHeaderNotMsgPack)?; // Intercept the body packets here so we can check when the stream // ends or errors out let (body_tx, body_rx) = mpsc::channel (2); let (body_finished_tx, body_finished_rx) = oneshot::channel (); let mut shutdown_watch_rx = state.shutdown_watch_rx.clone (); let relay_task = spawn (async move { if *shutdown_watch_rx.borrow () == false { loop { let item = body.next ().await; if let Some (item) = item { if let Ok (bytes) = &item { trace! ("Relaying {} bytes", bytes.len ()); } futures::select! { x = body_tx.send (item).fuse () => if let Err (_) = x { info! ("Body closed while relaying. (Client hung up?)"); body_finished_tx.send (ClientDisconnected).map_err (|_| LostServer)?; break; }, _ = shutdown_watch_rx.changed ().fuse () => { debug! ("Closing stream: relay is shutting down"); break; }, } } else { trace! ("Finished relaying bytes"); body_finished_tx.send (StreamFinished).map_err (|_| LostServer)?; break; } } } else { debug! ("Can't relay bytes, relay is shutting down"); } Ok::<(), HandleHttpResponseError> (()) }); let body = Body::wrap_stream (ReceiverStream::new (body_rx)); let tx = { let response_rendezvous = state.response_rendezvous.read ().await; match response_rendezvous.remove (&req_id) { None => { error! ("Server tried to respond to non-existent request"); return Ok (error_reply (StatusCode::BAD_REQUEST, "Request ID not found in response_rendezvous")?); }, Some ((_, x)) => x, } }; // UKAUFFY4 (Send half) if tx.send (Ok ((resp_parts, body))).is_err () { let msg = "Failed to connect to client"; error! (msg); return Ok (error_reply (StatusCode::BAD_GATEWAY, msg)?); } relay_task.await??; trace! ("Connected server to client for streaming."); match body_finished_rx.await { Ok (StreamFinished) => { Ok (error_reply (StatusCode::OK, "StreamFinished")?) }, Ok (ClientDisconnected) => { Ok (error_reply (StatusCode::OK, "ClientDisconnected")?) }, Err (e) => { debug! ("body_finished_rx {}", e); Ok (error_reply (StatusCode::OK, "body_finished_rx Err")?) }, } }