ptth/crates/ptth_relay/src/server_endpoint.rs

230 lines
5.7 KiB
Rust
Raw Normal View History

use std::{
sync::Arc,
time::Duration,
};
use chrono::Utc;
use futures::{
FutureExt,
stream::StreamExt,
};
use hyper::{
Body,
Response,
Request,
StatusCode,
};
use tokio::{
spawn,
sync::{
mpsc,
oneshot,
},
time::delay_for,
};
use ptth_core::{
http_serde,
prelude::*,
};
use super::{
error_reply,
errors::{
RequestError,
ShuttingDownError,
},
HandleHttpResponseError,
ok_reply,
RelayState,
};
// Servers will come here and either handle queued requests from parked clients,
// or park themselves until a request comes in.
// Step 1
pub async fn handle_listen (
state: Arc <RelayState>,
watcher_code: String,
api_key: &[u8],
)
-> Result <Response <Body>, RequestError>
{
use super::RequestRendezvous::*;
let trip_error = || Ok (error_reply (StatusCode::UNAUTHORIZED, "Bad X-ApiKey")?);
let expected_tripcode = {
let config = state.config.read ().await;
match config.servers.get (&watcher_code) {
None => {
error! ("Denied http_listen for non-existent server name {}", watcher_code);
return trip_error ();
},
Some (x) => *(*x).tripcode,
}
};
let actual_tripcode = blake3::hash (api_key);
if expected_tripcode != actual_tripcode {
error! ("Denied http_listen for bad tripcode {}", base64::encode (actual_tripcode.as_bytes ()));
return trip_error ();
}
// End of early returns
{
// TODO: Move into relay_state.rs
let mut server_status = state.server_status.lock ().await;
let mut status = server_status.entry (watcher_code.clone ()).or_insert_with (Default::default);
status.last_seen = Utc::now ();
}
let (tx, rx) = oneshot::channel ();
{
let mut request_rendezvous = state.request_rendezvous.lock ().await;
if let Some (ParkedClients (v)) = request_rendezvous.remove (&watcher_code)
{
if ! v.is_empty () {
// 1 or more clients were parked - Make the server
// handle them immediately
debug! ("Sending {} parked requests to server {}", v.len (), watcher_code);
return Ok (ok_reply (rmp_serde::to_vec (&v)?)?);
}
}
trace! ("Parking server {}", watcher_code);
request_rendezvous.insert (watcher_code.clone (), ParkedServer (tx));
}
// No clients were parked - make the server long-poll
futures::select! {
x = rx.fuse () => match x {
Ok (Ok (one_req)) => {
trace! ("Unparking server {}", watcher_code);
Ok (ok_reply (rmp_serde::to_vec (&vec! [one_req])?)?)
},
Ok (Err (ShuttingDownError::ShuttingDown)) => Ok (error_reply (StatusCode::SERVICE_UNAVAILABLE, "Server is shutting down, try again soon")?),
Err (_) => Ok (error_reply (StatusCode::INTERNAL_SERVER_ERROR, "Server error")?),
},
_ = delay_for (Duration::from_secs (30)).fuse () => {
trace! ("Timed out http_listen for server {}", watcher_code);
return Ok (error_reply (StatusCode::NO_CONTENT, "No requests now, long-poll again")?)
}
}
}
// Servers will come here to stream responses to clients
// Step 5
pub async fn handle_response (
req: Request <Body>,
state: Arc <RelayState>,
req_id: String,
)
-> Result <Response <Body>, HandleHttpResponseError>
{
#[derive (Debug)]
enum BodyFinishedReason {
StreamFinished,
ClientDisconnected,
}
use BodyFinishedReason::*;
use HandleHttpResponseError::*;
let (parts, mut body) = req.into_parts ();
let magic_header = parts.headers.get (ptth_core::PTTH_MAGIC_HEADER).ok_or (MissingPtthMagicHeader)?;
let magic_header = base64::decode (magic_header).map_err (PtthMagicHeaderNotBase64)?;
let resp_parts: http_serde::ResponseParts = rmp_serde::from_read_ref (&magic_header).map_err (PtthMagicHeaderNotMsgPack)?;
// Intercept the body packets here so we can check when the stream
// ends or errors out
let (mut body_tx, body_rx) = mpsc::channel (2);
let (body_finished_tx, body_finished_rx) = oneshot::channel ();
let mut shutdown_watch_rx = state.shutdown_watch_rx.clone ();
let relay_task = spawn (async move {
if shutdown_watch_rx.recv ().await == Some (false) {
loop {
let item = body.next ().await;
if let Some (item) = item {
if let Ok (bytes) = &item {
trace! ("Relaying {} bytes", bytes.len ());
}
futures::select! {
x = body_tx.send (item).fuse () => if let Err (_) = x {
info! ("Body closed while relaying. (Client hung up?)");
body_finished_tx.send (ClientDisconnected).map_err (|_| LostServer)?;
break;
},
_ = shutdown_watch_rx.recv ().fuse () => {
debug! ("Closing stream: relay is shutting down");
break;
},
}
}
else {
trace! ("Finished relaying bytes");
body_finished_tx.send (StreamFinished).map_err (|_| LostServer)?;
break;
}
}
}
else {
debug! ("Can't relay bytes, relay is shutting down");
}
Ok::<(), HandleHttpResponseError> (())
});
let body = Body::wrap_stream (body_rx);
let tx = {
let response_rendezvous = state.response_rendezvous.read ().await;
match response_rendezvous.remove (&req_id) {
None => {
error! ("Server tried to respond to non-existent request");
return Ok (error_reply (StatusCode::BAD_REQUEST, "Request ID not found in response_rendezvous")?);
},
Some ((_, x)) => x,
}
};
// UKAUFFY4 (Send half)
if tx.send (Ok ((resp_parts, body))).is_err () {
let msg = "Failed to connect to client";
error! (msg);
return Ok (error_reply (StatusCode::BAD_GATEWAY, msg)?);
}
relay_task.await??;
trace! ("Connected server to client for streaming.");
match body_finished_rx.await {
Ok (StreamFinished) => {
Ok (error_reply (StatusCode::OK, "StreamFinished")?)
},
Ok (ClientDisconnected) => {
Ok (error_reply (StatusCode::OK, "ClientDisconnected")?)
},
Err (e) => {
debug! ("body_finished_rx {}", e);
Ok (error_reply (StatusCode::OK, "body_finished_rx Err")?)
},
}
}