diff --git a/Cargo.lock b/Cargo.lock index 7ad9b46..073673a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -165,9 +165,9 @@ checksum = "e91831deabf0d6d7ec49552e489aed63b7456a7a3c46cff62adad428110b0af0" [[package]] name = "async-trait" -version = "0.1.42" +version = "0.1.45" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d3a45e77e34375a7923b1e8febb049bb011f064714a8e17a1a616fef01da13d" +checksum = "d3340571769500ddef1e94b45055fabed6b08a881269b7570c830b8f32ef84ef" dependencies = [ "proc-macro2", "quote", @@ -468,6 +468,7 @@ name = "debug_proxy" version = "0.1.0" dependencies = [ "anyhow", + "async-trait", "futures-util", "http", "hyper", diff --git a/crates/debug_proxy/Cargo.toml b/crates/debug_proxy/Cargo.toml index 67914c7..aadca56 100644 --- a/crates/debug_proxy/Cargo.toml +++ b/crates/debug_proxy/Cargo.toml @@ -9,6 +9,7 @@ license = "AGPL-3.0" [dependencies] anyhow = "1.0.34" +async-trait = "0.1.45" futures-util = "0.3.8" http = "0.2.1" hyper = { version = "0.14.4", features = ["server", "stream"] } diff --git a/crates/debug_proxy/src/lib.rs b/crates/debug_proxy/src/lib.rs index d7ed633..7e486a9 100644 --- a/crates/debug_proxy/src/lib.rs +++ b/crates/debug_proxy/src/lib.rs @@ -3,12 +3,14 @@ use std::{ sync::Arc, }; +use async_trait::async_trait; use futures_util::StreamExt; use hyper::{ Body, Request, Response, Server, + body::Bytes, service::{ make_service_fn, service_fn, @@ -25,18 +27,77 @@ use tokio::{ use tokio_stream::wrappers::ReceiverStream; use ulid::Ulid; -struct State { - client: Client, - upstream_authority: String, +#[async_trait] +trait ProxyFilter { + async fn request_body (&self, mut body: Body, tx: mpsc::Sender >) -> anyhow::Result <()>; } -async fn handle_all (req: Request , state: Arc ) +struct PassthroughFilter {} + +#[async_trait] +impl ProxyFilter for PassthroughFilter { + async fn request_body (&self, mut body: Body, tx: mpsc::Sender >) -> anyhow::Result <()> { + let mut bytes_transferred = 0; + + loop { + let item = body.next ().await; + + if let Some (item) = item { + if let Ok (item) = &item { + bytes_transferred += item.len (); + } + tx.send (item).await?; + } + else { + // Finished + break; + } + } + + Ok (()) + } +} + +struct RequestBodyDropFilter {} + +#[async_trait] +impl ProxyFilter for RequestBodyDropFilter { + async fn request_body (&self, mut body: Body, tx: mpsc::Sender >) -> anyhow::Result <()> { + let mut bytes_transferred = 0; + + loop { + let item = body.next ().await; + + if let Some (item) = item { + if let Ok (item) = &item { + bytes_transferred += item.len (); + } + // tx.send (item).await?; + tracing::debug! ("RequestBodyDropFilter dropping chunk"); + } + else { + // Finished + break; + } + } + + Ok (()) + } +} + +struct State { + client: Client, + upstream_authority: String, + proxy_filter: Arc , +} + +async fn handle_all (req: Request , state: Arc >) -> anyhow::Result > { let req_id = Ulid::new ().to_string (); let (head, mut body) = req.into_parts (); - tracing::trace! ("{} Got URI {}", req_id, head.uri); + tracing::debug! ("{} Got URI {}", req_id, head.uri); let upstream_authority = state.upstream_authority.clone (); @@ -55,27 +116,10 @@ async fn handle_all (req: Request , state: Arc ) let (tx, rx) = mpsc::channel (1); spawn ({ let req_id = req_id.clone (); + let proxy_filter = state.proxy_filter.clone (); + async move { - let mut bytes_transferred = 0; - - loop { - let item = body.next ().await; - - if let Some (item) = item { - if let Ok (item) = &item { - bytes_transferred += item.len (); - } - tx.send (item).await?; - } - else { - // Finished - break; - } - } - - tracing::trace! ("{} Request body bytes: {}", req_id, bytes_transferred); - - Ok::<_, anyhow::Error> (()) + proxy_filter.request_body (body, tx).await } }); @@ -122,9 +166,13 @@ pub async fn run_proxy ( shutdown_oneshot: oneshot::Receiver <()>, ) -> anyhow::Result <()> { + let filter = PassthroughFilter {}; + // let filter = RequestBodyDropFilter {}; + let state = Arc::new (State { client: Client::builder ().build ()?, upstream_authority, + proxy_filter: Arc::new (filter), }); let make_svc = make_service_fn (|_conn| { diff --git a/crates/ptth_server/src/lib.rs b/crates/ptth_server/src/lib.rs index 5b61d53..091c7d4 100644 --- a/crates/ptth_server/src/lib.rs +++ b/crates/ptth_server/src/lib.rs @@ -106,9 +106,8 @@ async fn handle_one_req ( if e.is_request () { warn! ("Error while POSTing response. Client probably hung up."); } - else { - error! ("Err: {:?}", e); - } + + error! ("Err: {:?}", e); }, } diff --git a/src/tests.rs b/src/tests.rs index c45b612..3401029 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -9,7 +9,6 @@ use std::{ }; use tokio::{ - runtime::Runtime, spawn, sync::oneshot, }; @@ -42,8 +41,9 @@ async fn testing_client_checks ( assert_eq! (resp, "Relay is up\n"); - let resp = client.get (&format! ("{}/frontend/servers/{}/files/COPYING", relay_url, server_name)) - .send ().await.expect ("Couldn't find license").bytes ().await.expect ("Couldn't find license"); + let req = client.get (&format! ("{}/frontend/servers/{}/files/COPYING", relay_url, server_name)) + .send (); + let resp = tokio::time::timeout (Duration::from_secs (2), req).await.expect ("Request timed out").expect ("Couldn't find license").bytes ().await.expect ("Couldn't find license"); if blake3::hash (&resp) != blake3::Hash::from ([ 0xca, 0x02, 0x92, 0x78, @@ -174,172 +174,163 @@ impl TestingServer { } } -#[test] -fn end_to_end () { +#[tokio::test] +async fn end_to_end () { // Prefer this form for tests, since all tests share one process // and we don't care if another test already installed a subscriber. //tracing_subscriber::fmt ().try_init ().ok (); - let rt = Runtime::new ().expect ("Can't create runtime for testing"); - // Spawn the root task - rt.block_on (async { - let relay_port = 4000; - // No proxy - let proxy_port = relay_port; - let server_name = "aliens_wildland"; + let relay_port = 4000; + // No proxy + let proxy_port = relay_port; + let server_name = "aliens_wildland"; + + let testing_config = TestingConfig { + server_name, + api_key: "AnacondaHardcoverGrannyUnlatchLankinessMutate", - let testing_config = TestingConfig { - server_name, - api_key: "AnacondaHardcoverGrannyUnlatchLankinessMutate", - - proxy_port, - relay_port, - }; - - let testing_relay = TestingRelay::new (&testing_config).await; - let testing_server = TestingServer::new (&testing_config).await; - wait_for_any_server (&testing_relay.state).await; - - assert_eq! (testing_relay.state.list_servers ().await, vec! [ - server_name.to_string (), - ]); - - let client = Client::builder () - .build ().expect ("Couldn't build HTTP client"); - - testing_client_checks (&testing_config, &client).await; - - info! ("Shutting down end-to-end test"); - - testing_server.graceful_shutdown ().await; - testing_relay.graceful_shutdown ().await; - }); + proxy_port, + relay_port, + }; + + let testing_relay = TestingRelay::new (&testing_config).await; + let testing_server = TestingServer::new (&testing_config).await; + wait_for_any_server (&testing_relay.state).await; + + assert_eq! (testing_relay.state.list_servers ().await, vec! [ + server_name.to_string (), + ]); + + let client = Client::builder () + .build ().expect ("Couldn't build HTTP client"); + + testing_client_checks (&testing_config, &client).await; + + info! ("Shutting down end-to-end test"); + + testing_server.graceful_shutdown ().await; + testing_relay.graceful_shutdown ().await; } -#[test] -fn debug_proxy () { - tracing_subscriber::fmt ().try_init ().ok (); - let rt = Runtime::new ().expect ("Can't create runtime for testing"); +#[tokio::test] +async fn debug_proxy () { + tracing_subscriber::fmt () + .with_env_filter (tracing_subscriber::EnvFilter::from_default_env ()) + .try_init ().ok (); - rt.block_on (async { - let relay_port = 4002; - let proxy_port = 11510; + let relay_port = 4002; + let proxy_port = 11510; + + // Start relay + + let server_name = "aliens_wildland"; + + let testing_config = TestingConfig { + server_name, + api_key: "AnacondaHardcoverGrannyUnlatchLankinessMutate", - // Start relay - - let server_name = "aliens_wildland"; - - let testing_config = TestingConfig { - server_name, - api_key: "AnacondaHardcoverGrannyUnlatchLankinessMutate", - - proxy_port, - relay_port, - }; - - let testing_relay = TestingRelay::new (&testing_config).await; - - // Start proxy - - let (stop_proxy_tx, stop_proxy_rx) = oneshot::channel (); - let task_proxy = spawn (async move { - debug_proxy::run_proxy (SocketAddr::from (([0, 0, 0, 0], proxy_port)), format! ("127.0.0.1:{}", relay_port), stop_proxy_rx).await - }); - - // Start server - - let testing_server = TestingServer::new (&testing_config).await; - - wait_for_any_server (&testing_relay.state).await; - - assert_eq! (testing_relay.state.list_servers ().await, vec! [ - server_name.to_string (), - ]); - - let client = Client::builder () - .build ().expect ("Couldn't build HTTP client"); - - testing_client_checks (&testing_config, &client).await; - - info! ("Shutting down end-to-end test"); - - testing_server.graceful_shutdown ().await; - - stop_proxy_tx.send (()).expect ("Couldn't shut down proxy"); - task_proxy.await.expect ("Couldn't join proxy").expect ("Proxy error"); - info! ("Proxy stopped"); - - testing_relay.graceful_shutdown ().await; + proxy_port, + relay_port, + }; + + let testing_relay = TestingRelay::new (&testing_config).await; + + // Start proxy + + let (stop_proxy_tx, stop_proxy_rx) = oneshot::channel (); + let task_proxy = spawn (async move { + debug_proxy::run_proxy (SocketAddr::from (([0, 0, 0, 0], proxy_port)), format! ("127.0.0.1:{}", relay_port), stop_proxy_rx).await }); + + // Start server + + let testing_server = TestingServer::new (&testing_config).await; + + wait_for_any_server (&testing_relay.state).await; + + assert_eq! (testing_relay.state.list_servers ().await, vec! [ + server_name.to_string (), + ]); + + let client = Client::builder () + .build ().expect ("Couldn't build HTTP client"); + + testing_client_checks (&testing_config, &client).await; + + info! ("Shutting down end-to-end test"); + + testing_server.graceful_shutdown ().await; + + stop_proxy_tx.send (()).expect ("Couldn't shut down proxy"); + task_proxy.await.expect ("Couldn't join proxy").expect ("Proxy error"); + info! ("Proxy stopped"); + + testing_relay.graceful_shutdown ().await; } -#[test] -fn scraper_endpoints () { - let rt = Runtime::new ().expect ("Can't create runtime for testing"); +#[tokio::test] +async fn scraper_endpoints () { + use ptth_relay::*; - rt.block_on (async { - use ptth_relay::*; - - let config_file = config::file::Config { - iso: config::file::Isomorphic { - enable_scraper_api: true, - dev_mode: Default::default (), + let config_file = config::file::Config { + iso: config::file::Isomorphic { + enable_scraper_api: true, + dev_mode: Default::default (), + }, + port: Some (4001), + servers: vec! [], + scraper_keys: vec! [ + key_validity::ScraperKey::new_30_day ("automated testing", b"bogus") + ], + }; + + let config = config::Config::try_from (config_file).expect ("Can't load config"); + + let relay_state = Arc::new (RelayState::try_from (config).expect ("Can't create relay state")); + let relay_state_2 = relay_state.clone (); + let (stop_relay_tx, stop_relay_rx) = oneshot::channel (); + let task_relay = spawn (async move { + run_relay ( + relay_state_2, + Arc::new (load_templates (&PathBuf::new ())?), + stop_relay_rx, + None + ).await + }); + + let relay_url = "http://127.0.0.1:4001"; + + let mut headers = reqwest::header::HeaderMap::new (); + headers.insert ("X-ApiKey", "bogus".try_into ().unwrap ()); + + let client = Client::builder () + .default_headers (headers) + .timeout (Duration::from_secs (2)) + .build ().expect ("Couldn't build HTTP client"); + + let mut resp = None; + for _ in 0usize..5 { + let x = client.get (&format! ("{}/scraper/api/test", relay_url)) + .send ().await; + match x { + Err (_) => { + // Probably a reqwest error cause the port is in + // use or something. Try again. + }, + Ok (x) => { + resp = Some (x); + break; }, - port: Some (4001), - servers: vec! [], - scraper_keys: vec! [ - key_validity::ScraperKey::new_30_day ("automated testing", b"bogus") - ], }; - let config = config::Config::try_from (config_file).expect ("Can't load config"); - - let relay_state = Arc::new (RelayState::try_from (config).expect ("Can't create relay state")); - let relay_state_2 = relay_state.clone (); - let (stop_relay_tx, stop_relay_rx) = oneshot::channel (); - let task_relay = spawn (async move { - run_relay ( - relay_state_2, - Arc::new (load_templates (&PathBuf::new ())?), - stop_relay_rx, - None - ).await - }); - - let relay_url = "http://127.0.0.1:4001"; - - let mut headers = reqwest::header::HeaderMap::new (); - headers.insert ("X-ApiKey", "bogus".try_into ().unwrap ()); - - let client = Client::builder () - .default_headers (headers) - .timeout (Duration::from_secs (2)) - .build ().expect ("Couldn't build HTTP client"); - - let mut resp = None; - for _ in 0usize..5 { - let x = client.get (&format! ("{}/scraper/api/test", relay_url)) - .send ().await; - match x { - Err (_) => { - // Probably a reqwest error cause the port is in - // use or something. Try again. - }, - Ok (x) => { - resp = Some (x); - break; - }, - }; - - tokio::time::sleep (Duration::from_millis (200)).await; - } - let resp = resp.expect ("Reqwest repeatedly failed to connect to the relay"); - let resp = resp.bytes ().await.expect ("Couldn't check if relay is up"); - - assert_eq! (resp, "You're valid!\n"); - - stop_relay_tx.send (()).expect ("Couldn't shut down relay"); - task_relay.await.expect ("Couldn't join relay").expect ("Relay error"); - }); + tokio::time::sleep (Duration::from_millis (200)).await; + } + let resp = resp.expect ("Reqwest repeatedly failed to connect to the relay"); + let resp = resp.bytes ().await.expect ("Couldn't check if relay is up"); + + assert_eq! (resp, "You're valid!\n"); + + stop_relay_tx.send (()).expect ("Couldn't shut down relay"); + task_relay.await.expect ("Couldn't join relay").expect ("Relay error"); }