use std::{ collections::HashMap, }; use chrono::{DateTime, Utc}; use hyper::{ Body, Request, Response, StatusCode, }; use serde::{ Serialize, Serializer, }; use tracing::{ error, instrument, }; use crate::{ RequestError, error_reply, key_validity::{ BlakeHashWrapper, KeyValidity, }, relay_state::Relay, }; // Not sure if this is the best way to do a hard-coded string table, but // it'll keep the tests up to date mod strings { pub const FORBIDDEN: &str = "403 Forbidden"; pub const NO_API_KEY: &str = "Can't auth as scraper without API key"; pub const UNKNOWN_API_VERSION: &str = "Unknown scraper API version"; pub const UNKNOWN_API_ENDPOINT: &str = "Unknown scraper API endpoint"; } // JSON is probably Good Enough For Now, so I'll just make everything // a struct and lazily serialize it right before leaving the // top-level handle () fn. fn serialize_last_seen (x: &Option >, s: S) -> Result { match x { None => s.serialize_none (), Some (x) => s.serialize_str (&x.to_rfc3339 ()), } } #[derive (Serialize)] pub struct Server { pub name: String, pub display_name: String, #[serde (serialize_with = "serialize_last_seen")] pub last_seen: Option >, } #[derive (Serialize)] pub struct ServerList { pub servers: Vec , } pub async fn v1_server_list (state: &Relay) -> ServerList { // name --> display_name let mut display_names = HashMap::new (); { let guard = state.config.read ().await; for (k, v) in &guard.servers { let display_name = v.display_name .clone () .unwrap_or_else (|| k.clone ()); display_names.insert (k.clone (), display_name); } } { let guard = state.me_config.read ().await; for (k, v) in &guard.servers { let display_name = v.display_name .clone () .unwrap_or_else (|| k.clone ()); display_names.insert (k.clone (), display_name); } } // name --> status let server_statuses = { let guard = state.server_status.lock ().await; (*guard).clone () }; let mut servers: Vec <_> = display_names.into_iter () .map (|(name, display_name)| { let last_seen = server_statuses.get (&name).map (|x| x.last_seen); Server { display_name, name, last_seen, } }) .collect (); servers.sort_by (|a, b| a.display_name.cmp (&b.display_name)); ServerList { servers, } } fn get_api_key (headers: &hyper::HeaderMap) -> Option <&str> { if let Some (key) = headers.get ("X-ApiKey").and_then (|v| v.to_str ().ok ()) { return Some (key); } if let Some (s) = headers.get ("Authorization").and_then (|v| v.to_str ().ok ()) { if let Some (key) = s.strip_prefix ("Bearer ") { return Some (key); } } None } #[instrument (level = "trace", skip (req, state))] async fn api_v1 ( req: Request , state: &Relay, path_rest: &str ) -> Result , RequestError> { use crate::{ AuditData, AuditEvent, }; let api_key = get_api_key (req.headers ()); let api_key = match api_key { None => return Ok (error_reply (StatusCode::FORBIDDEN, strings::NO_API_KEY)?), Some (x) => x, }; let actual_hash = BlakeHashWrapper::from_key (api_key.as_bytes ()); let bad_key = || error_reply (StatusCode::FORBIDDEN, strings::FORBIDDEN); let key_name; { let config = state.config.read ().await; let expected_key = match config.scraper_keys.get (&actual_hash.encode_base64 ()) { Some (x) => x, None => return Ok (bad_key ()?), }; let now = chrono::Utc::now (); // The hash in is_valid is redundant match expected_key.is_valid (now, api_key.as_bytes ()) { KeyValidity::Valid => (), KeyValidity::WrongKey (bad_hash) => { error! ("Bad scraper key with hash {:?}", bad_hash); return Ok (bad_key ()?); } err => { error! ("Bad scraper key {:?}", err); return Ok (bad_key ()?); }, } key_name = expected_key.name.to_string (); } state.audit_log.push (AuditEvent::new (AuditData::ScraperGet { key_name, path: path_rest.to_string (), })).await; if path_rest == "metrics" { Ok (metrics (state).await?) } else if path_rest == "test" { Ok (error_reply (StatusCode::OK, "You're valid!")?) } else if path_rest == "server_list" { let x = v1_server_list (&state).await; Ok (error_reply (StatusCode::OK, &serde_json::to_string (&x).unwrap ())?) } else if let Some (rest) = path_rest.strip_prefix ("server/") { // DRY T4H76LB3 if let Some (idx) = rest.find ('/') { let listen_code = String::from (&rest [0..idx]); let path = String::from (&rest [idx..]); let (parts, _) = req.into_parts (); // This is ugly. I don't like having scraper_api know about the // crate root. Ok (crate::handle_http_request (parts, path, &state, &listen_code).await?) } else { Ok (error_reply (StatusCode::BAD_REQUEST, "Bad URI format")?) } } else { Ok (error_reply (StatusCode::NOT_FOUND, strings::UNKNOWN_API_ENDPOINT)?) } } #[instrument (level = "trace", skip (state))] async fn metrics ( state: &Relay, ) -> Result , RequestError> { let mut s = String::with_capacity (4 * 1_024); let mut push_metric = |name, help: Option<&str>, kind, value| { if let Some (help) = help { s.push_str (format! ("# HELP {} {}\n", name, help).as_str ()); } s.push_str (format! ("# TYPE {} {}\n", name, kind).as_str ()); s.push_str (format! ("{} {}\n", name, value).as_str ()); }; let request_rendezvous_count = { let g = state.request_rendezvous.lock ().await; g.len () }; let server_status_count; let connected_server_count; let now = Utc::now (); { let g = state.server_status.lock ().await; server_status_count = g.len (); connected_server_count = g.iter () .filter (|(_, s)| now - s.last_seen < chrono::Duration::seconds (60)) .count (); } let response_rendezvous_count = { let g = state.response_rendezvous.read ().await; g.len () }; push_metric ("request_rendezvous_count", None, "gauge", request_rendezvous_count.to_string ()); push_metric ("server_status_count", None, "gauge", server_status_count.to_string ()); push_metric ("connected_server_count", None, "gauge", connected_server_count.to_string ()); push_metric ("response_rendezvous_count", None, "gauge", response_rendezvous_count.to_string ()); #[cfg (target_os = "linux")] { if let Some (rss) = tokio::fs::read_to_string ("/proc/self/status").await .ok () .and_then (|s| get_rss_from_status (s.as_str ())) { push_metric ("relay_vm_rss", Some ("VmRSS of the relay process, in kB"), "gauge", rss.to_string ()); } } Ok (Response::builder () .body (Body::from (s))?) } #[instrument (level = "trace", skip (req, state))] pub async fn handle ( req: Request , state: &Relay, path_rest: &str ) -> Result , RequestError> { { if ! state.config.read ().await.iso.enable_scraper_api { return Ok (error_reply (StatusCode::FORBIDDEN, "Scraper API disabled")?); } } if let Some (rest) = path_rest.strip_prefix ("v1/") { api_v1 (req, state, rest).await } else if let Some (rest) = path_rest.strip_prefix ("api/") { api_v1 (req, state, rest).await } else { Ok (error_reply (StatusCode::NOT_FOUND, strings::UNKNOWN_API_VERSION)?) } } fn get_rss_from_status (proc_status: &str) -> Option { use std::str::FromStr; for line in proc_status.lines () { if let Some (rest) = line.strip_prefix ("VmRSS:\t").and_then (|s| s.strip_suffix (" kB")) { return u64::from_str (rest.trim_start ()).ok (); } } None } #[cfg (test)] mod tests { use std::{ convert::{TryInto}, }; use tokio::runtime::Runtime; use crate::{ key_validity, }; use super::*; #[derive (Clone)] struct TestCase { // Inputs path_rest: &'static str, auth_header: Option <&'static str>, valid_key: Option <&'static str>, x_api_key: Option <&'static str>, // Expected expected_status: StatusCode, expected_headers: Vec <(&'static str, &'static str)>, expected_body: String, } impl TestCase { fn path_rest (&self, v: &'static str) -> Self { let mut x = self.clone (); x.path_rest = v; x } fn valid_key (&self, v: Option <&'static str>) -> Self { let mut x = self.clone (); x.valid_key = v; x } fn auth_header (&self, v: Option <&'static str>) -> Self { let mut x = self.clone (); x.auth_header = v; x } fn x_api_key (&self, v: Option <&'static str>) -> Self { let mut x = self.clone (); x.x_api_key = v; x } fn expected_status (&self, v: StatusCode) -> Self { let mut x = self.clone (); x.expected_status = v; x } fn expected_headers (&self, v: Vec <(&'static str, &'static str)>) -> Self { let mut x = self.clone (); x.expected_headers = v; x } fn expected_body (&self, v: String) -> Self { let mut x = self.clone (); x.expected_body = v; x } fn expected (&self, sc: StatusCode, body: &str) -> Self { self .expected_status (sc) .expected_body (format! ("{}\n", body)) } async fn test (&self, name: &str) { let mut input = Request::builder () .method ("GET") .uri (format! ("http://127.0.0.1:4000/scraper/{}", self.path_rest)); if let Some (auth_header) = self.auth_header { input = input.header ("Authorization", auth_header); } if let Some (x_api_key) = self.x_api_key { input = input.header ("X-ApiKey", x_api_key); } let input = input.body (Body::empty ()).unwrap (); let builder = Relay::build () .port (4000) .enable_scraper_api (true); let builder = if let Some (key) = self.valid_key.map (|x| key_validity::ScraperKey::new_30_day ("automated test", x.as_bytes ())) { builder.scraper_key (key) } else { builder }; let relay_state = builder.build ().expect ("Can't create relay state"); let actual = super::handle (input, &relay_state, self.path_rest).await; let actual = actual.expect ("Relay didn't respond"); let (actual_head, actual_body) = actual.into_parts (); let mut expected_headers = hyper::header::HeaderMap::new (); for (key, value) in &self.expected_headers { expected_headers.insert (*key, (*value).try_into ().expect ("Couldn't convert header value")); } assert_eq! (actual_head.status, self.expected_status, "{}", name); assert_eq! (actual_head.headers, expected_headers, "{}", name); let actual_body = hyper::body::to_bytes (actual_body).await; let actual_body = actual_body.expect ("Body should be convertible to bytes"); let actual_body = actual_body.to_vec (); let actual_body = String::from_utf8 (actual_body).expect ("Body should be UTF-8"); assert_eq! (actual_body, self.expected_body, "{}", name); } } #[test] fn auth () { let rt = Runtime::new ().expect ("Can't create runtime for testing"); rt.block_on (async move { let base_case = TestCase { path_rest: "v1/test", valid_key: Some ("bogus"), auth_header: None, x_api_key: Some ("bogus"), expected_status: StatusCode::OK, expected_headers: vec! [ ("content-type", "text/plain"), ], expected_body: "You're valid!\n".to_string (), }; base_case .test ("00").await; base_case .path_rest ("v9999/test") .expected (StatusCode::NOT_FOUND, strings::UNKNOWN_API_VERSION) .test ("01").await; base_case .valid_key (None) .expected (StatusCode::FORBIDDEN, strings::FORBIDDEN) .test ("02").await; base_case .x_api_key (Some ("borgus")) .expected (StatusCode::FORBIDDEN, strings::FORBIDDEN) .test ("03").await; base_case .path_rest ("v1/toast") .expected (StatusCode::NOT_FOUND, strings::UNKNOWN_API_ENDPOINT) .test ("04").await; base_case .x_api_key (None) .expected (StatusCode::FORBIDDEN, strings::NO_API_KEY) .test ("05").await; base_case .x_api_key (None) .auth_header (Some ("Bearer bogus")) .expected (StatusCode::OK, "You're valid!") .test ("06").await; }); } #[test] fn rss () { let input = "VmHWM: 584 kB\nVmRSS: 584 kB\nRssAnon: 68 kB\n"; assert_eq! (get_rss_from_status (input), Some (584)); } }