pub mod watcher; use std::{ error::Error, convert::Infallible, net::SocketAddr, sync::{ Arc }, time::{Duration}, }; use futures::channel::oneshot; use handlebars::Handlebars; use hyper::{ Body, Method, Request, Response, Server, StatusCode, }; use hyper::service::{make_service_fn, service_fn}; use serde::Serialize; use tokio::{ sync::Mutex, time::delay_for, }; use crate::{ http_serde, }; use watcher::*; enum Message { HttpRequestResponse (http_serde::WrappedRequest), HttpResponseResponseStream ((http_serde::ResponseParts, Body)), } #[derive (Default)] struct ServerState { handlebars: Arc >, watchers: Arc >>, } fn status_reply > (status: StatusCode, b: B) -> Response { Response::builder ().status (status).body (b.into ()).unwrap () } async fn handle_http_listen (state: Arc , watcher_code: String) -> Response { //println! ("Step 1"); match Watchers::long_poll (state.watchers.clone (), watcher_code).await { Some (Message::HttpRequestResponse (parts)) => { //println! ("Step 3"); status_reply (StatusCode::OK, rmp_serde::to_vec (&parts).unwrap ()) }, _ => status_reply (StatusCode::GATEWAY_TIMEOUT, "no\n"), } } async fn handle_http_response ( req: Request , state: Arc , req_id: String, ) -> Response { //println! ("Step 6"); let (parts, body) = req.into_parts (); let resp_parts: http_serde::ResponseParts = rmp_serde::from_read_ref (&base64::decode (parts.headers.get (crate::PTTH_MAGIC_HEADER).unwrap ()).unwrap ()).unwrap (); { let mut watchers = state.watchers.lock ().await; //println! ("Step 7"); if ! watchers.wake_one (Message::HttpResponseResponseStream ((resp_parts, body)), &req_id) { println! ("Step 8 (bad thing)"); status_reply (StatusCode::BAD_REQUEST, "A bad thing happened.\n") } else { //println! ("Step 8"); status_reply (StatusCode::OK, "ok\n") } } } async fn handle_http_request ( req: http::request::Parts, uri: String, state: Arc , watcher_code: String ) -> Response { let parts = { let id = ulid::Ulid::new ().to_string (); let req = match http_serde::RequestParts::from_hyper (req.method, uri, req.headers) { Ok (x) => x, _ => return status_reply (StatusCode::BAD_REQUEST, "Bad request"), }; http_serde::WrappedRequest { id, req, } }; //println! ("Step 2 {}", parts.id); let (s, r) = oneshot::channel (); let timeout = Duration::from_secs (5); let id_2 = parts.id.clone (); { let mut that = state.watchers.lock ().await; that.add_watcher_with_id (s, id_2) } let req_id = parts.id.clone (); tokio::spawn (async move { { let mut watchers = state.watchers.lock ().await; //println! ("Step 3"); if ! watchers.wake_one (Message::HttpRequestResponse (parts), &watcher_code) { watchers.remove_watcher (&req_id); } } delay_for (timeout).await; { let mut that = state.watchers.lock ().await; that.remove_watcher (&req_id); } }); match r.await { Ok (Message::HttpResponseResponseStream ((resp_parts, body))) => { //println! ("Step 7"); let mut resp = Response::builder () .status (hyper::StatusCode::from (resp_parts.status_code)); for (k, v) in resp_parts.headers.into_iter () { resp = resp.header (&k, v); } resp .body (body) .unwrap () }, _ => status_reply (StatusCode::GATEWAY_TIMEOUT, "server didn't reply in time or somethin'"), } } fn prefix_match <'a> (hay: &'a str, needle: &str) -> Option <&'a str> { if hay.starts_with (needle) { Some (&hay [needle.len ()..]) } else { None } } async fn handle_all (req: Request , state: Arc ) -> Result , Infallible> { let path = req.uri ().path (); //println! ("{}", path); if req.method () == Method::POST { // This is stuff the server can use. Clients can't // POST right now return Ok (if let Some (request_code) = prefix_match (path, "/7ZSFUKGV_http_response/") { let request_code = request_code.into (); handle_http_response (req, state, request_code).await } else { status_reply (StatusCode::BAD_REQUEST, "Can't POST this\n") }); } Ok (if let Some (listen_code) = prefix_match (path, "/7ZSFUKGV_http_listen/") { handle_http_listen (state, listen_code.into ()).await } else if let Some (rest) = prefix_match (path, "/servers/") { if rest == "" { #[derive (Serialize)] struct ServerEntry <'a> { path: &'a str, name: &'a str, } #[derive (Serialize)] struct ServerListPage <'a> { servers: Vec >, } let names: Vec <_> = { state.watchers.lock ().await.senders.iter () .map (|(k, _)| (*k).clone ()) .collect () }; //println! ("Found {} servers", names.len ()); let page = ServerListPage { servers: names.iter () .map (|name| ServerEntry { name: &name, path: &name, }) .collect (), }; let s = state.handlebars.render ("relay_server_list", &page).unwrap (); status_reply (StatusCode::OK, s) } else if let Some (idx) = rest.find ('/') { let listen_code = String::from (&rest [0..idx]); let path = String::from (&rest [idx..]); let (parts, _) = req.into_parts (); handle_http_request (parts, path, state, listen_code).await } else { status_reply (StatusCode::BAD_REQUEST, "Bad URI format") } } else if path == "/relay_up_check" { status_reply (StatusCode::OK, "Relay is up\n") } else { status_reply (StatusCode::OK, "Hi\n") }) } pub fn load_templates () -> Result , Box > { let mut handlebars = Handlebars::new (); handlebars.set_strict_mode (true); for (k, v) in vec! [ ("relay_server_list", "relay_server_list.html"), ].into_iter () { handlebars.register_template_file (k, format! ("ptth_handlebars/{}", v))?; } Ok (handlebars) } pub async fn main () -> Result <(), Box > { let addr = SocketAddr::from(([0, 0, 0, 0], 4000)); let state = ServerState { handlebars: Arc::new (load_templates ()?), watchers: Default::default (), }; let state = Arc::new (state); let make_svc = make_service_fn (|_conn| { let state = state.clone (); async { Ok::<_, Infallible> (service_fn (move |req| { let state = state.clone (); handle_all (req, state) })) } }); let server = Server::bind (&addr).serve (make_svc); server.await?; Ok (()) }