♻️ Get rid of more unwraps and panics

main
_ 2020-11-29 21:38:23 +00:00
parent 7bd2450698
commit d6430e39a9
12 changed files with 251 additions and 138 deletions

View File

@ -121,10 +121,11 @@ Client Relay Server
O <----- O O <----- O
| P5 | P5
O <------ O O <------ O
P6/H3 P6/H3 | P7
O -----> O
``` ```
We'll call these steps "P1" through "P6". We'll call these steps "P1" through "P7".
1. The server makes a "listen" request to the relay, 1. The server makes a "listen" request to the relay,
punching out through the server's firewall. punching out through the server's firewall.
@ -138,6 +139,8 @@ to respond.
4. The server processes the request. (P4 == H2) 4. The server processes the request. (P4 == H2)
5. The server packages its response in another request to the relay. 5. The server packages its response in another request to the relay.
6. The relay unwraps the request and forwards it to the client. (P6 == H3) 6. The relay unwraps the request and forwards it to the client. (P6 == H3)
7. When the full response body has been streamed through the relay and to the
client, the relay will respond to the server.
Every step of the normal HTTP process is inverted for the server: Every step of the normal HTTP process is inverted for the server:

View File

@ -4,7 +4,7 @@ use super::*;
fn test_pretty_print_last_seen () { fn test_pretty_print_last_seen () {
use LastSeen::*; use LastSeen::*;
let last_seen = DateTime::parse_from_rfc3339 ("2019-05-29T00:00:00+00:00").unwrap ().with_timezone (&Utc); let last_seen = DateTime::parse_from_rfc3339 ("2019-05-29T00:00:00+00:00").expect ("Test case should be RFC3339").with_timezone (&Utc);
for (input, expected) in vec! [ for (input, expected) in vec! [
("2019-05-28T23:59:59+00:00", Negative), ("2019-05-28T23:59:59+00:00", Negative),
@ -18,7 +18,7 @@ fn test_pretty_print_last_seen () {
("2019-05-30T10:00:00+00:00", Description ("2019-05-29T00:00:00Z".into ())), ("2019-05-30T10:00:00+00:00", Description ("2019-05-29T00:00:00Z".into ())),
("2019-05-31T00:00:00+00:00", Description ("2019-05-29T00:00:00Z".into ())), ("2019-05-31T00:00:00+00:00", Description ("2019-05-29T00:00:00Z".into ())),
].into_iter () { ].into_iter () {
let now = DateTime::parse_from_rfc3339 (input).unwrap ().with_timezone (&Utc); let now = DateTime::parse_from_rfc3339 (input).expect ("Test case should be RFC3339").with_timezone (&Utc);
let actual = pretty_print_last_seen (now, last_seen); let actual = pretty_print_last_seen (now, last_seen);
assert_eq! (actual, expected); assert_eq! (actual, expected);
} }

View File

@ -9,6 +9,7 @@ license = "AGPL-3.0"
[dependencies] [dependencies]
aho-corasick = "0.7.14" aho-corasick = "0.7.14"
anyhow = "1.0.34"
base64 = "0.12.3" base64 = "0.12.3"
blake3 = "0.3.7" blake3 = "0.3.7"
futures = "0.3.7" futures = "0.3.7"

View File

@ -1,7 +1,6 @@
#![warn (clippy::pedantic)] #![warn (clippy::pedantic)]
use std::{ use std::{
error::Error,
net::SocketAddr, net::SocketAddr,
path::PathBuf, path::PathBuf,
sync::Arc, sync::Arc,
@ -91,11 +90,11 @@ pub struct ConfigFile {
} }
#[tokio::main] #[tokio::main]
async fn main () -> Result <(), Box <dyn Error>> { async fn main () -> Result <(), anyhow::Error> {
tracing_subscriber::fmt::init (); tracing_subscriber::fmt::init ();
let path = PathBuf::from ("./config/ptth_server.toml"); let path = PathBuf::from ("./config/ptth_server.toml");
let config_file: ConfigFile = load_toml::load (&path); let config_file: ConfigFile = load_toml::load (&path)?;
info! ("file_server_root: {:?}", config_file.file_server_root); info! ("file_server_root: {:?}", config_file.file_server_root);
let addr = SocketAddr::from(([0, 0, 0, 0], 4000)); let addr = SocketAddr::from(([0, 0, 0, 0], 4000));

View File

@ -1,7 +1,6 @@
#![warn (clippy::pedantic)] #![warn (clippy::pedantic)]
use std::{ use std::{
error::Error,
path::PathBuf, path::PathBuf,
}; };
@ -29,12 +28,12 @@ struct Opt {
print_tripcode: bool, print_tripcode: bool,
} }
fn main () -> Result <(), Box <dyn Error>> { fn main () -> Result <(), anyhow::Error> {
let opt = Opt::from_args (); let opt = Opt::from_args ();
tracing_subscriber::fmt::init (); tracing_subscriber::fmt::init ();
let path = opt.config_path.clone ().unwrap_or_else (|| PathBuf::from ("./config/ptth_server.toml")); let path = opt.config_path.clone ().unwrap_or_else (|| PathBuf::from ("./config/ptth_server.toml"));
let config_file: ConfigFile = load_toml::load (&path); let config_file: ConfigFile = load_toml::load (&path)?;
if opt.print_tripcode { if opt.print_tripcode {
println! (r#""{}" = "{}""#, config_file.name, config_file.tripcode ()); println! (r#""{}" = "{}""#, config_file.name, config_file.tripcode ());

View File

@ -1,19 +1,75 @@
use thiserror::Error; use thiserror::Error;
#[derive (Debug, Error)]
pub enum LoadTomlError {
#[error ("Config file has bad permissions mode, it should be octal 0600")]
ConfigBadPermissions,
#[error ("I/O")]
Io (#[from] std::io::Error),
#[error ("UTF-8")]
Utf8 (#[from] std::string::FromUtf8Error),
#[error ("TOML")]
Toml (#[from] toml::de::Error),
}
#[derive (Debug, Error)] #[derive (Debug, Error)]
pub enum ServerError { pub enum ServerError {
#[error ("Loading TOML")]
LoadToml (#[from] LoadTomlError),
#[error ("Loading Handlebars template file")]
LoadHandlebars (#[from] handlebars::TemplateFileError),
#[error ("API key is too weak, server can't use it")]
WeakApiKey,
#[error ("File server error")] #[error ("File server error")]
FileServer (#[from] super::file_server::errors::FileServerError), FileServer (#[from] super::file_server::errors::FileServerError),
// Hyper stuff
#[error ("Hyper HTTP error")] #[error ("Hyper HTTP error")]
Http (#[from] hyper::http::Error), Http (#[from] hyper::http::Error),
#[error ("Hyper invalid header name")] #[error ("Hyper invalid header name")]
InvalidHeaderName (#[from] hyper::header::InvalidHeaderName), InvalidHeaderName (#[from] hyper::header::InvalidHeaderName),
#[error ("Can't parse wrapped requests")] #[error ("API key invalid")]
ApiKeyInvalid (hyper::header::InvalidHeaderValue),
// MessagePack stuff
#[error ("Can't parse wrapped requests in Step 3")]
CantParseWrappedRequests (rmp_serde::decode::Error), CantParseWrappedRequests (rmp_serde::decode::Error),
#[error ("Can't encode PTTH response as MsgPack in Step 5")]
MessagePackEncodeResponse (rmp_serde::encode::Error),
#[error ("Can't convert Hyper request to PTTH request")] #[error ("Can't convert Hyper request to PTTH request")]
CantConvertHyperToPtth (#[from] ptth_core::http_serde::Error), CantConvertHyperToPtth (#[from] ptth_core::http_serde::Error),
// Reqwest stuff
#[error ("Can't build HTTP client")]
CantBuildHttpClient (reqwest::Error),
#[error ("Can't collect non-200 error response body in Step 3")]
Step3CollectBody (reqwest::Error),
#[error ("Can't collect wrapped requests in Step 3")]
CantCollectWrappedRequests (reqwest::Error),
#[error ("Error in Step 5, sending response to client through relay")]
Step5Responding (reqwest::Error),
#[error ("Error in Step 7, getting response from relay after sending response to client")]
Step7AfterResponse (reqwest::Error),
// UTF-8
#[error ("Step 3 relay response (non-200 OK) was not valid UTF-8")]
Step3ErrorResponseNotUtf8 (std::string::FromUtf8Error),
} }

View File

@ -31,4 +31,7 @@ pub enum FileServerError {
#[error ("Markdown error")] #[error ("Markdown error")]
Markdown (#[from] MarkdownError), Markdown (#[from] MarkdownError),
#[error ("Invalid URI")]
InvalidUri (#[from] hyper::http::uri::InvalidUri),
} }

View File

@ -8,7 +8,6 @@ use std::{
cmp::min, cmp::min,
collections::HashMap, collections::HashMap,
convert::{Infallible, TryFrom, TryInto}, convert::{Infallible, TryFrom, TryInto},
error::Error,
fmt::Debug, fmt::Debug,
io::SeekFrom, io::SeekFrom,
path::{Path, PathBuf}, path::{Path, PathBuf},
@ -451,7 +450,6 @@ struct ServeFileParams {
enum InternalResponse { enum InternalResponse {
Favicon, Favicon,
Forbidden, Forbidden,
InvalidUri,
InvalidQuery, InvalidQuery,
MethodNotAllowed, MethodNotAllowed,
NotFound, NotFound,
@ -465,6 +463,100 @@ enum InternalResponse {
MarkdownPreview (String), MarkdownPreview (String),
} }
fn internal_serve_dir (
path_s: &str,
path: &Path,
dir: tokio::fs::ReadDir,
full_path: PathBuf,
uri: &hyper::Uri
)
-> Result <InternalResponse, FileServerError>
{
let has_trailing_slash = path_s.is_empty () || path_s.ends_with ('/');
if ! has_trailing_slash {
let file_name = path.file_name ().ok_or (FileServerError::NoFileNameRequested)?;
let file_name = file_name.to_str ().ok_or (FileServerError::FilePathNotUtf8)?;
return Ok (InternalResponse::Redirect (format! ("{}/", file_name)));
}
if uri.query ().is_some () {
return Ok (InternalResponse::InvalidQuery);
}
let dir = dir.into ();
Ok (InternalResponse::ServeDir (ServeDirParams {
dir,
path: full_path,
}))
}
async fn internal_serve_file (
mut file: tokio::fs::File,
uri: &hyper::Uri,
send_body: bool,
headers: &HashMap <String, Vec <u8>>
)
-> Result <InternalResponse, FileServerError>
{
use std::os::unix::fs::PermissionsExt;
let file_md = file.metadata ().await.map_err (FileServerError::CantGetFileMetadata)?;
if file_md.permissions ().mode () == super::load_toml::CONFIG_PERMISSIONS_MODE
{
return Ok (InternalResponse::Forbidden);
}
let file_len = file_md.len ();
let range_header = headers.get ("range").and_then (|v| std::str::from_utf8 (v).ok ());
Ok (match check_range (range_header, file_len) {
ParsedRange::RangeNotSatisfiable (file_len) => InternalResponse::RangeNotSatisfiable (file_len),
ParsedRange::Ok (range) => {
if uri.query () == Some ("as_markdown") {
const MAX_BUF_SIZE: u32 = 1_000_000;
if file_len > MAX_BUF_SIZE.into () {
InternalResponse::MarkdownErr (MarkdownError::TooBig)
}
else {
let mut buffer = vec! [0_u8; MAX_BUF_SIZE.try_into ().expect ("Couldn't fit u32 into usize")];
let bytes_read = file.read (&mut buffer).await?;
buffer.truncate (bytes_read);
InternalResponse::MarkdownPreview (render_markdown_styled (&buffer)?)
}
}
else {
let file = file.into ();
InternalResponse::ServeFile (ServeFileParams {
file,
send_body,
range,
range_requested: false,
})
}
},
ParsedRange::PartialContent (range) => {
if uri.query ().is_some () {
InternalResponse::InvalidQuery
}
else {
let file = file.into ();
InternalResponse::ServeFile (ServeFileParams {
file,
send_body,
range,
range_requested: true,
})
}
},
})
}
async fn internal_serve_all ( async fn internal_serve_all (
root: &Path, root: &Path,
method: Method, method: Method,
@ -479,10 +571,7 @@ async fn internal_serve_all (
info! ("Client requested {}", uri); info! ("Client requested {}", uri);
let uri = match hyper::Uri::from_str (uri) { let uri = hyper::Uri::from_str (uri).map_err (FileServerError::InvalidUri)?;
Err (_) => return Ok (InvalidUri),
Ok (x) => x,
};
let send_body = match &method { let send_body = match &method {
Method::Get => true, Method::Get => true,
@ -523,85 +612,26 @@ async fn internal_serve_all (
} }
} }
let has_trailing_slash = path_s.is_empty () || path_s.ends_with ('/'); if let Ok (dir) = read_dir (&full_path).await {
internal_serve_dir (
Ok (if let Ok (dir) = read_dir (&full_path).await { &path_s,
if ! has_trailing_slash { path,
let file_name = path.file_name ().ok_or (FileServerError::NoFileNameRequested)?;
return Ok (Redirect (format! ("{}/", file_name.to_str ().ok_or (FileServerError::FilePathNotUtf8)?)));
}
if uri.query ().is_some () {
return Ok (InvalidQuery);
}
let dir = dir.into ();
ServeDir (ServeDirParams {
dir, dir,
path: full_path, full_path,
}) &uri
)
} }
else if let Ok (mut file) = File::open (&full_path).await { else if let Ok (file) = File::open (&full_path).await {
use std::os::unix::fs::PermissionsExt; internal_serve_file (
file,
let file_md = file.metadata ().await.map_err (FileServerError::CantGetFileMetadata)?; &uri,
if file_md.permissions ().mode () == super::load_toml::CONFIG_PERMISSIONS_MODE send_body,
{ headers
return Ok (Forbidden); ).await
}
let file_len = file_md.len ();
let range_header = headers.get ("range").and_then (|v| std::str::from_utf8 (v).ok ());
match check_range (range_header, file_len) {
ParsedRange::RangeNotSatisfiable (file_len) => RangeNotSatisfiable (file_len),
ParsedRange::Ok (range) => {
if uri.query () == Some ("as_markdown") {
const MAX_BUF_SIZE: u32 = 1_000_000;
if file_len > MAX_BUF_SIZE.into () {
MarkdownErr (MarkdownError::TooBig)
}
else {
let mut buffer = vec! [0_u8; MAX_BUF_SIZE.try_into ().expect ("Couldn't fit u32 into usize")];
let bytes_read = file.read (&mut buffer).await?;
buffer.truncate (bytes_read);
MarkdownPreview (render_markdown_styled (&buffer)?)
}
}
else {
let file = file.into ();
ServeFile (ServeFileParams {
file,
send_body,
range,
range_requested: false,
})
}
},
ParsedRange::PartialContent (range) => {
if uri.query ().is_some () {
InvalidQuery
}
else {
let file = file.into ();
ServeFile (ServeFileParams {
file,
send_body,
range,
range_requested: true,
})
}
},
}
} }
else { else {
NotFound Ok (NotFound)
}) }
} }
#[instrument (level = "debug", skip (handlebars, headers))] #[instrument (level = "debug", skip (handlebars, headers))]
@ -621,7 +651,6 @@ pub async fn serve_all (
Ok (match internal_serve_all (root, method, uri, headers, hidden_path).await? { Ok (match internal_serve_all (root, method, uri, headers, hidden_path).await? {
Favicon => serve_error (StatusCode::NotFound, ""), Favicon => serve_error (StatusCode::NotFound, ""),
Forbidden => serve_error (StatusCode::Forbidden, "403 Forbidden"), Forbidden => serve_error (StatusCode::Forbidden, "403 Forbidden"),
InvalidUri => serve_error (StatusCode::BadRequest, "Invalid URI"),
InvalidQuery => serve_error (StatusCode::BadRequest, "Query is invalid for this object"), InvalidQuery => serve_error (StatusCode::BadRequest, "Query is invalid for this object"),
MethodNotAllowed => serve_error (StatusCode::MethodNotAllowed, "Unsupported method"), MethodNotAllowed => serve_error (StatusCode::MethodNotAllowed, "Unsupported method"),
NotFound => serve_error (StatusCode::NotFound, "404 Not Found"), NotFound => serve_error (StatusCode::NotFound, "404 Not Found"),
@ -656,7 +685,7 @@ pub async fn serve_all (
pub fn load_templates ( pub fn load_templates (
asset_root: &Path asset_root: &Path
) )
-> Result <Handlebars <'static>, Box <dyn Error>> -> Result <Handlebars <'static>, handlebars::TemplateFileError>
{ {
let mut handlebars = Handlebars::new (); let mut handlebars = Handlebars::new ();
handlebars.set_strict_mode (true); handlebars.set_strict_mode (true);

View File

@ -119,7 +119,7 @@ fn file_server () {
use super::*; use super::*;
tracing_subscriber::fmt ().try_init ().ok (); tracing_subscriber::fmt ().try_init ().ok ();
let mut rt = Runtime::new ().unwrap (); let mut rt = Runtime::new ().expect ("Can't create runtime");
rt.block_on (async { rt.block_on (async {
let file_server_root = PathBuf::from ("./"); let file_server_root = PathBuf::from ("./");
@ -127,6 +127,7 @@ fn file_server () {
{ {
use InternalResponse::*; use InternalResponse::*;
use crate::file_server::FileServerError;
let bad_passwords_path = "/files/src/bad_passwords.txt"; let bad_passwords_path = "/files/src/bad_passwords.txt";
@ -148,8 +149,7 @@ fn file_server () {
range_requested: false, range_requested: false,
file: AlwaysEqual::testing_blank (), file: AlwaysEqual::testing_blank (),
})), })),
("/ ", InvalidUri), ] {
].into_iter () {
let resp = internal_serve_all ( let resp = internal_serve_all (
&file_server_root, &file_server_root,
Method::Get, Method::Get,
@ -158,7 +158,24 @@ fn file_server () {
None None
).await; ).await;
assert_eq! (resp.unwrap (), expected); assert_eq! (resp.expect ("This block only tests Ok (_) responses"), expected);
}
for (uri_path, checker) in vec! [
("/ ", |e| match e {
FileServerError::InvalidUri (_) => (),
e => panic! ("Expected InvalidUri, got {:?}", e),
}),
] {
let resp = internal_serve_all (
&file_server_root,
Method::Get,
uri_path,
&headers,
None
).await;
checker (resp.unwrap_err ());
} }
let resp = internal_serve_all ( let resp = internal_serve_all (
@ -171,7 +188,7 @@ fn file_server () {
None None
).await; ).await;
assert_eq! (resp.unwrap (), RangeNotSatisfiable (1_048_576)); assert_eq! (resp.expect ("Should be Ok (_)"), RangeNotSatisfiable (1_048_576));
let resp = internal_serve_all ( let resp = internal_serve_all (
&file_server_root, &file_server_root,
@ -181,7 +198,7 @@ fn file_server () {
None None
).await; ).await;
assert_eq! (resp.unwrap (), ServeFile (ServeFileParams { assert_eq! (resp.expect ("Should be Ok (_)"), ServeFile (ServeFileParams {
send_body: false, send_body: false,
range: 0..1_048_576, range: 0..1_048_576,
range_requested: false, range_requested: false,
@ -210,7 +227,7 @@ fn markdown () {
), ),
].into_iter () { ].into_iter () {
let mut out = String::default (); let mut out = String::default ();
render_markdown (input.as_bytes (), &mut out).unwrap (); render_markdown (input.as_bytes (), &mut out).expect ("Markdown sample failed");
assert_eq! (expected, &out); assert_eq! (expected, &out);
} }
} }

View File

@ -8,7 +8,6 @@
#![allow (clippy::mut_mut)] #![allow (clippy::mut_mut)]
use std::{ use std::{
error::Error,
path::PathBuf, path::PathBuf,
sync::Arc, sync::Arc,
time::Duration, time::Duration,
@ -66,7 +65,7 @@ async fn handle_req_resp <'a> (
) -> Result <(), ServerError> { ) -> Result <(), ServerError> {
//println! ("Step 1"); //println! ("Step 1");
let body = req_resp.bytes ().await.unwrap (); let body = req_resp.bytes ().await.map_err (ServerError::CantCollectWrappedRequests)?;
let wrapped_reqs: Vec <http_serde::WrappedRequest> = match rmp_serde::from_read_ref (&body) let wrapped_reqs: Vec <http_serde::WrappedRequest> = match rmp_serde::from_read_ref (&body)
{ {
Ok (x) => x, Ok (x) => x,
@ -81,6 +80,8 @@ async fn handle_req_resp <'a> (
for wrapped_req in wrapped_reqs { for wrapped_req in wrapped_reqs {
let state = state.clone (); let state = state.clone ();
// These have to detach, so we won't be able to catch the join errors.
tokio::spawn (async move { tokio::spawn (async move {
let (req_id, parts) = (wrapped_req.id, wrapped_req.req); let (req_id, parts) = (wrapped_req.id, wrapped_req.req);
@ -99,11 +100,11 @@ async fn handle_req_resp <'a> (
&parts.uri, &parts.uri,
&parts.headers, &parts.headers,
state.hidden_path.as_deref () state.hidden_path.as_deref ()
).await.unwrap (); ).await?;
let mut resp_req = state.client let mut resp_req = state.client
.post (&format! ("{}/http_response/{}", state.config.relay_url, req_id)) .post (&format! ("{}/http_response/{}", state.config.relay_url, req_id))
.header (ptth_core::PTTH_MAGIC_HEADER, base64::encode (rmp_serde::to_vec (&response.parts).unwrap ())); .header (ptth_core::PTTH_MAGIC_HEADER, base64::encode (rmp_serde::to_vec (&response.parts).map_err (ServerError::MessagePackEncodeResponse)?));
if let Some (length) = response.content_length { if let Some (length) = response.content_length {
resp_req = resp_req.header ("Content-Length", length.to_string ()); resp_req = resp_req.header ("Content-Length", length.to_string ());
@ -112,7 +113,7 @@ async fn handle_req_resp <'a> (
resp_req = resp_req.body (reqwest::Body::wrap_stream (body)); resp_req = resp_req.body (reqwest::Body::wrap_stream (body));
} }
let req = resp_req.build ().unwrap (); let req = resp_req.build ().map_err (ServerError::Step5Responding)?;
debug! ("{:?}", req.headers ()); debug! ("{:?}", req.headers ());
@ -120,7 +121,7 @@ async fn handle_req_resp <'a> (
match state.client.execute (req).await { match state.client.execute (req).await {
Ok (r) => { Ok (r) => {
let status = r.status (); let status = r.status ();
let text = r.text ().await.unwrap (); let text = r.text ().await.map_err (ServerError::Step7AfterResponse)?;
debug! ("{:?} {:?}", status, text); debug! ("{:?} {:?}", status, text);
}, },
Err (e) => { Err (e) => {
@ -133,6 +134,7 @@ async fn handle_req_resp <'a> (
}, },
} }
Ok::<(), ServerError> (())
}); });
} }
@ -166,14 +168,14 @@ pub async fn run_server (
hidden_path: Option <PathBuf>, hidden_path: Option <PathBuf>,
asset_root: Option <PathBuf> asset_root: Option <PathBuf>
) )
-> Result <(), Box <dyn Error>> -> Result <(), ServerError>
{ {
use std::convert::TryInto; use std::convert::TryInto;
let asset_root = asset_root.unwrap_or_else (PathBuf::new); let asset_root = asset_root.unwrap_or_else (PathBuf::new);
if password_is_bad (config_file.api_key.clone ()) { if password_is_bad (config_file.api_key.clone ()) {
panic! ("API key is too weak, server can't use it"); return Err (ServerError::WeakApiKey);
} }
let server_info = file_server::ServerInfo { let server_info = file_server::ServerInfo {
@ -184,13 +186,13 @@ pub async fn run_server (
info! ("Tripcode is {}", config_file.tripcode ()); info! ("Tripcode is {}", config_file.tripcode ());
let mut headers = reqwest::header::HeaderMap::new (); let mut headers = reqwest::header::HeaderMap::new ();
headers.insert ("X-ApiKey", config_file.api_key.try_into ().unwrap ()); headers.insert ("X-ApiKey", config_file.api_key.try_into ().map_err (ServerError::ApiKeyInvalid)?);
let client = Client::builder () let client = Client::builder ()
.default_headers (headers) .default_headers (headers)
.timeout (Duration::from_secs (40)) .timeout (Duration::from_secs (40))
.build ().unwrap (); .build ().map_err (ServerError::CantBuildHttpClient)?;
let handlebars = file_server::load_templates (&asset_root).expect ("Can't load Handlebars templates"); let handlebars = file_server::load_templates (&asset_root)?;
let state = Arc::new (ServerState { let state = Arc::new (ServerState {
config: Config { config: Config {
@ -260,8 +262,8 @@ pub async fn run_server (
} }
else if req_resp.status () != StatusCode::OK { else if req_resp.status () != StatusCode::OK {
error! ("{}", req_resp.status ()); error! ("{}", req_resp.status ());
let body = req_resp.bytes ().await.unwrap (); let body = req_resp.bytes ().await.map_err (ServerError::Step3CollectBody)?;
let body = String::from_utf8 (body.to_vec ()).unwrap (); let body = String::from_utf8 (body.to_vec ()).map_err (ServerError::Step3ErrorResponseNotUtf8)?;
error! ("{}", body); error! ("{}", body);
if backoff_delay != err_backoff_delay { if backoff_delay != err_backoff_delay {
error! ("Non-timeout issue, increasing backoff_delay"); error! ("Non-timeout issue, increasing backoff_delay");

View File

@ -7,19 +7,21 @@ use std::{
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
use crate::errors::LoadTomlError;
pub const CONFIG_PERMISSIONS_MODE: u32 = 33152; pub const CONFIG_PERMISSIONS_MODE: u32 = 33152;
fn load_inner < fn load_inner <
T: DeserializeOwned T: DeserializeOwned
> ( > (
mut f: File mut f: File
) -> T { ) -> Result <T, LoadTomlError> {
let mut buffer = vec! [0_u8; 4096]; let mut buffer = vec! [0_u8; 4096];
let bytes_read = f.read (&mut buffer).unwrap_or_else (|_| panic! ("Can't read config")); let bytes_read = f.read (&mut buffer)?;
buffer.truncate (bytes_read); buffer.truncate (bytes_read);
let config_s = String::from_utf8 (buffer).unwrap_or_else (|_| panic! ("Can't parse config as UTF-8")); let config_s = String::from_utf8 (buffer)?;
toml::from_str (&config_s).unwrap_or_else (|e| panic! ("Can't parse config as TOML: {}", e)) Ok (toml::from_str (&config_s)?)
} }
/// For files that contain public-viewable information /// For files that contain public-viewable information
@ -29,8 +31,8 @@ pub fn load_public <
P: AsRef <Path> + Debug P: AsRef <Path> + Debug
> ( > (
config_file_path: P config_file_path: P
) -> T { ) -> Result <T, LoadTomlError> {
let f = File::open (&config_file_path).unwrap_or_else (|_| panic! ("Can't open {:?}", config_file_path)); let f = File::open (&config_file_path)?;
load_inner (f) load_inner (f)
} }
@ -42,13 +44,15 @@ pub fn load <
P: AsRef <Path> + Debug P: AsRef <Path> + Debug
> ( > (
config_file_path: P config_file_path: P
) -> T { ) -> Result <T, LoadTomlError> {
use std::os::unix::fs::PermissionsExt; use std::os::unix::fs::PermissionsExt;
let f = File::open (&config_file_path).unwrap_or_else (|_| panic! ("Can't open {:?}", config_file_path)); let f = File::open (&config_file_path)?;
let mode = f.metadata ().unwrap ().permissions ().mode (); let mode = f.metadata ()?.permissions ().mode ();
assert_eq! (mode, CONFIG_PERMISSIONS_MODE, "Config file has bad permissions mode, it should be octal 0600"); if mode != CONFIG_PERMISSIONS_MODE {
return Err (LoadTomlError::ConfigBadPermissions);
}
load_inner (f) load_inner (f)
} }

View File

@ -23,7 +23,7 @@ fn end_to_end () {
// and we don't care if another test already installed a subscriber. // and we don't care if another test already installed a subscriber.
tracing_subscriber::fmt ().try_init ().ok (); tracing_subscriber::fmt ().try_init ().ok ();
let mut rt = Runtime::new ().unwrap (); let mut rt = Runtime::new ().expect ("Can't create runtime for testing");
// Spawn the root task // Spawn the root task
rt.block_on (async { rt.block_on (async {
@ -41,14 +41,14 @@ fn end_to_end () {
}, },
}; };
let config = ptth_relay::config::Config::try_from (config_file).unwrap (); let config = ptth_relay::config::Config::try_from (config_file).expect ("Can't load config");
let relay_state = Arc::new (ptth_relay::RelayState::try_from (config).unwrap ()); let relay_state = Arc::new (ptth_relay::RelayState::try_from (config).expect ("Can't create relay state"));
let relay_state_2 = relay_state.clone (); let relay_state_2 = relay_state.clone ();
let (stop_relay_tx, stop_relay_rx) = oneshot::channel (); let (stop_relay_tx, stop_relay_rx) = oneshot::channel ();
let task_relay = spawn (async move { let task_relay = spawn (async move {
ptth_relay::run_relay (relay_state_2, stop_relay_rx, None).await.unwrap (); ptth_relay::run_relay (relay_state_2, stop_relay_rx, None).await
}); });
assert! (relay_state.list_servers ().await.is_empty ()); assert! (relay_state.list_servers ().await.is_empty ());
@ -65,7 +65,7 @@ fn end_to_end () {
let (stop_server_tx, stop_server_rx) = oneshot::channel (); let (stop_server_tx, stop_server_rx) = oneshot::channel ();
let task_server = { let task_server = {
spawn (async move { spawn (async move {
ptth_server::run_server (config_file, stop_server_rx, None, None).await.unwrap (); ptth_server::run_server (config_file, stop_server_rx, None, None).await
}) })
}; };
@ -77,15 +77,15 @@ fn end_to_end () {
let client = Client::builder () let client = Client::builder ()
.timeout (Duration::from_secs (2)) .timeout (Duration::from_secs (2))
.build ().unwrap (); .build ().expect ("Couldn't build HTTP client");
let resp = client.get (&format! ("{}/frontend/relay_up_check", relay_url)) let resp = client.get (&format! ("{}/frontend/relay_up_check", relay_url))
.send ().await.unwrap ().bytes ().await.unwrap (); .send ().await.expect ("Couldn't check if relay is up").bytes ().await.expect ("Couldn't check if relay is up");
assert_eq! (resp, "Relay is up\n"); assert_eq! (resp, "Relay is up\n");
let resp = client.get (&format! ("{}/frontend/servers/{}/files/COPYING", relay_url, server_name)) let resp = client.get (&format! ("{}/frontend/servers/{}/files/COPYING", relay_url, server_name))
.send ().await.unwrap ().bytes ().await.unwrap (); .send ().await.expect ("Couldn't find license").bytes ().await.expect ("Couldn't find license");
if blake3::hash (&resp) != blake3::Hash::from ([ if blake3::hash (&resp) != blake3::Hash::from ([
0xca, 0x02, 0x92, 0x78, 0xca, 0x02, 0x92, 0x78,
@ -98,28 +98,28 @@ fn end_to_end () {
0x2c, 0x4a, 0xac, 0x1f, 0x2c, 0x4a, 0xac, 0x1f,
0x1a, 0xbb, 0xa8, 0xef, 0x1a, 0xbb, 0xa8, 0xef,
]) { ]) {
panic! ("{}", String::from_utf8 (resp.to_vec ()).unwrap ()); panic! ("{}", String::from_utf8 (resp.to_vec ()).expect ("???"));
} }
// Requesting a file from a server that isn't registered // Requesting a file from a server that isn't registered
// will error out // will error out
let resp = client.get (&format! ("{}/frontend/servers/obviously_this_server_does_not_exist/files/COPYING", relay_url)) let resp = client.get (&format! ("{}/frontend/servers/obviously_this_server_does_not_exist/files/COPYING", relay_url))
.send ().await.unwrap (); .send ().await.expect ("Couldn't send request to bogus server");
assert_eq! (resp.status (), reqwest::StatusCode::NOT_FOUND); assert_eq! (resp.status (), reqwest::StatusCode::NOT_FOUND);
info! ("Shutting down end-to-end test"); info! ("Shutting down end-to-end test");
stop_server_tx.send (()).unwrap (); stop_server_tx.send (()).expect ("Couldn't shut down server");
stop_relay_tx.send (()).unwrap (); stop_relay_tx.send (()).expect ("Couldn't shut down relay");
info! ("Sent stop messages"); info! ("Sent stop messages");
task_relay.await.unwrap (); task_relay.await.expect ("Couldn't join relay").expect ("Relay error");
info! ("Relay stopped"); info! ("Relay stopped");
task_server.await.unwrap (); task_server.await.expect ("Couldn't join server").expect ("Server error");
info! ("Server stopped"); info! ("Server stopped");
}); });
} }