diff --git a/crates/ptth_core/src/http_serde.rs b/crates/ptth_core/src/http_serde.rs index f6afb0e..9b31372 100644 --- a/crates/ptth_core/src/http_serde.rs +++ b/crates/ptth_core/src/http_serde.rs @@ -28,9 +28,13 @@ impl TryFrom for Method { type Error = Error; fn try_from (x: hyper::Method) -> Result { + use hyper::Method; + match x { - hyper::Method::GET => Ok (Self::Get), - hyper::Method::HEAD => Ok (Self::Head), + Method::GET => Ok (Self::Get), + Method::HEAD => Ok (Self::Head), + Method::POST => Ok (Self::Post), + Method::PUT => Ok (Self::Put), _ => Err (Error::UnsupportedMethod), } } diff --git a/crates/ptth_relay/src/lib.rs b/crates/ptth_relay/src/lib.rs index 46084bd..495a6a1 100644 --- a/crates/ptth_relay/src/lib.rs +++ b/crates/ptth_relay/src/lib.rs @@ -129,10 +129,8 @@ async fn handle_http_request ( let user = get_user_name (&req); - let req = match http_serde::RequestParts::from_hyper (req.method, uri.clone (), req.headers) { - Ok (x) => x, - Err (_) => return Err (BadRequest), - }; + let req = http_serde::RequestParts::from_hyper (req.method, uri.clone (), req.headers) + .map_err (|_| BadRequest)?; let (tx, rx) = oneshot::channel (); @@ -597,12 +595,8 @@ async fn handle_all ( let response = match e { Error::BadUriFormat => error_reply (StatusCode::BAD_REQUEST, "Bad URI format")?, - Error::CantPost => { - error! ("Can't POST {}", path); - error_reply (StatusCode::BAD_REQUEST, "Can't POST this")? - }, Error::MethodNotAllowed => error_reply (StatusCode::METHOD_NOT_ALLOWED, "Method not allowed. Are you POST-ing to a GET-only url, or vice versa?")?, - Error::RoutingFailed => error_reply (StatusCode::OK, "URL routing failed")?, + Error::NotFound => error_reply (StatusCode::OK, "URL routing failed")?, }; return Ok (response); }, @@ -628,13 +622,6 @@ async fn handle_all ( DebugEndlessSource (throttle) => handle_endless_source (1, throttle).await?, DebugGenKey => handle_gen_scraper_key (state).await?, DebugMysteriousError => return Err (RequestError::Mysterious), - ErrorBadUriFormat => error_reply (StatusCode::BAD_REQUEST, "Bad URI format")?, - ErrorCantPost => { - error! ("Can't POST {}", path); - error_reply (StatusCode::BAD_REQUEST, "Can't POST this")? - }, - ErrorMethodNotAllowed => error_reply (StatusCode::METHOD_NOT_ALLOWED, "Method not allowed. Are you POST-ing to a GET-only url, or vice versa?")?, - ErrorRoutingFailed => error_reply (StatusCode::OK, "URL routing failed")?, RegisterServer => { match handle_register_server (req, state).await { Ok (_) => Response::builder () diff --git a/crates/ptth_relay/src/routing.rs b/crates/ptth_relay/src/routing.rs index c478c1e..94c8572 100644 --- a/crates/ptth_relay/src/routing.rs +++ b/crates/ptth_relay/src/routing.rs @@ -15,10 +15,6 @@ pub enum Route <'a> { DebugEndlessSource (Option ), DebugGenKey, DebugMysteriousError, - ErrorBadUriFormat, - ErrorCantPost, - ErrorMethodNotAllowed, - ErrorRoutingFailed, RegisterServer, Root, Scraper { @@ -35,44 +31,33 @@ pub enum Route <'a> { #[derive (Debug, PartialEq)] pub enum Error { BadUriFormat, - CantPost, MethodNotAllowed, - RoutingFailed, + NotFound, } -type Result <'a> = std::result::Result , Error>; - -pub fn route_url <'a> (method: &Method, path: &'a str) -> Result <'a> { +pub fn route_url <'a> (method: &Method, path: &'a str) -> Result , Error> { if let Some (listen_code) = path.strip_prefix ("/7ZSFUKGV/http_listen/") { - if method != Method::GET { - return Err (Error::MethodNotAllowed); - } + assert_method (method, Method::GET)?; Ok (Route::ServerHttpListen { listen_code }) } else if let Some (request_code) = path.strip_prefix ("/7ZSFUKGV/http_response/") { - if method != Method::POST { - return Err (Error::MethodNotAllowed); - } + assert_method (method, Method::POST)?; Ok (Route::ServerHttpResponse { request_code }) } else if path == "/frontend/register" { - if method != Method::POST { - return Err (Error::MethodNotAllowed); - } + assert_method (method, Method::POST)?; Ok (Route::RegisterServer) } else if let Some (rest) = path.strip_prefix ("/frontend/servers/") { // DRY T4H76LB3 if rest.is_empty () { - if method != Method::GET { - return Err (Error::MethodNotAllowed); - } + assert_method (method, Method::GET)?; Ok (Route::ClientServerList) } else if let Some (idx) = rest.find ('/') { @@ -84,87 +69,73 @@ pub fn route_url <'a> (method: &Method, path: &'a str) -> Result <'a> { }) } else { - if method != Method::GET { - return Err (Error::MethodNotAllowed); - } - Ok (Route::ErrorBadUriFormat) + assert_method (method, Method::GET)?; + Err (Error::BadUriFormat) } } else if path == "/frontend/unregistered_servers" { - if method != Method::GET { - return Err (Error::MethodNotAllowed); - } + assert_method (method, Method::GET)?; Ok (Route::ClientUnregisteredServers) } else if path == "/frontend/audit_log" { - if method != Method::GET { - return Err (Error::MethodNotAllowed); - } + assert_method (method, Method::GET)?; Ok (Route::ClientAuditLog) } else if let Some (rest) = path.strip_prefix ("/frontend/debug/") { if rest.is_empty () { - if method != Method::GET { - return Err (Error::MethodNotAllowed); - } + assert_method (method, Method::GET)?; Ok (Route::Debug) } else if rest == "endless_source" { - if method != Method::GET { - return Err (Error::MethodNotAllowed); - } + assert_method (method, Method::GET)?; Ok (Route::DebugEndlessSource (None)) } else if rest == "endless_source_throttled" { - if method != Method::GET { - return Err (Error::MethodNotAllowed); - } + assert_method (method, Method::GET)?; Ok (Route::DebugEndlessSource (Some (1024 / 64))) } else if rest == "endless_sink" { - if method != Method::POST { - return Err (Error::MethodNotAllowed); - } + assert_method (method, Method::POST)?; Ok (Route::DebugEndlessSink) } else if rest == "gen_key" { - if method != Method::GET { - return Err (Error::MethodNotAllowed); - } + assert_method (method, Method::GET)?; Ok (Route::DebugGenKey) } else { - Ok (Route::ErrorRoutingFailed) + Err (Error::NotFound) } } else if path == "/" { - if method != Method::GET { - return Err (Error::MethodNotAllowed); - } + assert_method (method, Method::GET)?; Ok (Route::Root) } else if path == "/frontend/relay_up_check" { - if method != Method::GET { - return Err (Error::MethodNotAllowed); - } + assert_method (method, Method::GET)?; Ok (Route::ClientRelayIsUp) } else if path == "/frontend/test_mysterious_error" { - if method != Method::GET { - return Err (Error::MethodNotAllowed); - } + assert_method (method, Method::GET)?; Ok (Route::DebugMysteriousError) } else if let Some (rest) = path.strip_prefix ("/scraper/") { - if method != Method::GET { - return Err (Error::MethodNotAllowed); - } + assert_method (method, Method::GET)?; Ok (Route::Scraper { rest }) } else { - Err (Error::RoutingFailed) + Err (Error::NotFound) + } +} + +fn assert_method > (method: M, expected: Method) -> Result <(), Error> +{ + if method == expected { + Ok (()) + } + else { + Err (Error::MethodNotAllowed) } }