use std::{ net::SocketAddr, sync::Arc, }; use async_trait::async_trait; use futures_util::StreamExt; use hyper::{ Body, Request, Response, Server, body::Bytes, service::{ make_service_fn, service_fn, }, }; use reqwest::Client; use tokio::{ spawn, sync::{ mpsc, oneshot, }, }; use tokio_stream::wrappers::ReceiverStream; #[async_trait] trait ProxyFilter { async fn request_body (&self, mut body: Body, tx: mpsc::Sender >) -> anyhow::Result <()>; } struct PassthroughFilter {} #[async_trait] impl ProxyFilter for PassthroughFilter { async fn request_body (&self, mut body: Body, tx: mpsc::Sender >) -> anyhow::Result <()> { let mut bytes_transferred = 0; loop { let item = body.next ().await; if let Some (item) = item { if let Ok (item) = &item { bytes_transferred += item.len (); } tx.send (item).await?; } else { // Finished break; } } Ok (()) } } struct RequestBodyDropFilter {} #[async_trait] impl ProxyFilter for RequestBodyDropFilter { async fn request_body (&self, mut body: Body, tx: mpsc::Sender >) -> anyhow::Result <()> { let mut bytes_transferred = 0; loop { let item = body.next ().await; if let Some (item) = item { if let Ok (item) = &item { bytes_transferred += item.len (); } // tx.send (item).await?; tracing::debug! ("RequestBodyDropFilter dropping chunk"); } else { // Finished break; } } Ok (()) } } struct State { client: Client, upstream_authority: String, proxy_filter: Arc , } async fn handle_all (req: Request , state: Arc >) -> anyhow::Result > { let req_id = rusty_ulid::generate_ulid_string (); let (head, body) = req.into_parts (); tracing::debug! ("{} Got URI {}", req_id, head.uri); let upstream_authority = state.upstream_authority.clone (); let mut new_uri = head.uri.clone ().into_parts (); new_uri.scheme = Some (http::uri::Scheme::HTTP); new_uri.authority = Some (http::uri::Authority::from_maybe_shared (upstream_authority)?); let new_uri = http::Uri::from_parts (new_uri)?; tracing::trace! ("{} Rebuilt URI as {}", req_id, new_uri); let mut upstream_req = state.client.request (head.method, &new_uri.to_string ()); for (k, v) in &head.headers { upstream_req = upstream_req.header (k, v); } let (tx, rx) = mpsc::channel (1); spawn ({ let _req_id = req_id.clone (); let proxy_filter = state.proxy_filter.clone (); async move { proxy_filter.request_body (body, tx).await } }); let upstream_resp = upstream_req.body (reqwest::Body::wrap_stream (ReceiverStream::new (rx))).send ().await?; let mut resp = Response::builder () .status (upstream_resp.status ()); for (k, v) in upstream_resp.headers () { resp = resp.header (k, v); } let (tx, rx) = mpsc::channel (1); spawn (async move { let mut body = upstream_resp.bytes_stream (); let mut bytes_transferred = 0; loop { let item = body.next ().await; if let Some (item) = item { if let Ok (item) = &item { bytes_transferred += item.len (); } tx.send (item).await?; } else { // Finished break; } } tracing::trace! ("{} Response body bytes: {}", req_id, bytes_transferred); Ok::<_, anyhow::Error> (()) }); Ok (resp.body (Body::wrap_stream (ReceiverStream::new (rx)))?) } pub async fn run_proxy ( addr: SocketAddr, upstream_authority: String, shutdown_oneshot: oneshot::Receiver <()>, ) -> anyhow::Result <()> { let filter = PassthroughFilter {}; // let filter = RequestBodyDropFilter {}; let state = Arc::new (State { client: Client::builder ().build ()?, upstream_authority, proxy_filter: Arc::new (filter), }); let make_svc = make_service_fn (|_conn| { let state = state.clone (); async { Ok::<_, String> (service_fn (move |req| { let state = state.clone (); handle_all (req, state) })) } }); let server = Server::bind (&addr) .serve (make_svc); tracing::info! ("Proxy binding to {}", addr); server.with_graceful_shutdown (async { shutdown_oneshot.await.ok (); }).await?; Ok (()) }