ptth/src/server/mod.rs

167 lines
3.6 KiB
Rust

use std::{
error::Error,
path::PathBuf,
sync::Arc,
time::Duration,
};
use handlebars::Handlebars;
use hyper::{
StatusCode,
};
use reqwest::Client;
use serde::Deserialize;
use tokio::{
time::delay_for,
};
use crate::{
http_serde,
prefix_match,
};
pub mod file_server;
struct ServerState {
opt: Opt,
handlebars: Handlebars <'static>,
client: Client,
}
fn status_reply (c: http_serde::StatusCode, body: &str) -> http_serde::Response
{
let mut r = http_serde::Response::default ();
r.status_code (c)
.body_bytes (body.as_bytes ().to_vec ());
r
}
async fn handle_req_resp <'a> (
state: Arc <ServerState>,
req_resp: reqwest::Response
) {
//println! ("Step 1");
if req_resp.status () != StatusCode::OK {
// TODO: Error handling
return;
}
let body = req_resp.bytes ().await.unwrap ();
let wrapped_reqs: Vec <http_serde::WrappedRequest> = match rmp_serde::from_read_ref (&body)
{
Ok (x) => x,
_ => return,
};
for wrapped_req in wrapped_reqs.into_iter () {
let state = state.clone ();
tokio::spawn (async move {
let (req_id, parts) = (wrapped_req.id, wrapped_req.req);
let response = if let Some (uri) = prefix_match (&parts.uri, "/files") {
file_server::serve_all (
&state.handlebars,
&state.opt.file_server_root,
parts.method,
uri,
&parts.headers
).await
}
else {
status_reply (http_serde::StatusCode::NotFound, "404 Not Found")
};
let mut resp_req = state.client
.post (&format! ("{}/7ZSFUKGV_http_response/{}", state.opt.relay_url, req_id))
.header (crate::PTTH_MAGIC_HEADER, base64::encode (rmp_serde::to_vec (&response.parts).unwrap ()));
if let Some (body) = response.body {
resp_req = resp_req.body (reqwest::Body::wrap_stream (body));
}
//println! ("Step 6");
if let Err (e) = resp_req.send ().await {
println! ("Err: {:?}", e);
}
});
}
}
#[derive (Default, Deserialize)]
pub struct ConfigFile {
pub name: String,
pub api_key: String,
}
#[derive (Clone)]
pub struct Opt {
pub relay_url: String,
pub file_server_root: PathBuf,
}
pub async fn main (config_file: ConfigFile, opt: Opt)
-> Result <(), Box <dyn Error>>
{
use std::convert::TryInto;
let tripcode = base64::encode (blake3::hash (config_file.api_key.as_bytes ()).as_bytes ());
println! ("Our tripcode is {}", tripcode);
let mut headers = reqwest::header::HeaderMap::new ();
headers.insert ("X-ApiKey", config_file.api_key.try_into ().unwrap ());
let client = Client::builder ()
.default_headers (headers)
.build ().unwrap ();
let handlebars = file_server::load_templates ()?;
let state = Arc::new (ServerState {
opt,
handlebars,
client,
});
let mut backoff_delay = 0;
loop {
if backoff_delay > 0 {
delay_for (Duration::from_millis (backoff_delay)).await;
}
let req_req = state.client.get (&format! ("{}/7ZSFUKGV_http_listen/{}", state.opt.relay_url, config_file.name));
let err_backoff_delay = std::cmp::min (30_000, backoff_delay * 2 + 500);
let req_resp = match req_req.send ().await {
Err (e) => {
eprintln! ("Err: {:?}", e);
backoff_delay = err_backoff_delay;
continue;
},
Ok (r) => {
backoff_delay = 0;
r
},
};
if req_resp.status () != StatusCode::OK {
eprintln! ("{}", req_resp.status ());
eprintln! ("{}", String::from_utf8 (req_resp.bytes ().await.unwrap ().to_vec ()).unwrap ());
backoff_delay = err_backoff_delay;
continue;
}
// Spawn another task for each request so we can
// immediately listen for the next connection
let state = state.clone ();
tokio::spawn (async move {
handle_req_resp (state, req_resp).await;
});
}
}