pub mod watcher; use std::{ error::Error, convert::Infallible, net::SocketAddr, sync::{ Arc }, time::{Duration}, }; use futures::channel::oneshot; use hyper::{ Body, Method, Request, Response, Server, StatusCode, }; use hyper::service::{make_service_fn, service_fn}; use tokio::{ sync::Mutex, time::delay_for, }; use crate::{ http_serde, }; use watcher::*; enum Message { HttpRequestResponse (http_serde::WrappedRequest), HttpResponseResponseStream ((http_serde::ResponseParts, Body)), } #[derive (Default)] struct ServerState { watchers: Arc >>, } fn status_reply > (status: StatusCode, b: B) -> Response { Response::builder ().status (status).body (b.into ()).unwrap () } async fn handle_http_listen (state: Arc , watcher_code: String) -> Response { //println! ("Step 1"); match Watchers::long_poll (state.watchers.clone (), watcher_code).await { Some (Message::HttpRequestResponse (parts)) => { //println! ("Step 3"); status_reply (StatusCode::OK, rmp_serde::to_vec (&parts).unwrap ()) }, _ => status_reply (StatusCode::GATEWAY_TIMEOUT, "no\n"), } } async fn handle_http_response ( req: Request , state: Arc , req_id: String, ) -> Response { //println! ("Step 6"); let (parts, body) = req.into_parts (); let resp_parts: http_serde::ResponseParts = rmp_serde::from_read_ref (&base64::decode (parts.headers.get (crate::PTTH_MAGIC_HEADER).unwrap ()).unwrap ()).unwrap (); { let mut watchers = state.watchers.lock ().await; //println! ("Step 7"); if ! watchers.wake_one (Message::HttpResponseResponseStream ((resp_parts, body)), &req_id) { println! ("Step 8 (bad thing)"); status_reply (StatusCode::BAD_REQUEST, "A bad thing happened.\n") } else { //println! ("Step 8"); status_reply (StatusCode::OK, "ok\n") } } } async fn handle_http_request ( req: http::request::Parts, uri: String, state: Arc , watcher_code: String ) -> Response { let parts = { let id = ulid::Ulid::new ().to_string (); let req = match http_serde::RequestParts::from_hyper (req.method, uri, req.headers) { Ok (x) => x, _ => return status_reply (StatusCode::BAD_REQUEST, "Bad request"), }; http_serde::WrappedRequest { id, req, } }; //println! ("Step 2 {}", parts.id); let (s, r) = oneshot::channel (); let timeout = Duration::from_secs (5); let id_2 = parts.id.clone (); { let mut that = state.watchers.lock ().await; that.add_watcher_with_id (s, id_2) } let req_id = parts.id.clone (); tokio::spawn (async move { { let mut watchers = state.watchers.lock ().await; //println! ("Step 3"); if ! watchers.wake_one (Message::HttpRequestResponse (parts), &watcher_code) { watchers.remove_watcher (&req_id); } } delay_for (timeout).await; { let mut that = state.watchers.lock ().await; that.remove_watcher (&req_id); } }); match r.await { Ok (Message::HttpResponseResponseStream ((resp_parts, body))) => { //println! ("Step 7"); let mut resp = Response::builder () .status (hyper::StatusCode::from (resp_parts.status_code)); for (k, v) in resp_parts.headers.into_iter () { resp = resp.header (&k, v); } resp .body (body) .unwrap () }, _ => status_reply (StatusCode::GATEWAY_TIMEOUT, "server didn't reply in time or somethin'"), } } fn prefix_match <'a> (hay: &'a str, needle: &str) -> Option <&'a str> { if hay.starts_with (needle) { Some (&hay [needle.len ()..]) } else { None } } async fn handle_all (req: Request , state: Arc ) -> Result , Infallible> { let path = req.uri ().path (); //println! ("{}", path); if req.method () == Method::POST { // This is stuff the server can use. Clients can't // POST right now return Ok (if let Some (request_code) = prefix_match (path, "/7ZSFUKGV_http_response/") { let request_code = request_code.into (); handle_http_response (req, state, request_code).await } else { status_reply (StatusCode::BAD_REQUEST, "Can't POST this\n") }); } if let Some (listen_code) = prefix_match (path, "/7ZSFUKGV_http_listen/") { Ok (handle_http_listen (state, listen_code.into ()).await) } else if let Some (rest) = prefix_match (path, "/servers/") { if let Some (idx) = rest.find ('/') { let listen_code = String::from (&rest [0..idx]); let path = String::from (&rest [idx..]); let (parts, _) = req.into_parts (); Ok (handle_http_request (parts, path, state, listen_code).await) } else { Ok (status_reply (StatusCode::BAD_REQUEST, "Bad URI format")) } } else if path == "/relay_up_check" { Ok (status_reply (StatusCode::OK, "Relay is up\n")) } else { Ok (status_reply (StatusCode::OK, "Hi\n")) } } pub async fn main () -> Result <(), Box > { let addr = SocketAddr::from(([0, 0, 0, 0], 4000)); let state = Arc::new (ServerState::default ()); let make_svc = make_service_fn (|_conn| { let state = state.clone (); async { Ok::<_, Infallible> (service_fn (move |req| { let state = state.clone (); handle_all (req, state) })) } }); let server = Server::bind (&addr).serve (make_svc); server.await?; Ok (()) }