ptth/crates/debug_proxy/src/lib.rs

203 lines
4.3 KiB
Rust
Raw Normal View History

use std::{
net::SocketAddr,
sync::Arc,
};
use async_trait::async_trait;
use futures_util::StreamExt;
use hyper::{
Body,
Request,
Response,
Server,
body::Bytes,
service::{
make_service_fn,
service_fn,
},
};
use reqwest::Client;
use tokio::{
spawn,
sync::{
mpsc,
oneshot,
},
};
use tokio_stream::wrappers::ReceiverStream;
#[async_trait]
trait ProxyFilter {
async fn request_body (&self, mut body: Body, tx: mpsc::Sender <Result <Bytes, hyper::Error>>) -> anyhow::Result <()>;
}
struct PassthroughFilter {}
#[async_trait]
impl ProxyFilter for PassthroughFilter {
async fn request_body (&self, mut body: Body, tx: mpsc::Sender <Result <Bytes, hyper::Error>>) -> anyhow::Result <()> {
let mut bytes_transferred = 0;
loop {
let item = body.next ().await;
if let Some (item) = item {
if let Ok (item) = &item {
bytes_transferred += item.len ();
}
tx.send (item).await?;
}
else {
// Finished
break;
}
}
let _bytes_transferred = bytes_transferred;
Ok (())
}
}
struct RequestBodyDropFilter {}
#[async_trait]
impl ProxyFilter for RequestBodyDropFilter {
async fn request_body (&self, mut body: Body, _tx: mpsc::Sender <Result <Bytes, hyper::Error>>) -> anyhow::Result <()> {
let mut bytes_transferred = 0;
loop {
let item = body.next ().await;
if let Some (item) = item {
if let Ok (item) = &item {
bytes_transferred += item.len ();
}
// tx.send (item).await?;
tracing::debug! ("RequestBodyDropFilter dropping chunk");
}
else {
// Finished
break;
}
}
let _bytes_transferred = bytes_transferred;
Ok (())
}
}
struct State <PF> {
client: Client,
upstream_authority: String,
proxy_filter: Arc <PF>,
}
async fn handle_all <PF: 'static + ProxyFilter + Sync + Send> (req: Request <Body>, state: Arc <State <PF>>)
-> anyhow::Result <Response <Body>>
{
let req_id = rusty_ulid::generate_ulid_string ();
let (head, body) = req.into_parts ();
tracing::debug! ("{} Got URI {}", req_id, head.uri);
let upstream_authority = state.upstream_authority.clone ();
let mut new_uri = head.uri.clone ().into_parts ();
new_uri.scheme = Some (http::uri::Scheme::HTTP);
new_uri.authority = Some (http::uri::Authority::from_maybe_shared (upstream_authority)?);
let new_uri = http::Uri::from_parts (new_uri)?;
tracing::trace! ("{} Rebuilt URI as {}", req_id, new_uri);
let mut upstream_req = state.client.request (head.method, &new_uri.to_string ());
for (k, v) in &head.headers {
upstream_req = upstream_req.header (k, v);
}
let (tx, rx) = mpsc::channel (1);
spawn ({
let _req_id = req_id.clone ();
let proxy_filter = state.proxy_filter.clone ();
async move {
proxy_filter.request_body (body, tx).await
}
});
let upstream_resp = upstream_req.body (reqwest::Body::wrap_stream (ReceiverStream::new (rx))).send ().await?;
let mut resp = Response::builder ()
.status (upstream_resp.status ());
for (k, v) in upstream_resp.headers () {
resp = resp.header (k, v);
}
let (tx, rx) = mpsc::channel (1);
spawn (async move {
let mut body = upstream_resp.bytes_stream ();
let mut bytes_transferred = 0;
loop {
let item = body.next ().await;
if let Some (item) = item {
if let Ok (item) = &item {
bytes_transferred += item.len ();
}
tx.send (item).await?;
}
else {
// Finished
break;
}
}
tracing::trace! ("{} Response body bytes: {}", req_id, bytes_transferred);
Ok::<_, anyhow::Error> (())
});
Ok (resp.body (Body::wrap_stream (ReceiverStream::new (rx)))?)
}
pub async fn run_proxy (
addr: SocketAddr,
upstream_authority: String,
shutdown_oneshot: oneshot::Receiver <()>,
) -> anyhow::Result <()>
{
let filter = PassthroughFilter {};
// let filter = RequestBodyDropFilter {};
let state = Arc::new (State {
client: Client::builder ().build ()?,
upstream_authority,
proxy_filter: Arc::new (filter),
});
let make_svc = make_service_fn (|_conn| {
let state = state.clone ();
async {
Ok::<_, String> (service_fn (move |req| {
let state = state.clone ();
handle_all (req, state)
}))
}
});
let server = Server::bind (&addr)
.serve (make_svc);
tracing::info! ("Proxy binding to {}", addr);
server.with_graceful_shutdown (async {
shutdown_oneshot.await.ok ();
}).await?;
Ok (())
}