From 6b5208fdb44065e87ff4e35c90a4590f383bc31a Mon Sep 17 00:00:00 2001 From: _ <_@_> Date: Fri, 30 Oct 2020 17:57:36 -0500 Subject: [PATCH] :recycle: Move the relay binary into the lib what am i doing --- src/bin/relay.rs | 272 +------------------------------------ src/lib.rs | 2 +- src/relay/mod.rs | 271 ++++++++++++++++++++++++++++++++++++ src/{ => relay}/watcher.rs | 0 4 files changed, 274 insertions(+), 271 deletions(-) create mode 100644 src/relay/mod.rs rename src/{ => relay}/watcher.rs (100%) diff --git a/src/bin/relay.rs b/src/bin/relay.rs index 972928f..51bfde7 100644 --- a/src/bin/relay.rs +++ b/src/bin/relay.rs @@ -1,274 +1,6 @@ -use std::{ - collections::*, - error::Error, - convert::{ - Infallible, - TryFrom, - }, - iter::FromIterator, - net::SocketAddr, - sync::{ - Arc - }, - time::{Duration}, -}; - -use futures::channel::oneshot; -use hyper::{ - Body, - Method, - Request, - Response, - Server, - StatusCode, -}; -use hyper::service::{make_service_fn, service_fn}; - -use tokio::{ - sync::Mutex, - time::delay_for, -}; - -use ptth::{ - http_serde, - watcher::Watchers, -}; - -enum Message { - Meow, - HttpRequestResponse (http_serde::WrappedRequest), - HttpResponseResponseStream ((http_serde::ResponseParts, Body)), -} - -#[derive (Default)] -struct ServerState { - watchers: Arc >>, -} - -fn status_reply > (status: StatusCode, b: B) --> Response -{ - Response::builder ().status (status).body (b.into ()).unwrap () -} - -async fn handle_watch (state: Arc , watcher_code: String) --> Response -{ - match Watchers::long_poll (state.watchers.clone (), watcher_code).await { - None => status_reply (StatusCode::OK, "no\n"), - Some (_) => status_reply (StatusCode::OK, "actually, yes\n"), - } -} - -async fn handle_wake (state: Arc , watcher_code: String) --> Response -{ - let mut watchers = state.watchers.lock ().await; - - if watchers.wake_one (Message::Meow, &watcher_code) { - status_reply (StatusCode::OK, "ok\n") - } - else { - status_reply (StatusCode::BAD_REQUEST, "no\n") - } -} - -async fn handle_http_listen (state: Arc , watcher_code: String) --> Response -{ - //println! ("Step 1"); - match Watchers::long_poll (state.watchers.clone (), watcher_code).await { - Some (Message::HttpRequestResponse (parts)) => { - println! ("Step 3"); - status_reply (StatusCode::OK, rmp_serde::to_vec (&parts).unwrap ()) - }, - _ => status_reply (StatusCode::GATEWAY_TIMEOUT, "no\n"), - } -} - -async fn handle_http_response ( - req: Request , - state: Arc , - req_id: String, -) - -> Response -{ - println! ("Step 6"); - let (parts, body) = req.into_parts (); - let resp_parts: http_serde::ResponseParts = rmp_serde::from_read_ref (&base64::decode (parts.headers.get (ptth::PTTH_MAGIC_HEADER).unwrap ()).unwrap ()).unwrap (); - - { - let mut watchers = state.watchers.lock ().await; - - println! ("Step 7"); - if ! watchers.wake_one (Message::HttpResponseResponseStream ((resp_parts, body)), &req_id) - { - println! ("Step 8 (bad thing)"); - status_reply (StatusCode::BAD_REQUEST, "A bad thing happened.\n") - } - else { - println! ("Step 8"); - status_reply (StatusCode::OK, "ok\n") - } - } -} - -async fn handle_http_request ( - req: http::request::Parts, - uri: String, - state: Arc , - watcher_code: String -) - -> Response -{ - let parts = { - let id = ulid::Ulid::new ().to_string (); - let method = match ptth::http_serde::Method::try_from (req.method) { - Ok (x) => x, - _ => return status_reply (StatusCode::BAD_REQUEST, "Method not supported"), - }; - let headers = HashMap::from_iter ( - req.headers.into_iter () - .filter_map (|(k, v)| k.map (|k| (k, v))) - .map (|(k, v)| (String::from (k.as_str ()), v.as_bytes ().to_vec ())) - ); - - http_serde::WrappedRequest { - id, - req: http_serde::RequestParts { - method, - uri, - headers, - }, - } - }; - - println! ("Step 2 {}", parts.id); - - let (s, r) = oneshot::channel (); - let timeout = Duration::from_secs (5); - - let id_2 = parts.id.clone (); - { - let mut that = state.watchers.lock ().await; - that.add_watcher_with_id (s, id_2) - } - - let req_id = parts.id.clone (); - - tokio::spawn (async move { - { - let mut watchers = state.watchers.lock ().await; - - println! ("Step 3"); - if ! watchers.wake_one (Message::HttpRequestResponse (parts), &watcher_code) { - watchers.remove_watcher (&req_id); - } - } - - delay_for (timeout).await; - { - let mut that = state.watchers.lock ().await; - that.remove_watcher (&req_id); - } - }); - - match r.await { - Ok (Message::HttpResponseResponseStream ((resp_parts, body))) => { - println! ("Step 7"); - - let mut resp = Response::builder () - .status (hyper::StatusCode::from (resp_parts.status_code)); - - for (k, v) in resp_parts.headers.into_iter () { - resp = resp.header (&k, v); - } - - resp - .body (body) - .unwrap () - }, - _ => status_reply (StatusCode::GATEWAY_TIMEOUT, "server didn't reply in time or somethin'"), - } -} - -fn prefix_match <'a> (hay: &'a str, needle: &str) -> Option <&'a str> -{ - if hay.starts_with (needle) { - Some (&hay [needle.len ()..]) - } - else { - None - } -} - -async fn handle_all (req: Request , state: Arc ) --> Result , Infallible> -{ - let path = req.uri ().path (); - //println! ("{}", path); - - if req.method () == Method::POST { - return Ok (if let Some (request_code) = prefix_match (path, "/http_response/") { - let request_code = request_code.into (); - handle_http_response (req, state, request_code).await - } - else { - status_reply (StatusCode::BAD_REQUEST, "Can't POST this\n") - }); - } - - if let Some (watch_code) = prefix_match (path, "/watch/") { - Ok (handle_watch (state, watch_code.into ()).await) - } - else if let Some (watch_code) = prefix_match (path, "/wake/") { - Ok (handle_wake (state, watch_code.into ()).await) - } - else if let Some (listen_code) = prefix_match (path, "/http_listen/") { - Ok (handle_http_listen (state, listen_code.into ()).await) - } - else if let Some (rest) = prefix_match (path, "/http_request/") { - if let Some (idx) = rest.find ('/') { - let listen_code = String::from (&rest [0..idx]); - let path = String::from (&rest [idx..]); - let (parts, _) = req.into_parts (); - - Ok (handle_http_request (parts, path, state, listen_code).await) - } - else { - Ok (status_reply (StatusCode::BAD_REQUEST, "Bad URI format")) - } - } - else { - Ok (status_reply (StatusCode::OK, "Hi\n")) - } -} - -async fn relay_main () -> Result <(), Box > { - let addr = SocketAddr::from(([0, 0, 0, 0], 4000)); - - let state = Arc::new (ServerState::default ()); - - let make_svc = make_service_fn (|_conn| { - let state = state.clone (); - - async { - Ok::<_, Infallible> (service_fn (move |req| { - let state = state.clone (); - - handle_all (req, state) - })) - } - }); - - let server = Server::bind (&addr).serve (make_svc); - - server.await?; - - Ok (()) -} +use std::error::Error; #[tokio::main] async fn main () -> Result <(), Box > { - relay_main ().await + ptth::relay::relay_main ().await } diff --git a/src/lib.rs b/src/lib.rs index f130975..ec6a2d1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,5 @@ pub mod file_server; pub mod http_serde; -pub mod watcher; +pub mod relay; pub const PTTH_MAGIC_HEADER: &str = "X-PTTH-2LJYXWC4"; diff --git a/src/relay/mod.rs b/src/relay/mod.rs new file mode 100644 index 0000000..61b5b68 --- /dev/null +++ b/src/relay/mod.rs @@ -0,0 +1,271 @@ +pub mod watcher; + +use std::{ + collections::*, + error::Error, + convert::{ + Infallible, + TryFrom, + }, + iter::FromIterator, + net::SocketAddr, + sync::{ + Arc + }, + time::{Duration}, +}; + +use futures::channel::oneshot; +use hyper::{ + Body, + Method, + Request, + Response, + Server, + StatusCode, +}; +use hyper::service::{make_service_fn, service_fn}; + +use tokio::{ + sync::Mutex, + time::delay_for, +}; + +use crate::{ + http_serde, +}; +use watcher::*; + +enum Message { + Meow, + HttpRequestResponse (http_serde::WrappedRequest), + HttpResponseResponseStream ((http_serde::ResponseParts, Body)), +} + +#[derive (Default)] +struct ServerState { + watchers: Arc >>, +} + +fn status_reply > (status: StatusCode, b: B) +-> Response +{ + Response::builder ().status (status).body (b.into ()).unwrap () +} + +async fn handle_watch (state: Arc , watcher_code: String) +-> Response +{ + match Watchers::long_poll (state.watchers.clone (), watcher_code).await { + None => status_reply (StatusCode::OK, "no\n"), + Some (_) => status_reply (StatusCode::OK, "actually, yes\n"), + } +} + +async fn handle_wake (state: Arc , watcher_code: String) +-> Response +{ + let mut watchers = state.watchers.lock ().await; + + if watchers.wake_one (Message::Meow, &watcher_code) { + status_reply (StatusCode::OK, "ok\n") + } + else { + status_reply (StatusCode::BAD_REQUEST, "no\n") + } +} + +async fn handle_http_listen (state: Arc , watcher_code: String) +-> Response +{ + //println! ("Step 1"); + match Watchers::long_poll (state.watchers.clone (), watcher_code).await { + Some (Message::HttpRequestResponse (parts)) => { + println! ("Step 3"); + status_reply (StatusCode::OK, rmp_serde::to_vec (&parts).unwrap ()) + }, + _ => status_reply (StatusCode::GATEWAY_TIMEOUT, "no\n"), + } +} + +async fn handle_http_response ( + req: Request , + state: Arc , + req_id: String, +) + -> Response +{ + println! ("Step 6"); + let (parts, body) = req.into_parts (); + let resp_parts: http_serde::ResponseParts = rmp_serde::from_read_ref (&base64::decode (parts.headers.get (crate::PTTH_MAGIC_HEADER).unwrap ()).unwrap ()).unwrap (); + + { + let mut watchers = state.watchers.lock ().await; + + println! ("Step 7"); + if ! watchers.wake_one (Message::HttpResponseResponseStream ((resp_parts, body)), &req_id) + { + println! ("Step 8 (bad thing)"); + status_reply (StatusCode::BAD_REQUEST, "A bad thing happened.\n") + } + else { + println! ("Step 8"); + status_reply (StatusCode::OK, "ok\n") + } + } +} + +async fn handle_http_request ( + req: http::request::Parts, + uri: String, + state: Arc , + watcher_code: String +) + -> Response +{ + let parts = { + let id = ulid::Ulid::new ().to_string (); + let method = match http_serde::Method::try_from (req.method) { + Ok (x) => x, + _ => return status_reply (StatusCode::BAD_REQUEST, "Method not supported"), + }; + let headers = HashMap::from_iter ( + req.headers.into_iter () + .filter_map (|(k, v)| k.map (|k| (k, v))) + .map (|(k, v)| (String::from (k.as_str ()), v.as_bytes ().to_vec ())) + ); + + http_serde::WrappedRequest { + id, + req: http_serde::RequestParts { + method, + uri, + headers, + }, + } + }; + + println! ("Step 2 {}", parts.id); + + let (s, r) = oneshot::channel (); + let timeout = Duration::from_secs (5); + + let id_2 = parts.id.clone (); + { + let mut that = state.watchers.lock ().await; + that.add_watcher_with_id (s, id_2) + } + + let req_id = parts.id.clone (); + + tokio::spawn (async move { + { + let mut watchers = state.watchers.lock ().await; + + println! ("Step 3"); + if ! watchers.wake_one (Message::HttpRequestResponse (parts), &watcher_code) { + watchers.remove_watcher (&req_id); + } + } + + delay_for (timeout).await; + { + let mut that = state.watchers.lock ().await; + that.remove_watcher (&req_id); + } + }); + + match r.await { + Ok (Message::HttpResponseResponseStream ((resp_parts, body))) => { + println! ("Step 7"); + + let mut resp = Response::builder () + .status (hyper::StatusCode::from (resp_parts.status_code)); + + for (k, v) in resp_parts.headers.into_iter () { + resp = resp.header (&k, v); + } + + resp + .body (body) + .unwrap () + }, + _ => status_reply (StatusCode::GATEWAY_TIMEOUT, "server didn't reply in time or somethin'"), + } +} + +fn prefix_match <'a> (hay: &'a str, needle: &str) -> Option <&'a str> +{ + if hay.starts_with (needle) { + Some (&hay [needle.len ()..]) + } + else { + None + } +} + +async fn handle_all (req: Request , state: Arc ) +-> Result , Infallible> +{ + let path = req.uri ().path (); + //println! ("{}", path); + + if req.method () == Method::POST { + return Ok (if let Some (request_code) = prefix_match (path, "/http_response/") { + let request_code = request_code.into (); + handle_http_response (req, state, request_code).await + } + else { + status_reply (StatusCode::BAD_REQUEST, "Can't POST this\n") + }); + } + + if let Some (watch_code) = prefix_match (path, "/watch/") { + Ok (handle_watch (state, watch_code.into ()).await) + } + else if let Some (watch_code) = prefix_match (path, "/wake/") { + Ok (handle_wake (state, watch_code.into ()).await) + } + else if let Some (listen_code) = prefix_match (path, "/http_listen/") { + Ok (handle_http_listen (state, listen_code.into ()).await) + } + else if let Some (rest) = prefix_match (path, "/http_request/") { + if let Some (idx) = rest.find ('/') { + let listen_code = String::from (&rest [0..idx]); + let path = String::from (&rest [idx..]); + let (parts, _) = req.into_parts (); + + Ok (handle_http_request (parts, path, state, listen_code).await) + } + else { + Ok (status_reply (StatusCode::BAD_REQUEST, "Bad URI format")) + } + } + else { + Ok (status_reply (StatusCode::OK, "Hi\n")) + } +} + +pub async fn relay_main () -> Result <(), Box > { + let addr = SocketAddr::from(([0, 0, 0, 0], 4000)); + + let state = Arc::new (ServerState::default ()); + + let make_svc = make_service_fn (|_conn| { + let state = state.clone (); + + async { + Ok::<_, Infallible> (service_fn (move |req| { + let state = state.clone (); + + handle_all (req, state) + })) + } + }); + + let server = Server::bind (&addr).serve (make_svc); + + server.await?; + + Ok (()) +} diff --git a/src/watcher.rs b/src/relay/watcher.rs similarity index 100% rename from src/watcher.rs rename to src/relay/watcher.rs