♻️ refactor: debug_proxy into a lib

main
_ 2021-03-05 04:03:40 +00:00
parent 27336d8571
commit 33746d9ace
5 changed files with 141 additions and 111 deletions

2
Cargo.lock generated
View File

@ -490,6 +490,7 @@ dependencies = [
"tokio", "tokio",
"tracing", "tracing",
"tracing-subscriber", "tracing-subscriber",
"ulid",
] ]
[[package]] [[package]]
@ -1628,6 +1629,7 @@ dependencies = [
"base64 0.13.0", "base64 0.13.0",
"blake3", "blake3",
"chrono", "chrono",
"debug_proxy",
"ptth_relay", "ptth_relay",
"ptth_server", "ptth_server",
"reqwest", "reqwest",

View File

@ -33,6 +33,7 @@ tokio = { version = "0.2.22", features = ["full"] }
tracing = "0.1.21" tracing = "0.1.21"
tracing-subscriber = "0.2.15" tracing-subscriber = "0.2.15"
debug_proxy = { path = "crates/debug_proxy" }
ptth_relay = { path = "crates/ptth_relay" } ptth_relay = { path = "crates/ptth_relay" }
ptth_server = { path = "crates/ptth_server" } ptth_server = { path = "crates/ptth_server" }

View File

@ -15,3 +15,4 @@ reqwest = { version = "0.10.8", features = ["stream"] }
tokio = { version = "0.2.22", features = ["full"] } tokio = { version = "0.2.22", features = ["full"] }
tracing = "0.1.21" tracing = "0.1.21"
tracing-subscriber = "0.2.15" tracing-subscriber = "0.2.15"
ulid = "0.4.1"

View File

@ -0,0 +1,136 @@
use std::{
net::SocketAddr,
sync::Arc,
};
use hyper::{
Body,
Request,
Response,
Server,
service::{
make_service_fn,
service_fn,
},
};
use reqwest::Client;
use tokio::{
spawn,
stream::StreamExt,
sync::mpsc,
};
use ulid::Ulid;
struct State {
client: Client,
upstream_authority: String,
}
async fn handle_all (req: Request <Body>, state: Arc <State>)
-> anyhow::Result <Response <Body>>
{
let req_id = Ulid::new ().to_string ();
let (head, mut body) = req.into_parts ();
tracing::trace! ("{} Got URI {}", req_id, head.uri);
let upstream_authority = state.upstream_authority.clone ();
let mut new_uri = head.uri.clone ().into_parts ();
new_uri.scheme = Some (http::uri::Scheme::HTTP);
new_uri.authority = Some (http::uri::Authority::from_maybe_shared (upstream_authority)?);
let new_uri = http::Uri::from_parts (new_uri)?;
tracing::trace! ("{} Rebuilt URI as {}", req_id, new_uri);
let mut upstream_req = state.client.request (head.method, &new_uri.to_string ());
for (k, v) in &head.headers {
upstream_req = upstream_req.header (k, v);
}
let (mut tx, rx) = mpsc::channel (1);
spawn ({
let req_id = req_id.clone ();
async move {
let mut bytes_transferred = 0;
loop {
let item = body.next ().await;
if let Some (item) = item {
if let Ok (item) = &item {
bytes_transferred += item.len ();
}
tx.send (item).await?;
}
else {
// Finished
break;
}
}
tracing::trace! ("{} Request body bytes: {}", req_id, bytes_transferred);
Ok::<_, anyhow::Error> (())
}
});
let upstream_resp = upstream_req.body (reqwest::Body::wrap_stream (rx)).send ().await?;
let mut resp = Response::builder ()
.status (upstream_resp.status ());
for (k, v) in upstream_resp.headers () {
resp = resp.header (k, v);
}
let (mut tx, rx) = mpsc::channel (1);
spawn (async move {
let mut body = upstream_resp.bytes_stream ();
let mut bytes_transferred = 0;
loop {
let item = body.next ().await;
if let Some (item) = item {
if let Ok (item) = &item {
bytes_transferred += item.len ();
}
tx.send (item).await?;
}
else {
// Finished
break;
}
}
tracing::trace! ("{} Response body bytes: {}", req_id, bytes_transferred);
Ok::<_, anyhow::Error> (())
});
Ok (resp.body (Body::wrap_stream (rx))?)
}
pub async fn run_proxy (addr: SocketAddr, upstream_authority: String) -> anyhow::Result <()> {
let state = Arc::new (State {
client: Client::builder ().build ()?,
upstream_authority,
});
let make_svc = make_service_fn (|_conn| {
let state = state.clone ();
async {
Ok::<_, String> (service_fn (move |req| {
let state = state.clone ();
handle_all (req, state)
}))
}
});
tracing::info! ("Binding to {}", addr);
Ok (Server::bind (&addr)
.serve (make_svc).await?)
}

View File

@ -1,120 +1,10 @@
use std::{ use std::{
net::SocketAddr, net::SocketAddr,
sync::Arc,
}; };
use hyper::{
Body,
Request,
Response,
Server,
service::{
make_service_fn,
service_fn,
},
StatusCode,
};
use reqwest::Client;
use tokio::{
spawn,
stream::StreamExt,
sync::mpsc,
};
struct State {
client: Client,
}
async fn handle_all (req: Request <Body>, state: Arc <State>)
-> anyhow::Result <Response <Body>>
{
let (head, mut body) = req.into_parts ();
tracing::trace! ("Got URI {}", head.uri);
let mut new_uri = head.uri.clone ().into_parts ();
new_uri.scheme = Some (http::uri::Scheme::HTTPS);
new_uri.authority = Some (http::uri::Authority::from_static ("example.com"));
let new_uri = http::Uri::from_parts (new_uri)?;
tracing::trace! ("Rebuilt URI as {}", new_uri);
let mut upstream_req = state.client.request (head.method, &new_uri.to_string ());
for (k, v) in &head.headers {
// upstream_req = upstream_req.header (k, v);
}
let (mut tx, rx) = mpsc::channel (1);
spawn (async move {
loop {
let item = body.next ().await;
if let Some (item) = item {
tx.send (item).await?;
}
else {
// Finished
break;
}
}
Ok::<_, anyhow::Error> (())
});
let upstream_resp = upstream_req.body (reqwest::Body::wrap_stream (rx)).send ().await?;
let mut resp = Response::builder ()
.status (upstream_resp.status ());
for (k, v) in upstream_resp.headers () {
resp = resp.header (k, v);
}
let (mut tx, rx) = mpsc::channel (1);
spawn (async move {
let mut body = upstream_resp.bytes_stream ();
loop {
let item = body.next ().await;
if let Some (item) = item {
tx.send (item).await?;
}
else {
// Finished
break;
}
}
Ok::<_, anyhow::Error> (())
});
Ok (resp.body (Body::wrap_stream (rx))?)
}
#[tokio::main] #[tokio::main]
async fn main () -> anyhow::Result <()> { async fn main () -> anyhow::Result <()> {
tracing_subscriber::fmt::init (); tracing_subscriber::fmt::init ();
let addr = SocketAddr::from(([0, 0, 0, 0], 11509)); let addr = SocketAddr::from(([0, 0, 0, 0], 11509));
debug_proxy::run_proxy (addr, "127.0.0.1:4000".to_string ()).await
let state = Arc::new (State {
client: Client::builder ().build ()?,
});
let make_svc = make_service_fn (|_conn| {
let state = state.clone ();
async {
Ok::<_, String> (service_fn (move |req| {
let state = state.clone ();
handle_all (req, state)
}))
}
});
tracing::info! ("Binding to {}", addr);
Ok (Server::bind (&addr)
.serve (make_svc).await?)
} }