ptth/crates/ptth_forwarding_relay/src/main.rs

215 lines
4.4 KiB
Rust

use std::{
collections::HashMap,
sync::Arc,
time::Duration,
};
use futures_util::StreamExt;
use hyper::{
Body,
Method,
Request,
Response,
StatusCode,
};
use tokio::{
spawn,
sync::{
RwLock,
mpsc,
},
time::interval,
};
use tokio_stream::wrappers::ReceiverStream;
use tracing::{
info, trace,
};
use tracing_subscriber::{
fmt,
fmt::format::FmtSpan,
EnvFilter,
};
use ulid::Ulid;
#[derive (Default)]
struct RelayState {
connections: HashMap <String, ConnectionState>,
client_opaques: HashMap <String, String>,
server_opaques: HashMap <String, String>,
}
/*
HTTP has 2 good pause points:
- Client has uploaded request body, server has said nothing
- Server has sent status code + response headers
Because we want to stream everything, there is no point in a single HTTP
req-resp pair
having both a streaming request body and a streaming response body.
To move the state machine, the first request from client and server must not
be streaming.
With all that in mind, the r
*/
enum ConnectionState {
// We got 1 connection from the client. We need a 2nd to form the upstream.
WaitForUpstream (String, String),
// We got 2 connections from the client. We need the server to accept
// by sending its downstream.
WaitForAccept (String, String, String),
Connected (String, String, String, String),
}
// An established connection has 4 individual HTTP streams
struct EstablishedConnection {
// Request body of 'upstream' call
client_up: String,
// Response body of 'connect' call
client_down: String,
// Response body of 'listen' call
server_up: String,
// Request body of 'accept' call
server_down: String,
}
pub struct HttpService {
state: Arc <RelayState>
}
impl HttpService {
pub fn new () -> Self {
Self {
state: Arc::new (RelayState::default ()),
}
}
pub async fn serve (&self, port: u16) -> Result <(), hyper::Error> {
use std::net::SocketAddr;
use hyper::{
server::Server,
service::{
make_service_fn,
service_fn,
},
};
let make_svc = make_service_fn (|_conn| {
let state = self.state.clone ();
async {
Ok::<_, String> (service_fn (move |req| {
let state = state.clone ();
Self::handle_all (req, state)
}))
}
});
let addr = SocketAddr::from(([127, 0, 0, 1], port));
let server = Server::bind (&addr)
.serve (make_svc)
;
server.await
}
async fn handle_all (req: Request <Body>, state: Arc <RelayState>)
-> Result <Response <Body>, anyhow::Error>
{
if req.method () == Method::GET {
return Self::handle_gets (req, &*state).await;
}
if req.method () == Method::POST {
return Self::handle_posts (req, &*state).await;
}
Ok::<_, anyhow::Error> (Response::builder ()
.body (Body::from ("hello\n"))?)
}
async fn handle_gets (req: Request <Body>, state: &RelayState)
-> Result <Response <Body>, anyhow::Error>
{
let (mut tx, rx) = mpsc::channel (1);
spawn (async move {
let id = Ulid::new ().to_string ();
trace! ("Downstream {} started", id);
Self::handle_downstream (tx).await.ok ();
trace! ("Downstream {} ended", id);
});
Ok::<_, anyhow::Error> (Response::builder ()
.body (Body::wrap_stream (ReceiverStream::new (rx)))?)
}
async fn handle_posts (req: Request <Body>, state: &RelayState)
-> Result <Response <Body>, anyhow::Error>
{
let id = Ulid::new ().to_string ();
trace! ("Upstream {} started", id);
let mut body = req.into_body ();
while let Some (Ok (item)) = body.next ().await {
println! ("Chunk: {:?}", item);
}
trace! ("Upstream {} ended", id);
Ok::<_, anyhow::Error> (Response::builder ()
.body (Body::from ("hello\n"))?)
}
async fn handle_downstream (tx: mpsc::Sender <anyhow::Result <String>>) -> Result <(), anyhow::Error> {
let mut int = interval (Duration::from_secs (1));
let mut counter = 0u64;
loop {
int.tick ().await;
tx.send (Ok::<_, anyhow::Error> (format! ("Counter: {}\n", counter))).await?;
counter += 1;
}
}
}
#[tokio::main]
async fn main () -> Result <(), anyhow::Error> {
use std::time::Duration;
use tokio::{
spawn,
time::interval,
};
fmt ()
.with_env_filter (EnvFilter::from_default_env ())
.with_span_events (FmtSpan::CLOSE)
.init ()
;
let service = HttpService::new ();
info! ("Starting relay");
Ok (service.serve (4003).await?)
}
#[cfg (test)]
mod tests {
use super::*;
#[test]
fn state_machine () {
// assert! (false);
}
}