use std::{ collections::HashMap, iter::FromIterator, sync::Arc, }; use chrono::{DateTime, Utc}; use hyper::{ Body, Request, Response, StatusCode, }; use serde::{ Serialize, Serializer, }; use tracing::{ error, instrument, }; use crate::{ RequestError, error_reply, key_validity::KeyValidity, prefix_match, relay_state::RelayState, }; // JSON is probably Good Enough For Now, so I'll just make everything // a struct and lazily serialize it right before leaving the // top-level handle () fn. fn serialize_last_seen (x: &Option >, s: S) -> Result { match x { None => s.serialize_none (), Some (x) => s.serialize_str (&x.to_rfc3339 ()), } } #[derive (Serialize)] pub struct Server { pub name: String, pub display_name: String, #[serde (serialize_with = "serialize_last_seen")] pub last_seen: Option >, } #[derive (Serialize)] pub struct ServerList { pub servers: Vec , } pub async fn v1_server_list (state: &Arc ) -> ServerList { // name --> display_name let display_names: HashMap = { let guard = state.config.read ().await; let servers = (*guard).servers.iter () .map (|(k, v)| { let display_name = v.display_name .clone () .unwrap_or_else (|| k.clone ()); (k.clone (), display_name) }); HashMap::from_iter (servers) }; // name --> status let server_statuses = { let guard = state.server_status.lock ().await; (*guard).clone () }; let mut servers: Vec <_> = display_names.into_iter () .map (|(name, display_name)| { let last_seen = server_statuses.get (&name).map (|x| x.last_seen); Server { display_name, name, last_seen, } }) .collect (); servers.sort_by (|a, b| a.display_name.cmp (&b.display_name)); ServerList { servers, } } #[instrument (level = "trace", skip (req, state))] async fn api_v1 ( req: Request , state: Arc , path_rest: &str ) -> Result , RequestError> { let api_key = req.headers ().get ("X-ApiKey"); let api_key = match api_key { None => return Ok (error_reply (StatusCode::FORBIDDEN, "Can't run scraper without an API key")?), Some (x) => x, }; let bad_key = || error_reply (StatusCode::FORBIDDEN, "403 Forbidden"); { let config = state.config.read ().await; let dev_mode = match &config.iso.dev_mode { None => return Ok (bad_key ()?), Some (x) => x, }; let expected_key = match &dev_mode.scraper_key { None => return Ok (bad_key ()?), Some (x) => x, }; let now = chrono::Utc::now (); match expected_key.is_valid (now, api_key.as_bytes ()) { KeyValidity::Valid => (), KeyValidity::WrongKey (bad_hash) => { error! ("Bad scraper key with hash {:?}", bad_hash); return Ok (bad_key ()?); } err => { error! ("Bad scraper key {:?}", err); return Ok (bad_key ()?); }, } } if path_rest == "test" { Ok (error_reply (StatusCode::OK, "You're valid!")?) } else if path_rest == "server_list" { let x = v1_server_list (&state).await; Ok (error_reply (StatusCode::OK, &serde_json::to_string (&x).unwrap ())?) } else { Ok (error_reply (StatusCode::NOT_FOUND, "Unknown API endpoint")?) } } #[instrument (level = "trace", skip (req, state))] pub async fn handle ( req: Request , state: Arc , path_rest: &str ) -> Result , RequestError> { { if ! state.config.read ().await.iso.enable_scraper_api { return Ok (error_reply (StatusCode::FORBIDDEN, "Scraper API disabled")?); } } if let Some (rest) = prefix_match ("v1/", path_rest) { api_v1 (req, state, rest).await } else if let Some (rest) = prefix_match ("api/", path_rest) { api_v1 (req, state, rest).await } else { Ok (error_reply (StatusCode::NOT_FOUND, "Unknown scraper API version")?) } } #[cfg (test)] mod tests { use std::{ convert::{TryFrom, TryInto}, }; use tokio::runtime::Runtime; use crate::{ config, key_validity, }; use super::*; #[derive (Clone)] struct TestCase { // Inputs path_rest: &'static str, valid_key: Option <&'static str>, input_key: Option <&'static str>, // Expected expected_status: StatusCode, expected_headers: Vec <(&'static str, &'static str)>, expected_body: &'static str, } impl TestCase { fn path_rest (&self, v: &'static str) -> Self { let mut x = self.clone (); x.path_rest = v; x } fn valid_key (&self, v: Option <&'static str>) -> Self { let mut x = self.clone (); x.valid_key = v; x } fn input_key (&self, v: Option <&'static str>) -> Self { let mut x = self.clone (); x.input_key = v; x } fn expected_status (&self, v: StatusCode) -> Self { let mut x = self.clone (); x.expected_status = v; x } fn expected_headers (&self, v: Vec <(&'static str, &'static str)>) -> Self { let mut x = self.clone (); x.expected_headers = v; x } fn expected_body (&self, v: &'static str) -> Self { let mut x = self.clone (); x.expected_body = v; x } fn expected (&self, sc: StatusCode, body: &'static str) -> Self { self .expected_status (sc) .expected_body (body) } async fn test (&self) { let mut input = Request::builder () .method ("GET") .uri (format! ("http://127.0.0.1:4000/scraper/{}", self.path_rest)); if let Some (input_key) = self.input_key { input = input.header ("X-ApiKey", input_key); } let input = input.body (Body::empty ()).unwrap (); let config_file = config::file::Config { port: Some (4000), servers: vec! [], iso: config::file::Isomorphic { enable_scraper_auth: true, dev_mode: self.valid_key.map (|key| config::file::DevMode { scraper_key: Some (key_validity::ScraperKey::new (key.as_bytes ())), }), }, }; let config = config::Config::try_from (config_file).expect ("Can't load config"); let relay_state = Arc::new (RelayState::try_from (config).expect ("Can't create relay state")); let actual = handle_scraper_api (input, relay_state, self.path_rest).await; let actual = actual.expect ("Relay didn't respond"); let (actual_head, actual_body) = actual.into_parts (); let mut expected_headers = hyper::header::HeaderMap::new (); for (key, value) in &self.expected_headers { expected_headers.insert (*key, (*value).try_into ().expect ("Couldn't convert header value")); } assert_eq! (actual_head.status, self.expected_status); assert_eq! (actual_head.headers, expected_headers); let actual_body = hyper::body::to_bytes (actual_body).await; let actual_body = actual_body.expect ("Body should be convertible to bytes"); let actual_body = actual_body.to_vec (); let actual_body = String::from_utf8 (actual_body).expect ("Body should be UTF-8"); assert_eq! (&actual_body, self.expected_body); } } #[test] fn auth () { let mut rt = Runtime::new ().expect ("Can't create runtime for testing"); rt.block_on (async move { let base_case = TestCase { path_rest: "v1/test", valid_key: Some ("bogus"), input_key: Some ("bogus"), expected_status: StatusCode::OK, expected_headers: vec! [ ("content-type", "text/plain"), ], expected_body: "You're valid!\n", }; for case in &[ base_case.clone (), base_case.path_rest ("v9999/test") .expected (StatusCode::NOT_FOUND, "Unknown scraper API version\n"), base_case.valid_key (None) .expected (StatusCode::FORBIDDEN, "403 Forbidden\n"), base_case.input_key (Some ("borgus")) .expected (StatusCode::FORBIDDEN, "403 Forbidden\n"), base_case.path_rest ("v1/toast") .expected (StatusCode::NOT_FOUND, "Unknown API endpoint\n"), base_case.input_key (None) .expected (StatusCode::FORBIDDEN, "Can't run scraper without an API key\n"), ] { case.test ().await; } }); } }