203 lines
4.3 KiB
Rust
203 lines
4.3 KiB
Rust
use std::{
|
|
net::SocketAddr,
|
|
sync::Arc,
|
|
};
|
|
|
|
use async_trait::async_trait;
|
|
use futures_util::StreamExt;
|
|
use hyper::{
|
|
Body,
|
|
Request,
|
|
Response,
|
|
Server,
|
|
body::Bytes,
|
|
service::{
|
|
make_service_fn,
|
|
service_fn,
|
|
},
|
|
};
|
|
use reqwest::Client;
|
|
use tokio::{
|
|
spawn,
|
|
sync::{
|
|
mpsc,
|
|
oneshot,
|
|
},
|
|
};
|
|
use tokio_stream::wrappers::ReceiverStream;
|
|
|
|
#[async_trait]
|
|
trait ProxyFilter {
|
|
async fn request_body (&self, mut body: Body, tx: mpsc::Sender <Result <Bytes, hyper::Error>>) -> anyhow::Result <()>;
|
|
}
|
|
|
|
struct PassthroughFilter {}
|
|
|
|
#[async_trait]
|
|
impl ProxyFilter for PassthroughFilter {
|
|
async fn request_body (&self, mut body: Body, tx: mpsc::Sender <Result <Bytes, hyper::Error>>) -> anyhow::Result <()> {
|
|
let mut bytes_transferred = 0;
|
|
|
|
loop {
|
|
let item = body.next ().await;
|
|
|
|
if let Some (item) = item {
|
|
if let Ok (item) = &item {
|
|
bytes_transferred += item.len ();
|
|
}
|
|
tx.send (item).await?;
|
|
}
|
|
else {
|
|
// Finished
|
|
break;
|
|
}
|
|
}
|
|
|
|
let _bytes_transferred = bytes_transferred;
|
|
|
|
Ok (())
|
|
}
|
|
}
|
|
|
|
struct RequestBodyDropFilter {}
|
|
|
|
#[async_trait]
|
|
impl ProxyFilter for RequestBodyDropFilter {
|
|
async fn request_body (&self, mut body: Body, _tx: mpsc::Sender <Result <Bytes, hyper::Error>>) -> anyhow::Result <()> {
|
|
let mut bytes_transferred = 0;
|
|
|
|
loop {
|
|
let item = body.next ().await;
|
|
|
|
if let Some (item) = item {
|
|
if let Ok (item) = &item {
|
|
bytes_transferred += item.len ();
|
|
}
|
|
// tx.send (item).await?;
|
|
tracing::debug! ("RequestBodyDropFilter dropping chunk");
|
|
}
|
|
else {
|
|
// Finished
|
|
break;
|
|
}
|
|
}
|
|
|
|
let _bytes_transferred = bytes_transferred;
|
|
|
|
Ok (())
|
|
}
|
|
}
|
|
|
|
struct State <PF> {
|
|
client: Client,
|
|
upstream_authority: String,
|
|
proxy_filter: Arc <PF>,
|
|
}
|
|
|
|
async fn handle_all <PF: 'static + ProxyFilter + Sync + Send> (req: Request <Body>, state: Arc <State <PF>>)
|
|
-> anyhow::Result <Response <Body>>
|
|
{
|
|
let req_id = rusty_ulid::generate_ulid_string ();
|
|
let (head, body) = req.into_parts ();
|
|
|
|
tracing::debug! ("{} Got URI {}", req_id, head.uri);
|
|
|
|
let upstream_authority = state.upstream_authority.clone ();
|
|
|
|
let mut new_uri = head.uri.clone ().into_parts ();
|
|
new_uri.scheme = Some (http::uri::Scheme::HTTP);
|
|
new_uri.authority = Some (http::uri::Authority::from_maybe_shared (upstream_authority)?);
|
|
let new_uri = http::Uri::from_parts (new_uri)?;
|
|
|
|
tracing::trace! ("{} Rebuilt URI as {}", req_id, new_uri);
|
|
|
|
let mut upstream_req = state.client.request (head.method, &new_uri.to_string ());
|
|
for (k, v) in &head.headers {
|
|
upstream_req = upstream_req.header (k, v);
|
|
}
|
|
|
|
let (tx, rx) = mpsc::channel (1);
|
|
spawn ({
|
|
let _req_id = req_id.clone ();
|
|
let proxy_filter = state.proxy_filter.clone ();
|
|
|
|
async move {
|
|
proxy_filter.request_body (body, tx).await
|
|
}
|
|
});
|
|
|
|
let upstream_resp = upstream_req.body (reqwest::Body::wrap_stream (ReceiverStream::new (rx))).send ().await?;
|
|
|
|
let mut resp = Response::builder ()
|
|
.status (upstream_resp.status ());
|
|
|
|
for (k, v) in upstream_resp.headers () {
|
|
resp = resp.header (k, v);
|
|
}
|
|
|
|
let (tx, rx) = mpsc::channel (1);
|
|
spawn (async move {
|
|
let mut body = upstream_resp.bytes_stream ();
|
|
let mut bytes_transferred = 0;
|
|
|
|
loop {
|
|
let item = body.next ().await;
|
|
|
|
if let Some (item) = item {
|
|
if let Ok (item) = &item {
|
|
bytes_transferred += item.len ();
|
|
}
|
|
tx.send (item).await?;
|
|
}
|
|
else {
|
|
// Finished
|
|
break;
|
|
}
|
|
}
|
|
|
|
tracing::trace! ("{} Response body bytes: {}", req_id, bytes_transferred);
|
|
|
|
Ok::<_, anyhow::Error> (())
|
|
});
|
|
|
|
Ok (resp.body (Body::wrap_stream (ReceiverStream::new (rx)))?)
|
|
}
|
|
|
|
pub async fn run_proxy (
|
|
addr: SocketAddr,
|
|
upstream_authority: String,
|
|
shutdown_oneshot: oneshot::Receiver <()>,
|
|
) -> anyhow::Result <()>
|
|
{
|
|
let filter = PassthroughFilter {};
|
|
// let filter = RequestBodyDropFilter {};
|
|
|
|
let state = Arc::new (State {
|
|
client: Client::builder ().build ()?,
|
|
upstream_authority,
|
|
proxy_filter: Arc::new (filter),
|
|
});
|
|
|
|
let make_svc = make_service_fn (|_conn| {
|
|
let state = state.clone ();
|
|
|
|
async {
|
|
Ok::<_, String> (service_fn (move |req| {
|
|
let state = state.clone ();
|
|
|
|
handle_all (req, state)
|
|
}))
|
|
}
|
|
});
|
|
|
|
let server = Server::bind (&addr)
|
|
.serve (make_svc);
|
|
|
|
tracing::info! ("Proxy binding to {}", addr);
|
|
server.with_graceful_shutdown (async {
|
|
shutdown_oneshot.await.ok ();
|
|
}).await?;
|
|
|
|
Ok (())
|
|
}
|