diff --git a/crates/ptth_relay/src/scraper_api.rs b/crates/ptth_relay/src/scraper_api.rs index 299ecde..e06234d 100644 --- a/crates/ptth_relay/src/scraper_api.rs +++ b/crates/ptth_relay/src/scraper_api.rs @@ -119,6 +119,21 @@ pub async fn v1_server_list (state: &Relay) } } +fn get_api_key (headers: &hyper::HeaderMap) -> Option <&str> +{ + if let Some (key) = headers.get ("X-ApiKey").and_then (|v| v.to_str ().ok ()) { + return Some (key); + } + + if let Some (s) = headers.get ("Authorization").and_then (|v| v.to_str ().ok ()) { + if let Some (key) = s.strip_prefix ("Bearer ") { + return Some (key); + } + } + + None +} + #[instrument (level = "trace", skip (req, state))] async fn api_v1 ( req: Request , @@ -132,7 +147,7 @@ async fn api_v1 ( AuditEvent, }; - let api_key = req.headers ().get ("X-ApiKey"); + let api_key = get_api_key (req.headers ()); let api_key = match api_key { None => return Ok (error_reply (StatusCode::FORBIDDEN, strings::NO_API_KEY)?), @@ -351,7 +366,7 @@ mod tests { .expected_body (format! ("{}\n", body)) } - async fn test (&self) { + async fn test (&self, name: &str) { let mut input = Request::builder () .method ("GET") .uri (format! ("http://127.0.0.1:4000/scraper/{}", self.path_rest)); @@ -387,15 +402,15 @@ mod tests { expected_headers.insert (*key, (*value).try_into ().expect ("Couldn't convert header value")); } - assert_eq! (actual_head.status, self.expected_status); - assert_eq! (actual_head.headers, expected_headers); + assert_eq! (actual_head.status, self.expected_status, "{}", name); + assert_eq! (actual_head.headers, expected_headers, "{}", name); let actual_body = hyper::body::to_bytes (actual_body).await; let actual_body = actual_body.expect ("Body should be convertible to bytes"); let actual_body = actual_body.to_vec (); let actual_body = String::from_utf8 (actual_body).expect ("Body should be UTF-8"); - assert_eq! (actual_body, self.expected_body); + assert_eq! (actual_body, self.expected_body, "{}", name); } } @@ -417,38 +432,38 @@ mod tests { }; base_case - .test ().await; + .test ("00").await; base_case .path_rest ("v9999/test") .expected (StatusCode::NOT_FOUND, strings::UNKNOWN_API_VERSION) - .test ().await; + .test ("01").await; base_case .valid_key (None) .expected (StatusCode::FORBIDDEN, strings::FORBIDDEN) - .test ().await; + .test ("02").await; base_case .x_api_key (Some ("borgus")) .expected (StatusCode::FORBIDDEN, strings::FORBIDDEN) - .test ().await; + .test ("03").await; base_case .path_rest ("v1/toast") .expected (StatusCode::NOT_FOUND, strings::UNKNOWN_API_ENDPOINT) - .test ().await; + .test ("04").await; base_case .x_api_key (None) .expected (StatusCode::FORBIDDEN, strings::NO_API_KEY) - .test ().await; + .test ("05").await; base_case .x_api_key (None) - .auth_header (Some ("Bearer: bogus")) - .expected (StatusCode::FORBIDDEN, strings::NO_API_KEY) - .test ().await; + .auth_header (Some ("Bearer bogus")) + .expected (StatusCode::OK, "You're valid!") + .test ("06").await; }); }