♻️ Move the relay binary into the lib

what am i doing
main
_ 2020-10-30 17:57:36 -05:00
parent 8d3fc71dcf
commit 6b5208fdb4
4 changed files with 274 additions and 271 deletions

View File

@ -1,274 +1,6 @@
use std::{ use std::error::Error;
collections::*,
error::Error,
convert::{
Infallible,
TryFrom,
},
iter::FromIterator,
net::SocketAddr,
sync::{
Arc
},
time::{Duration},
};
use futures::channel::oneshot;
use hyper::{
Body,
Method,
Request,
Response,
Server,
StatusCode,
};
use hyper::service::{make_service_fn, service_fn};
use tokio::{
sync::Mutex,
time::delay_for,
};
use ptth::{
http_serde,
watcher::Watchers,
};
enum Message {
Meow,
HttpRequestResponse (http_serde::WrappedRequest),
HttpResponseResponseStream ((http_serde::ResponseParts, Body)),
}
#[derive (Default)]
struct ServerState {
watchers: Arc <Mutex <Watchers <Message>>>,
}
fn status_reply <B: Into <Body>> (status: StatusCode, b: B)
-> Response <Body>
{
Response::builder ().status (status).body (b.into ()).unwrap ()
}
async fn handle_watch (state: Arc <ServerState>, watcher_code: String)
-> Response <Body>
{
match Watchers::long_poll (state.watchers.clone (), watcher_code).await {
None => status_reply (StatusCode::OK, "no\n"),
Some (_) => status_reply (StatusCode::OK, "actually, yes\n"),
}
}
async fn handle_wake (state: Arc <ServerState>, watcher_code: String)
-> Response <Body>
{
let mut watchers = state.watchers.lock ().await;
if watchers.wake_one (Message::Meow, &watcher_code) {
status_reply (StatusCode::OK, "ok\n")
}
else {
status_reply (StatusCode::BAD_REQUEST, "no\n")
}
}
async fn handle_http_listen (state: Arc <ServerState>, watcher_code: String)
-> Response <Body>
{
//println! ("Step 1");
match Watchers::long_poll (state.watchers.clone (), watcher_code).await {
Some (Message::HttpRequestResponse (parts)) => {
println! ("Step 3");
status_reply (StatusCode::OK, rmp_serde::to_vec (&parts).unwrap ())
},
_ => status_reply (StatusCode::GATEWAY_TIMEOUT, "no\n"),
}
}
async fn handle_http_response (
req: Request <Body>,
state: Arc <ServerState>,
req_id: String,
)
-> Response <Body>
{
println! ("Step 6");
let (parts, body) = req.into_parts ();
let resp_parts: http_serde::ResponseParts = rmp_serde::from_read_ref (&base64::decode (parts.headers.get (ptth::PTTH_MAGIC_HEADER).unwrap ()).unwrap ()).unwrap ();
{
let mut watchers = state.watchers.lock ().await;
println! ("Step 7");
if ! watchers.wake_one (Message::HttpResponseResponseStream ((resp_parts, body)), &req_id)
{
println! ("Step 8 (bad thing)");
status_reply (StatusCode::BAD_REQUEST, "A bad thing happened.\n")
}
else {
println! ("Step 8");
status_reply (StatusCode::OK, "ok\n")
}
}
}
async fn handle_http_request (
req: http::request::Parts,
uri: String,
state: Arc <ServerState>,
watcher_code: String
)
-> Response <Body>
{
let parts = {
let id = ulid::Ulid::new ().to_string ();
let method = match ptth::http_serde::Method::try_from (req.method) {
Ok (x) => x,
_ => return status_reply (StatusCode::BAD_REQUEST, "Method not supported"),
};
let headers = HashMap::from_iter (
req.headers.into_iter ()
.filter_map (|(k, v)| k.map (|k| (k, v)))
.map (|(k, v)| (String::from (k.as_str ()), v.as_bytes ().to_vec ()))
);
http_serde::WrappedRequest {
id,
req: http_serde::RequestParts {
method,
uri,
headers,
},
}
};
println! ("Step 2 {}", parts.id);
let (s, r) = oneshot::channel ();
let timeout = Duration::from_secs (5);
let id_2 = parts.id.clone ();
{
let mut that = state.watchers.lock ().await;
that.add_watcher_with_id (s, id_2)
}
let req_id = parts.id.clone ();
tokio::spawn (async move {
{
let mut watchers = state.watchers.lock ().await;
println! ("Step 3");
if ! watchers.wake_one (Message::HttpRequestResponse (parts), &watcher_code) {
watchers.remove_watcher (&req_id);
}
}
delay_for (timeout).await;
{
let mut that = state.watchers.lock ().await;
that.remove_watcher (&req_id);
}
});
match r.await {
Ok (Message::HttpResponseResponseStream ((resp_parts, body))) => {
println! ("Step 7");
let mut resp = Response::builder ()
.status (hyper::StatusCode::from (resp_parts.status_code));
for (k, v) in resp_parts.headers.into_iter () {
resp = resp.header (&k, v);
}
resp
.body (body)
.unwrap ()
},
_ => status_reply (StatusCode::GATEWAY_TIMEOUT, "server didn't reply in time or somethin'"),
}
}
fn prefix_match <'a> (hay: &'a str, needle: &str) -> Option <&'a str>
{
if hay.starts_with (needle) {
Some (&hay [needle.len ()..])
}
else {
None
}
}
async fn handle_all (req: Request <Body>, state: Arc <ServerState>)
-> Result <Response <Body>, Infallible>
{
let path = req.uri ().path ();
//println! ("{}", path);
if req.method () == Method::POST {
return Ok (if let Some (request_code) = prefix_match (path, "/http_response/") {
let request_code = request_code.into ();
handle_http_response (req, state, request_code).await
}
else {
status_reply (StatusCode::BAD_REQUEST, "Can't POST this\n")
});
}
if let Some (watch_code) = prefix_match (path, "/watch/") {
Ok (handle_watch (state, watch_code.into ()).await)
}
else if let Some (watch_code) = prefix_match (path, "/wake/") {
Ok (handle_wake (state, watch_code.into ()).await)
}
else if let Some (listen_code) = prefix_match (path, "/http_listen/") {
Ok (handle_http_listen (state, listen_code.into ()).await)
}
else if let Some (rest) = prefix_match (path, "/http_request/") {
if let Some (idx) = rest.find ('/') {
let listen_code = String::from (&rest [0..idx]);
let path = String::from (&rest [idx..]);
let (parts, _) = req.into_parts ();
Ok (handle_http_request (parts, path, state, listen_code).await)
}
else {
Ok (status_reply (StatusCode::BAD_REQUEST, "Bad URI format"))
}
}
else {
Ok (status_reply (StatusCode::OK, "Hi\n"))
}
}
async fn relay_main () -> Result <(), Box <dyn Error>> {
let addr = SocketAddr::from(([0, 0, 0, 0], 4000));
let state = Arc::new (ServerState::default ());
let make_svc = make_service_fn (|_conn| {
let state = state.clone ();
async {
Ok::<_, Infallible> (service_fn (move |req| {
let state = state.clone ();
handle_all (req, state)
}))
}
});
let server = Server::bind (&addr).serve (make_svc);
server.await?;
Ok (())
}
#[tokio::main] #[tokio::main]
async fn main () -> Result <(), Box <dyn Error>> { async fn main () -> Result <(), Box <dyn Error>> {
relay_main ().await ptth::relay::relay_main ().await
} }

View File

@ -1,5 +1,5 @@
pub mod file_server; pub mod file_server;
pub mod http_serde; pub mod http_serde;
pub mod watcher; pub mod relay;
pub const PTTH_MAGIC_HEADER: &str = "X-PTTH-2LJYXWC4"; pub const PTTH_MAGIC_HEADER: &str = "X-PTTH-2LJYXWC4";

271
src/relay/mod.rs Normal file
View File

@ -0,0 +1,271 @@
pub mod watcher;
use std::{
collections::*,
error::Error,
convert::{
Infallible,
TryFrom,
},
iter::FromIterator,
net::SocketAddr,
sync::{
Arc
},
time::{Duration},
};
use futures::channel::oneshot;
use hyper::{
Body,
Method,
Request,
Response,
Server,
StatusCode,
};
use hyper::service::{make_service_fn, service_fn};
use tokio::{
sync::Mutex,
time::delay_for,
};
use crate::{
http_serde,
};
use watcher::*;
enum Message {
Meow,
HttpRequestResponse (http_serde::WrappedRequest),
HttpResponseResponseStream ((http_serde::ResponseParts, Body)),
}
#[derive (Default)]
struct ServerState {
watchers: Arc <Mutex <Watchers <Message>>>,
}
fn status_reply <B: Into <Body>> (status: StatusCode, b: B)
-> Response <Body>
{
Response::builder ().status (status).body (b.into ()).unwrap ()
}
async fn handle_watch (state: Arc <ServerState>, watcher_code: String)
-> Response <Body>
{
match Watchers::long_poll (state.watchers.clone (), watcher_code).await {
None => status_reply (StatusCode::OK, "no\n"),
Some (_) => status_reply (StatusCode::OK, "actually, yes\n"),
}
}
async fn handle_wake (state: Arc <ServerState>, watcher_code: String)
-> Response <Body>
{
let mut watchers = state.watchers.lock ().await;
if watchers.wake_one (Message::Meow, &watcher_code) {
status_reply (StatusCode::OK, "ok\n")
}
else {
status_reply (StatusCode::BAD_REQUEST, "no\n")
}
}
async fn handle_http_listen (state: Arc <ServerState>, watcher_code: String)
-> Response <Body>
{
//println! ("Step 1");
match Watchers::long_poll (state.watchers.clone (), watcher_code).await {
Some (Message::HttpRequestResponse (parts)) => {
println! ("Step 3");
status_reply (StatusCode::OK, rmp_serde::to_vec (&parts).unwrap ())
},
_ => status_reply (StatusCode::GATEWAY_TIMEOUT, "no\n"),
}
}
async fn handle_http_response (
req: Request <Body>,
state: Arc <ServerState>,
req_id: String,
)
-> Response <Body>
{
println! ("Step 6");
let (parts, body) = req.into_parts ();
let resp_parts: http_serde::ResponseParts = rmp_serde::from_read_ref (&base64::decode (parts.headers.get (crate::PTTH_MAGIC_HEADER).unwrap ()).unwrap ()).unwrap ();
{
let mut watchers = state.watchers.lock ().await;
println! ("Step 7");
if ! watchers.wake_one (Message::HttpResponseResponseStream ((resp_parts, body)), &req_id)
{
println! ("Step 8 (bad thing)");
status_reply (StatusCode::BAD_REQUEST, "A bad thing happened.\n")
}
else {
println! ("Step 8");
status_reply (StatusCode::OK, "ok\n")
}
}
}
async fn handle_http_request (
req: http::request::Parts,
uri: String,
state: Arc <ServerState>,
watcher_code: String
)
-> Response <Body>
{
let parts = {
let id = ulid::Ulid::new ().to_string ();
let method = match http_serde::Method::try_from (req.method) {
Ok (x) => x,
_ => return status_reply (StatusCode::BAD_REQUEST, "Method not supported"),
};
let headers = HashMap::from_iter (
req.headers.into_iter ()
.filter_map (|(k, v)| k.map (|k| (k, v)))
.map (|(k, v)| (String::from (k.as_str ()), v.as_bytes ().to_vec ()))
);
http_serde::WrappedRequest {
id,
req: http_serde::RequestParts {
method,
uri,
headers,
},
}
};
println! ("Step 2 {}", parts.id);
let (s, r) = oneshot::channel ();
let timeout = Duration::from_secs (5);
let id_2 = parts.id.clone ();
{
let mut that = state.watchers.lock ().await;
that.add_watcher_with_id (s, id_2)
}
let req_id = parts.id.clone ();
tokio::spawn (async move {
{
let mut watchers = state.watchers.lock ().await;
println! ("Step 3");
if ! watchers.wake_one (Message::HttpRequestResponse (parts), &watcher_code) {
watchers.remove_watcher (&req_id);
}
}
delay_for (timeout).await;
{
let mut that = state.watchers.lock ().await;
that.remove_watcher (&req_id);
}
});
match r.await {
Ok (Message::HttpResponseResponseStream ((resp_parts, body))) => {
println! ("Step 7");
let mut resp = Response::builder ()
.status (hyper::StatusCode::from (resp_parts.status_code));
for (k, v) in resp_parts.headers.into_iter () {
resp = resp.header (&k, v);
}
resp
.body (body)
.unwrap ()
},
_ => status_reply (StatusCode::GATEWAY_TIMEOUT, "server didn't reply in time or somethin'"),
}
}
fn prefix_match <'a> (hay: &'a str, needle: &str) -> Option <&'a str>
{
if hay.starts_with (needle) {
Some (&hay [needle.len ()..])
}
else {
None
}
}
async fn handle_all (req: Request <Body>, state: Arc <ServerState>)
-> Result <Response <Body>, Infallible>
{
let path = req.uri ().path ();
//println! ("{}", path);
if req.method () == Method::POST {
return Ok (if let Some (request_code) = prefix_match (path, "/http_response/") {
let request_code = request_code.into ();
handle_http_response (req, state, request_code).await
}
else {
status_reply (StatusCode::BAD_REQUEST, "Can't POST this\n")
});
}
if let Some (watch_code) = prefix_match (path, "/watch/") {
Ok (handle_watch (state, watch_code.into ()).await)
}
else if let Some (watch_code) = prefix_match (path, "/wake/") {
Ok (handle_wake (state, watch_code.into ()).await)
}
else if let Some (listen_code) = prefix_match (path, "/http_listen/") {
Ok (handle_http_listen (state, listen_code.into ()).await)
}
else if let Some (rest) = prefix_match (path, "/http_request/") {
if let Some (idx) = rest.find ('/') {
let listen_code = String::from (&rest [0..idx]);
let path = String::from (&rest [idx..]);
let (parts, _) = req.into_parts ();
Ok (handle_http_request (parts, path, state, listen_code).await)
}
else {
Ok (status_reply (StatusCode::BAD_REQUEST, "Bad URI format"))
}
}
else {
Ok (status_reply (StatusCode::OK, "Hi\n"))
}
}
pub async fn relay_main () -> Result <(), Box <dyn Error>> {
let addr = SocketAddr::from(([0, 0, 0, 0], 4000));
let state = Arc::new (ServerState::default ());
let make_svc = make_service_fn (|_conn| {
let state = state.clone ();
async {
Ok::<_, Infallible> (service_fn (move |req| {
let state = state.clone ();
handle_all (req, state)
}))
}
});
let server = Server::bind (&addr).serve (make_svc);
server.await?;
Ok (())
}