From 27336d85717c5fded7ddd1d3a6e03428c13470ac Mon Sep 17 00:00:00 2001 From: _ <> Date: Fri, 5 Mar 2021 03:17:56 +0000 Subject: [PATCH] add debug_proxy which I can probably use to inject network problems during tests --- Cargo.lock | 13 +++ crates/debug_proxy/Cargo.toml | 17 ++++ crates/debug_proxy/src/main.rs | 120 ++++++++++++++++++++++++ crates/ptth_file_server_bin/src/main.rs | 4 +- 4 files changed, 152 insertions(+), 2 deletions(-) create mode 100644 crates/debug_proxy/Cargo.toml create mode 100644 crates/debug_proxy/src/main.rs diff --git a/Cargo.lock b/Cargo.lock index 6226349..ab7eb16 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -479,6 +479,19 @@ dependencies = [ "num_cpus", ] +[[package]] +name = "debug_proxy" +version = "0.1.0" +dependencies = [ + "anyhow", + "http", + "hyper", + "reqwest", + "tokio", + "tracing", + "tracing-subscriber", +] + [[package]] name = "digest" version = "0.8.1" diff --git a/crates/debug_proxy/Cargo.toml b/crates/debug_proxy/Cargo.toml new file mode 100644 index 0000000..86b8791 --- /dev/null +++ b/crates/debug_proxy/Cargo.toml @@ -0,0 +1,17 @@ +[package] + +name = "debug_proxy" +version = "0.1.0" +authors = ["Trish"] +edition = "2018" +license = "AGPL-3.0" + +[dependencies] + +anyhow = "1.0.34" +http = "0.2.1" +hyper = "0.13.8" +reqwest = { version = "0.10.8", features = ["stream"] } +tokio = { version = "0.2.22", features = ["full"] } +tracing = "0.1.21" +tracing-subscriber = "0.2.15" diff --git a/crates/debug_proxy/src/main.rs b/crates/debug_proxy/src/main.rs new file mode 100644 index 0000000..173a877 --- /dev/null +++ b/crates/debug_proxy/src/main.rs @@ -0,0 +1,120 @@ +use std::{ + net::SocketAddr, + sync::Arc, +}; + +use hyper::{ + Body, + Request, + Response, + Server, + service::{ + make_service_fn, + service_fn, + }, + StatusCode, +}; +use reqwest::Client; +use tokio::{ + spawn, + stream::StreamExt, + sync::mpsc, +}; + +struct State { + client: Client, +} + +async fn handle_all (req: Request , state: Arc ) +-> anyhow::Result > +{ + let (head, mut body) = req.into_parts (); + + tracing::trace! ("Got URI {}", head.uri); + + let mut new_uri = head.uri.clone ().into_parts (); + new_uri.scheme = Some (http::uri::Scheme::HTTPS); + new_uri.authority = Some (http::uri::Authority::from_static ("example.com")); + let new_uri = http::Uri::from_parts (new_uri)?; + + tracing::trace! ("Rebuilt URI as {}", new_uri); + + let mut upstream_req = state.client.request (head.method, &new_uri.to_string ()); + for (k, v) in &head.headers { + // upstream_req = upstream_req.header (k, v); + } + + let (mut tx, rx) = mpsc::channel (1); + spawn (async move { + loop { + let item = body.next ().await; + + if let Some (item) = item { + tx.send (item).await?; + } + else { + // Finished + break; + } + } + + Ok::<_, anyhow::Error> (()) + }); + + let upstream_resp = upstream_req.body (reqwest::Body::wrap_stream (rx)).send ().await?; + + let mut resp = Response::builder () + .status (upstream_resp.status ()); + + for (k, v) in upstream_resp.headers () { + resp = resp.header (k, v); + } + + let (mut tx, rx) = mpsc::channel (1); + spawn (async move { + let mut body = upstream_resp.bytes_stream (); + + loop { + let item = body.next ().await; + + if let Some (item) = item { + tx.send (item).await?; + } + else { + // Finished + break; + } + } + + Ok::<_, anyhow::Error> (()) + }); + + Ok (resp.body (Body::wrap_stream (rx))?) +} + +#[tokio::main] +async fn main () -> anyhow::Result <()> { + tracing_subscriber::fmt::init (); + + let addr = SocketAddr::from(([0, 0, 0, 0], 11509)); + + let state = Arc::new (State { + client: Client::builder ().build ()?, + }); + + let make_svc = make_service_fn (|_conn| { + let state = state.clone (); + + async { + Ok::<_, String> (service_fn (move |req| { + let state = state.clone (); + + handle_all (req, state) + })) + } + }); + + tracing::info! ("Binding to {}", addr); + Ok (Server::bind (&addr) + .serve (make_svc).await?) +} diff --git a/crates/ptth_file_server_bin/src/main.rs b/crates/ptth_file_server_bin/src/main.rs index cd8f82c..fda4d15 100644 --- a/crates/ptth_file_server_bin/src/main.rs +++ b/crates/ptth_file_server_bin/src/main.rs @@ -35,7 +35,7 @@ use ptth_server::{ }; async fn handle_all (req: Request , state: Arc ) --> Result , anyhow::Error> +-> anyhow::Result > { use std::str::FromStr; use hyper::header::HeaderName; @@ -82,7 +82,7 @@ pub struct ConfigFile { } #[tokio::main] -async fn main () -> Result <(), anyhow::Error> { +async fn main () -> anyhow::Result <()> { tracing_subscriber::fmt::init (); let path = PathBuf::from ("./config/ptth_server.toml");