diff --git a/crates/ptth_relay/src/lib.rs b/crates/ptth_relay/src/lib.rs index 9b5ce9d..1ab9c1f 100644 --- a/crates/ptth_relay/src/lib.rs +++ b/crates/ptth_relay/src/lib.rs @@ -317,17 +317,22 @@ async fn handle_server_list_internal (state: &Arc ) } async fn handle_server_list ( - state: Arc + state: Arc , + handlebars: Arc > ) -> Result , RequestError> { let page = handle_server_list_internal (&state).await; - let s = state.handlebars.render ("relay_server_list", &page)?; + let s = handlebars.render ("relay_server_list", &page)?; Ok (ok_reply (s)?) } #[instrument (level = "trace", skip (req, state))] -async fn handle_all (req: Request , state: Arc ) +async fn handle_all ( + req: Request , + state: Arc , + handlebars: Arc > +) -> Result , RequestError> { let path = req.uri ().path ().to_string (); @@ -359,7 +364,7 @@ async fn handle_all (req: Request , state: Arc ) } else if let Some (rest) = prefix_match ("/frontend/servers/", &path) { if rest == "" { - Ok (handle_server_list (state).await?) + Ok (handle_server_list (state, handlebars).await?) } else if let Some (idx) = rest.find ('/') { let listen_code = String::from (&rest [0..idx]); @@ -373,7 +378,7 @@ async fn handle_all (req: Request , state: Arc ) } } else if path == "/" { - let s = state.handlebars.render ("relay_root", &())?; + let s = handlebars.render ("relay_root", &())?; Ok (ok_reply (s)?) } else if path == "/frontend/relay_up_check" { @@ -429,6 +434,7 @@ async fn reload_config ( pub async fn run_relay ( state: Arc , + handlebars: Arc >, shutdown_oneshot: oneshot::Receiver <()>, config_reload_path: Option ) @@ -448,12 +454,14 @@ pub async fn run_relay ( let make_svc = make_service_fn (|_conn| { let state = state.clone (); + let handlebars = handlebars.clone (); async { Ok::<_, RequestError> (service_fn (move |req| { let state = state.clone (); + let handlebars = handlebars.clone (); - handle_all (req, state) + handle_all (req, state, handlebars) })) } }); diff --git a/crates/ptth_relay/src/main.rs b/crates/ptth_relay/src/main.rs index d00332d..abe446d 100644 --- a/crates/ptth_relay/src/main.rs +++ b/crates/ptth_relay/src/main.rs @@ -42,6 +42,7 @@ async fn main () -> Result <(), Box > { forced_shutdown.wrap_server ( run_relay ( Arc::new (RelayState::try_from (config)?), + Arc::new (ptth_relay::load_templates (&PathBuf::new ())?), shutdown_rx, Some (config_path) ) diff --git a/crates/ptth_relay/src/relay_state.rs b/crates/ptth_relay/src/relay_state.rs index 6b7f0c1..f91d76a 100644 --- a/crates/ptth_relay/src/relay_state.rs +++ b/crates/ptth_relay/src/relay_state.rs @@ -76,7 +76,6 @@ impl Default for ServerStatus { pub struct RelayState { pub config: RwLock , - pub handlebars: Arc >, // Key: Server ID pub request_rendezvous: Mutex >, @@ -97,7 +96,6 @@ impl TryFrom for RelayState { Ok (Self { config: config.into (), - handlebars: Arc::new (load_templates (&PathBuf::new ())?), request_rendezvous: Default::default (), server_status: Default::default (), response_rendezvous: Default::default (), diff --git a/crates/ptth_relay/src/scraper_api.rs b/crates/ptth_relay/src/scraper_api.rs index b63bbf7..ff58a32 100644 --- a/crates/ptth_relay/src/scraper_api.rs +++ b/crates/ptth_relay/src/scraper_api.rs @@ -98,3 +98,49 @@ pub async fn handle_scraper_api ( Ok (error_reply (StatusCode::NOT_FOUND, "Unknown scraper API version")?) } } + +#[cfg (test)] +mod tests { + use std::{ + convert::TryFrom, + }; + + use tokio::runtime::Runtime; + use crate::{ + config, + key_validity, + }; + use super::*; + + #[test] + fn auth () { + let input = Request::builder () + .method ("GET") + .uri ("http://127.0.0.1:4000/scraper/v1/test") + .header ("X-ApiKey", "bogus") + .body (Body::empty ()) + .unwrap (); + + let config_file = config::file::Config { + port: Some (4000), + servers: vec! [], + iso: config::file::Isomorphic { + enable_scraper_auth: true, + dev_mode: Some (config::file::DevMode { + scraper_key: Some (key_validity::ScraperKey::new (b"bogus")), + }), + }, + }; + + let config = config::Config::try_from (config_file).expect ("Can't load config"); + + let relay_state = Arc::new (RelayState::try_from (config).expect ("Can't create relay state")); + + let mut rt = Runtime::new ().expect ("Can't create runtime for testing"); + + rt.block_on (async move { + let actual = handle_scraper_api (input, relay_state, "").await; + let actual = actual.expect ("Relay didn't respond"); + }); + } +} diff --git a/src/tests.rs b/src/tests.rs index 55856bb..6a4c28f 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -1,5 +1,6 @@ use std::{ convert::{TryFrom, TryInto}, + path::PathBuf, sync::{ Arc, }, @@ -16,6 +17,8 @@ use tokio::{ use reqwest::Client; use tracing::{debug, info}; +use ptth_relay::load_templates; + #[test] fn end_to_end () { use ptth_relay::key_validity::BlakeHashWrapper; @@ -51,7 +54,12 @@ fn end_to_end () { let relay_state_2 = relay_state.clone (); let (stop_relay_tx, stop_relay_rx) = oneshot::channel (); let task_relay = spawn (async move { - ptth_relay::run_relay (relay_state_2, stop_relay_rx, None).await + ptth_relay::run_relay ( + relay_state_2, + Arc::new (load_templates (&PathBuf::new ())?), + stop_relay_rx, + None + ).await }); assert! (relay_state.list_servers ().await.is_empty ()); @@ -153,7 +161,12 @@ fn scraper_endpoints () { let relay_state_2 = relay_state.clone (); let (stop_relay_tx, stop_relay_rx) = oneshot::channel (); let task_relay = spawn (async move { - run_relay (relay_state_2, stop_relay_rx, None).await + run_relay ( + relay_state_2, + Arc::new (load_templates (&PathBuf::new ())?), + stop_relay_rx, + None + ).await }); let relay_url = "http://127.0.0.1:4001";