ptth/crates/ptth_relay/src/scraper_api.rs

380 lines
8.8 KiB
Rust

use std::{
collections::HashMap,
};
use chrono::{DateTime, Utc};
use hyper::{
Body,
Request,
Response,
StatusCode,
};
use serde::{
Serialize,
Serializer,
};
use tracing::{
error,
instrument,
};
use crate::{
RequestError,
error_reply,
key_validity::{
BlakeHashWrapper,
KeyValidity,
},
relay_state::Relay,
};
// Not sure if this is the best way to do a hard-coded string table, but
// it'll keep the tests up to date
mod strings {
pub const FORBIDDEN: &str = "403 Forbidden";
pub const NO_API_KEY: &str = "Can't auth as scraper without API key";
pub const UNKNOWN_API_VERSION: &str = "Unknown scraper API version";
pub const UNKNOWN_API_ENDPOINT: &str = "Unknown scraper API endpoint";
}
// JSON is probably Good Enough For Now, so I'll just make everything
// a struct and lazily serialize it right before leaving the
// top-level handle () fn.
fn serialize_last_seen <S: Serializer> (x: &Option <DateTime <Utc>>, s: S)
-> Result <S::Ok, S::Error>
{
match x {
None => s.serialize_none (),
Some (x) => s.serialize_str (&x.to_rfc3339 ()),
}
}
#[derive (Serialize)]
pub struct Server {
pub name: String,
pub display_name: String,
#[serde (serialize_with = "serialize_last_seen")]
pub last_seen: Option <DateTime <Utc>>,
}
#[derive (Serialize)]
pub struct ServerList {
pub servers: Vec <Server>,
}
pub async fn v1_server_list (state: &Relay)
-> ServerList
{
// name --> display_name
let mut display_names = HashMap::new ();
{
let guard = state.config.read ().await;
for (k, v) in &guard.servers {
let display_name = v.display_name
.clone ()
.unwrap_or_else (|| k.clone ());
display_names.insert (k.clone (), display_name);
}
}
{
let guard = state.me_config.read ().await;
for (k, v) in &guard.servers {
let display_name = v.display_name
.clone ()
.unwrap_or_else (|| k.clone ());
display_names.insert (k.clone (), display_name);
}
}
// name --> status
let server_statuses = {
let guard = state.server_status.lock ().await;
(*guard).clone ()
};
let mut servers: Vec <_> = display_names.into_iter ()
.map (|(name, display_name)| {
let last_seen = server_statuses.get (&name).map (|x| x.last_seen);
Server {
display_name,
name,
last_seen,
}
})
.collect ();
servers.sort_by (|a, b| a.display_name.cmp (&b.display_name));
ServerList {
servers,
}
}
#[instrument (level = "trace", skip (req, state))]
async fn api_v1 (
req: Request <Body>,
state: &Relay,
path_rest: &str
)
-> Result <Response <Body>, RequestError>
{
use crate::{
AuditData,
AuditEvent,
};
let api_key = req.headers ().get ("X-ApiKey");
let api_key = match api_key {
None => return Ok (error_reply (StatusCode::FORBIDDEN, strings::NO_API_KEY)?),
Some (x) => x,
};
let actual_hash = BlakeHashWrapper::from_key (api_key.as_bytes ());
let bad_key = || error_reply (StatusCode::FORBIDDEN, strings::FORBIDDEN);
let key_name;
{
let config = state.config.read ().await;
let expected_key = match config.scraper_keys.get (&actual_hash.encode_base64 ()) {
Some (x) => x,
None => return Ok (bad_key ()?),
};
let now = chrono::Utc::now ();
// The hash in is_valid is redundant
match expected_key.is_valid (now, api_key.as_bytes ()) {
KeyValidity::Valid => (),
KeyValidity::WrongKey (bad_hash) => {
error! ("Bad scraper key with hash {:?}", bad_hash);
return Ok (bad_key ()?);
}
err => {
error! ("Bad scraper key {:?}", err);
return Ok (bad_key ()?);
},
}
key_name = expected_key.name.to_string ();
}
state.audit_log.push (AuditEvent::new (AuditData::ScraperGet {
key_name,
path: path_rest.to_string (),
})).await;
if path_rest == "test" {
Ok (error_reply (StatusCode::OK, "You're valid!")?)
}
else if path_rest == "server_list" {
let x = v1_server_list (&state).await;
Ok (error_reply (StatusCode::OK, &serde_json::to_string (&x).unwrap ())?)
}
else if let Some (rest) = path_rest.strip_prefix ("server/") {
// DRY T4H76LB3
if let Some (idx) = rest.find ('/') {
let listen_code = String::from (&rest [0..idx]);
let path = String::from (&rest [idx..]);
let (parts, _) = req.into_parts ();
// This is ugly. I don't like having scraper_api know about the
// crate root.
Ok (crate::handle_http_request (parts, path, &state, &listen_code).await?)
}
else {
Ok (error_reply (StatusCode::BAD_REQUEST, "Bad URI format")?)
}
}
else {
Ok (error_reply (StatusCode::NOT_FOUND, strings::UNKNOWN_API_ENDPOINT)?)
}
}
#[instrument (level = "trace", skip (req, state))]
pub async fn handle (
req: Request <Body>,
state: &Relay,
path_rest: &str
)
-> Result <Response <Body>, RequestError>
{
{
if ! state.config.read ().await.iso.enable_scraper_api {
return Ok (error_reply (StatusCode::FORBIDDEN, "Scraper API disabled")?);
}
}
if let Some (rest) = path_rest.strip_prefix ("v1/") {
api_v1 (req, state, rest).await
}
else if let Some (rest) = path_rest.strip_prefix ("api/") {
api_v1 (req, state, rest).await
}
else {
Ok (error_reply (StatusCode::NOT_FOUND, strings::UNKNOWN_API_VERSION)?)
}
}
#[cfg (test)]
mod tests {
use std::{
convert::{TryInto},
};
use tokio::runtime::Runtime;
use crate::{
key_validity,
};
use super::*;
#[derive (Clone)]
struct TestCase {
// Inputs
path_rest: &'static str,
valid_key: Option <&'static str>,
input_key: Option <&'static str>,
// Expected
expected_status: StatusCode,
expected_headers: Vec <(&'static str, &'static str)>,
expected_body: String,
}
impl TestCase {
fn path_rest (&self, v: &'static str) -> Self {
let mut x = self.clone ();
x.path_rest = v;
x
}
fn valid_key (&self, v: Option <&'static str>) -> Self {
let mut x = self.clone ();
x.valid_key = v;
x
}
fn input_key (&self, v: Option <&'static str>) -> Self {
let mut x = self.clone ();
x.input_key = v;
x
}
fn expected_status (&self, v: StatusCode) -> Self {
let mut x = self.clone ();
x.expected_status = v;
x
}
fn expected_headers (&self, v: Vec <(&'static str, &'static str)>) -> Self {
let mut x = self.clone ();
x.expected_headers = v;
x
}
fn expected_body (&self, v: String) -> Self {
let mut x = self.clone ();
x.expected_body = v;
x
}
fn expected (&self, sc: StatusCode, body: &str) -> Self {
self
.expected_status (sc)
.expected_body (format! ("{}\n", body))
}
async fn test (&self) {
let mut input = Request::builder ()
.method ("GET")
.uri (format! ("http://127.0.0.1:4000/scraper/{}", self.path_rest));
if let Some (input_key) = self.input_key {
input = input.header ("X-ApiKey", input_key);
}
let input = input.body (Body::empty ()).unwrap ();
let builder = Relay::build ()
.port (4000)
.enable_scraper_api (true);
let builder = if let Some (key) = self.valid_key.map (|x| key_validity::ScraperKey::new_30_day ("automated test", x.as_bytes ())) {
builder.scraper_key (key)
}
else {
builder
};
let relay_state = builder.build ().expect ("Can't create relay state");
let actual = super::handle (input, &relay_state, self.path_rest).await;
let actual = actual.expect ("Relay didn't respond");
let (actual_head, actual_body) = actual.into_parts ();
let mut expected_headers = hyper::header::HeaderMap::new ();
for (key, value) in &self.expected_headers {
expected_headers.insert (*key, (*value).try_into ().expect ("Couldn't convert header value"));
}
assert_eq! (actual_head.status, self.expected_status);
assert_eq! (actual_head.headers, expected_headers);
let actual_body = hyper::body::to_bytes (actual_body).await;
let actual_body = actual_body.expect ("Body should be convertible to bytes");
let actual_body = actual_body.to_vec ();
let actual_body = String::from_utf8 (actual_body).expect ("Body should be UTF-8");
assert_eq! (actual_body, self.expected_body);
}
}
#[test]
fn auth () {
let rt = Runtime::new ().expect ("Can't create runtime for testing");
rt.block_on (async move {
let base_case = TestCase {
path_rest: "v1/test",
valid_key: Some ("bogus"),
input_key: Some ("bogus"),
expected_status: StatusCode::OK,
expected_headers: vec! [
("content-type", "text/plain"),
],
expected_body: "You're valid!\n".to_string (),
};
for case in &[
base_case.clone (),
base_case.path_rest ("v9999/test")
.expected (StatusCode::NOT_FOUND, strings::UNKNOWN_API_VERSION),
base_case.valid_key (None)
.expected (StatusCode::FORBIDDEN, strings::FORBIDDEN),
base_case.input_key (Some ("borgus"))
.expected (StatusCode::FORBIDDEN, strings::FORBIDDEN),
base_case.path_rest ("v1/toast")
.expected (StatusCode::NOT_FOUND, strings::UNKNOWN_API_ENDPOINT),
base_case.input_key (None)
.expected (StatusCode::FORBIDDEN, strings::NO_API_KEY),
] {
case.test ().await;
}
});
}
}