From ba17f11297782f2a5ccba0eb700501965832e369 Mon Sep 17 00:00:00 2001 From: _ <_@_> Date: Fri, 30 Oct 2020 20:35:39 -0500 Subject: [PATCH] :recycle: Extract http_serde::Body so I can make the file server standalone --- src/bin/file_server.rs | 113 ++++++++++++++++++++++++++++++++++++++ src/http_serde.rs | 47 ++++++++++++++-- src/relay/mod.rs | 25 +++------ src/server/file_server.rs | 43 ++++++++++++--- src/server/mod.rs | 4 +- 5 files changed, 201 insertions(+), 31 deletions(-) create mode 100644 src/bin/file_server.rs diff --git a/src/bin/file_server.rs b/src/bin/file_server.rs new file mode 100644 index 0000000..fda269d --- /dev/null +++ b/src/bin/file_server.rs @@ -0,0 +1,113 @@ +use std::{ + convert::Infallible, + error::Error, + path::PathBuf, + sync::Arc, + net::SocketAddr, +}; + +use hyper::{ + Body, + Request, + Response, + Server, + service::{ + make_service_fn, + service_fn, + }, + StatusCode, +}; + +#[derive (Default)] +struct ServerState { + // Pass +} + +fn status_reply > (status: StatusCode, b: B) +-> Response +{ + Response::builder ().status (status).body (b.into ()).unwrap () +} + +fn prefix_match <'a> (hay: &'a str, needle: &str) -> Option <&'a str> +{ + if hay.starts_with (needle) { + Some (&hay [needle.len ()..]) + } + else { + None + } +} + +async fn handle_all (req: Request , _state: Arc ) +-> Result , Infallible> +{ + use ptth::{ + http_serde::RequestParts, + server::file_server, + }; + + let path = req.uri ().path (); + //println! ("{}", path); + + if let Some (path) = prefix_match (path, "/files") { + let root = PathBuf::from ("./"); + + let path = path.into (); + + let (parts, _) = req.into_parts (); + + let ptth_req = match RequestParts::from_hyper (parts.method, path, parts.headers) { + Ok (x) => x, + _ => return Ok (status_reply (StatusCode::BAD_REQUEST, "Bad request")), + }; + + let ptth_resp = file_server::serve_all (&root, ptth_req).await; + + let mut resp = Response::builder () + .status (StatusCode::from (ptth_resp.parts.status_code)); + + use std::str::FromStr; + + for (k, v) in ptth_resp.parts.headers.into_iter () { + resp = resp.header (hyper::header::HeaderName::from_str (&k).unwrap (), v); + } + + let body = ptth_resp.body + .map (|b| Body::wrap_stream (b)) + .unwrap_or_else (|| Body::empty ()) + ; + + let resp = resp.body (body).unwrap (); + + Ok (resp) + } + else { + Ok (status_reply (StatusCode::NOT_FOUND, "404 Not Found\n")) + } +} + +#[tokio::main] +async fn main () -> Result <(), Box > { + let addr = SocketAddr::from(([0, 0, 0, 0], 4000)); + + let state = Arc::new (ServerState::default ()); + + let make_svc = make_service_fn (|_conn| { + let state = state.clone (); + + async { + Ok::<_, Infallible> (service_fn (move |req| { + let state = state.clone (); + + handle_all (req, state) + })) + } + }); + + let server = Server::bind (&addr).serve (make_svc); + + server.await?; + + Ok (()) +} diff --git a/src/http_serde.rs b/src/http_serde.rs index 1927345..bdb40fd 100644 --- a/src/http_serde.rs +++ b/src/http_serde.rs @@ -4,12 +4,13 @@ use std::{ }; use serde::{Deserialize, Serialize}; +use tokio::sync::mpsc::Receiver; // Hyper doesn't seem to make it easy to de/ser requests // and responses and stuff like that, so I do it by hand here. pub enum Error { - Unsupported, + UnsupportedMethod, InvalidHeaderName, } @@ -32,7 +33,7 @@ impl TryFrom for Method { match x { hyper::Method::GET => Ok (Self::Get), hyper::Method::HEAD => Ok (Self::Head), - _ => Err (Error::Unsupported), + _ => Err (Error::UnsupportedMethod), } } } @@ -52,6 +53,30 @@ pub struct RequestParts { pub headers: HashMap >, } +impl RequestParts { + pub fn from_hyper ( + method: hyper::Method, + uri: String, + headers: hyper::HeaderMap + ) -> Result + { + use std::iter::FromIterator; + + let method = Method::try_from (method)?; + let headers = HashMap::from_iter ( + headers.into_iter () + .filter_map (|(k, v)| k.map (|k| (k, v))) + .map (|(k, v)| (String::from (k.as_str ()), v.as_bytes ().to_vec ())) + ); + + Ok (Self { + method, + uri, + headers, + }) + } +} + #[derive (Deserialize, Serialize)] pub struct WrappedRequest { pub id: String, @@ -92,10 +117,15 @@ pub struct ResponseParts { pub headers: HashMap >, } +// reqwest and hyper have different Body types for _some reason_ +// so I have to do everything myself + +type Body = Receiver , std::convert::Infallible>>; + #[derive (Default)] pub struct Response { pub parts: ResponseParts, - pub body: Option , + pub body: Option , } impl Response { @@ -109,8 +139,17 @@ impl Response { self } - pub fn body (&mut self, b: reqwest::Body) -> &mut Self { + pub fn body (&mut self, b: Body) -> &mut Self { self.body = Some (b); self } + + pub fn body_bytes (&mut self, b: Vec ) -> &mut Self { + let (mut tx, rx) = tokio::sync::mpsc::channel (1); + tokio::spawn (async move { + tx.send (Ok (b)).await.unwrap (); + }); + self.body = Some (rx); + self + } } diff --git a/src/relay/mod.rs b/src/relay/mod.rs index b8b8a00..9fb077b 100644 --- a/src/relay/mod.rs +++ b/src/relay/mod.rs @@ -1,13 +1,8 @@ pub mod watcher; use std::{ - collections::*, error::Error, - convert::{ - Infallible, - TryFrom, - }, - iter::FromIterator, + convert::Infallible, net::SocketAddr, sync::{ Arc @@ -125,23 +120,14 @@ async fn handle_http_request ( { let parts = { let id = ulid::Ulid::new ().to_string (); - let method = match http_serde::Method::try_from (req.method) { + let req = match http_serde::RequestParts::from_hyper (req.method, uri, req.headers) { Ok (x) => x, - _ => return status_reply (StatusCode::BAD_REQUEST, "Method not supported"), + _ => return status_reply (StatusCode::BAD_REQUEST, "Bad request"), }; - let headers = HashMap::from_iter ( - req.headers.into_iter () - .filter_map (|(k, v)| k.map (|k| (k, v))) - .map (|(k, v)| (String::from (k.as_str ()), v.as_bytes ().to_vec ())) - ); http_serde::WrappedRequest { id, - req: http_serde::RequestParts { - method, - uri, - headers, - }, + req, } }; @@ -211,6 +197,9 @@ async fn handle_all (req: Request , state: Arc ) //println! ("{}", path); if req.method () == Method::POST { + // This is stuff the server can use. Clients can't + // POST right now + return Ok (if let Some (request_code) = prefix_match (path, "/http_response/") { let request_code = request_code.into (); handle_http_response (req, state, request_code).await diff --git a/src/server/file_server.rs b/src/server/file_server.rs index 0748f0b..dd63ef5 100644 --- a/src/server/file_server.rs +++ b/src/server/file_server.rs @@ -7,10 +7,6 @@ use std::{ path::{Path, PathBuf}, }; -// file_server shouldn't depend on reqwest, but for now it -// does depend on their Body struct -use reqwest::Body; - use tokio::{ fs::{ File, @@ -103,7 +99,7 @@ async fn serve_dir (mut dir: ReadDir) -> http_serde::Response { let mut response = http_serde::Response::default (); response.header ("content-type".into (), String::from ("text/html").into_bytes ()); - response.body (Body::wrap_stream (rx)); + response.body (rx); response } @@ -116,7 +112,7 @@ async fn serve_file ( ) -> http_serde::Response { let (tx, rx) = channel (2); let body = if should_send_body { - Some (Body::wrap_stream (rx)) + Some (rx) } else { None @@ -200,7 +196,7 @@ async fn serve_error ( { let mut resp = http_serde::Response::default (); resp.status_code (status_code) - .body (msg.into ()); + .body_bytes (msg.into_bytes ()); resp } @@ -232,7 +228,10 @@ pub async fn serve_all ( use percent_encoding::*; + // TODO: There is totally a dir traversal attack in here somewhere + let encoded_path = &parts.uri [1..]; + let path = percent_decode (encoded_path.as_bytes ()).decode_utf8 ().unwrap (); let mut full_path = PathBuf::from (root); @@ -253,3 +252,33 @@ pub async fn serve_all ( serve_error (http_serde::StatusCode::NotFound, "404 Not Found".into ()).await } } + +#[cfg (test)] +mod tests { + #[test] + fn i_hate_paths () { + use std::{ + ffi::OsStr, + path::{Component, Path} + }; + + let mut components = Path::new ("/home/user").components (); + + assert_eq! (components.next (), Some (Component::RootDir)); + assert_eq! (components.next (), Some (Component::Normal (OsStr::new ("home")))); + assert_eq! (components.next (), Some (Component::Normal (OsStr::new ("user")))); + assert_eq! (components.next (), None); + + let mut components = Path::new ("./home/user").components (); + + assert_eq! (components.next (), Some (Component::CurDir)); + assert_eq! (components.next (), Some (Component::Normal (OsStr::new ("home")))); + assert_eq! (components.next (), Some (Component::Normal (OsStr::new ("user")))); + assert_eq! (components.next (), None); + + let mut components = Path::new (".").components (); + + assert_eq! (components.next (), Some (Component::CurDir)); + assert_eq! (components.next (), None); + } +} diff --git a/src/server/mod.rs b/src/server/mod.rs index 05cb8f4..7d9b42d 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -15,7 +15,7 @@ use tokio::{ use crate::http_serde; -mod file_server; +pub mod file_server; const SERVER_NAME: &str = "alien_wildlands"; @@ -47,7 +47,7 @@ async fn handle_req_resp ( .header (crate::PTTH_MAGIC_HEADER, base64::encode (rmp_serde::to_vec (&response.parts).unwrap ())); if let Some (body) = response.body { - resp_req = resp_req.body (body); + resp_req = resp_req.body (reqwest::Body::wrap_stream (body)); } //println! ("Step 6");