// Static file server that can plug into the PTTH reverse server use std::{ cmp::{min, max}, collections::*, convert::{Infallible, TryInto}, error::Error, io::SeekFrom, path::{Path, PathBuf}, }; use handlebars::Handlebars; use tokio::{ fs::{ File, read_dir, ReadDir, }, io::AsyncReadExt, sync::mpsc::{ channel, }, }; use tracing::{ debug, error, info, instrument, }; use regex::Regex; use crate::http_serde; fn parse_range_header (range_str: &str) -> (Option , Option ) { use lazy_static::*; lazy_static! { static ref RE: Regex = Regex::new (r"^bytes=(\d*)-(\d*)$").expect ("Couldn't compile regex for Range header"); } debug! ("{}", range_str); let caps = match RE.captures (range_str) { Some (x) => x, _ => return (None, None), }; let start = caps.get (1).map (|x| x.as_str ()); let end = caps.get (2).map (|x| x.as_str ()); let start = start.map (|x| u64::from_str_radix (x, 10).ok ()).flatten (); let end = end.map (|x| u64::from_str_radix (x, 10).ok ()).flatten (); (start, end) } use serde::Serialize; use tokio::fs::DirEntry; // This could probably be done with borrows, if I owned the // tokio::fs::DirEntry instead of consuming it #[derive (Serialize)] struct TemplateDirEntry { trailing_slash: &'static str, file_name: String, encoded_file_name: String, error: bool, } async fn read_dir_entry (entry: DirEntry) -> TemplateDirEntry { let trailing_slash = match entry.file_type ().await { Ok (t) => if t.is_dir () { "/" } else { "" }, Err (_) => "", }; let file_name = match entry.file_name ().into_string () { Ok (x) => x, Err (_) => return TemplateDirEntry { trailing_slash: "", file_name: "".into (), encoded_file_name: "".into (), error: true, }, }; use percent_encoding::*; let encoded_file_name = utf8_percent_encode (&file_name, CONTROLS).to_string (); TemplateDirEntry { trailing_slash: &trailing_slash, file_name, encoded_file_name, error: false, } } use std::borrow::Cow; #[instrument (level = "debug", skip (handlebars, dir))] async fn serve_dir ( handlebars: &Handlebars <'static>, path: Cow <'_, str>, mut dir: ReadDir ) -> http_serde::Response { let mut entries = vec! []; while let Ok (Some (entry)) = dir.next_entry ().await { entries.push (read_dir_entry (entry).await); } entries.sort_unstable_by (|a, b| a.file_name.partial_cmp (&b.file_name).unwrap ()); #[derive (Serialize)] struct TemplateDirPage <'a> { path: Cow <'a, str>, entries: Vec , } let s = handlebars.render ("file_server_dir", &TemplateDirPage { path, entries, }).unwrap (); let body = s.into_bytes (); let mut resp = http_serde::Response::default (); resp.content_length = Some (body.len ().try_into ().unwrap ()); resp .header ("content-type".to_string (), "text/html".to_string ().into_bytes ()) .header ("content-length".to_string (), body.len ().to_string ().into_bytes ()) .body_bytes (body) ; resp } #[instrument (level = "debug", skip (f))] async fn serve_file ( mut f: File, should_send_body: bool, range_start: Option , range_end: Option ) -> http_serde::Response { let (tx, rx) = channel (2); let body = if should_send_body { Some (rx) } else { None }; let file_md = f.metadata ().await.unwrap (); let file_len = file_md.len (); let start = range_start.unwrap_or (0); let end = range_end.unwrap_or (file_len); let start = max (0, min (start, file_len)); let end = max (0, min (end, file_len)); f.seek (SeekFrom::Start (start)).await.unwrap (); info! ("Serving range {}-{}", start, end); if should_send_body { tokio::spawn (async move { //println! ("Opening file {:?}", path); let mut tx = tx; let mut bytes_sent = 0; let mut bytes_left = end - start; loop { let mut buffer = vec! [0u8; 65_536]; let bytes_read: u64 = f.read (&mut buffer).await.unwrap ().try_into ().unwrap (); let bytes_read = min (bytes_left, bytes_read); buffer.truncate (bytes_read.try_into ().unwrap ()); if bytes_read == 0 { break; } if tx.send (Ok::<_, Infallible> (buffer)).await.is_err () { error! ("Send failed while streaming file ({} bytes sent)", bytes_sent); break; } bytes_left -= bytes_read; if bytes_left == 0 { info! ("Finished"); break; } bytes_sent += bytes_read; debug! ("Sent {} bytes", bytes_sent); //delay_for (Duration::from_millis (50)).await; } }); } let mut response = http_serde::Response::default (); response.header (String::from ("accept-ranges"), b"bytes".to_vec ()); if should_send_body { if range_start.is_none () && range_end.is_none () { response.status_code (http_serde::StatusCode::Ok); response.header (String::from ("content-length"), end.to_string ().into_bytes ()); } else { response.status_code (http_serde::StatusCode::PartialContent); response.header (String::from ("content-range"), format! ("bytes {}-{}/{}", start, end - 1, end).into_bytes ()); } response.content_length = Some (end - start); } if let Some (body) = body { response.body (body); } response } async fn serve_error ( status_code: http_serde::StatusCode, msg: String ) -> http_serde::Response { let mut resp = http_serde::Response::default (); resp.status_code (status_code) .body_bytes (msg.into_bytes ()); resp } #[instrument (level = "debug", skip (handlebars))] pub async fn serve_all ( handlebars: &Handlebars <'static>, root: &Path, method: http_serde::Method, uri: &str, headers: &HashMap >, ) -> http_serde::Response { info! ("Client requested {}", uri); let mut range_start = None; let mut range_end = None; if let Some (v) = headers.get ("range") { let v = std::str::from_utf8 (v).unwrap (); let (start, end) = parse_range_header (v); range_start = start; range_end = end; } let should_send_body = match &method { http_serde::Method::Get => true, http_serde::Method::Head => false, m => { debug! ("Unsupported method {:?}", m); return serve_error (http_serde::StatusCode::MethodNotAllowed, "Unsupported method".into ()).await; } }; use percent_encoding::*; // TODO: There is totally a dir traversal attack in here somewhere let encoded_path = &uri [1..]; let path_s = percent_decode (encoded_path.as_bytes ()).decode_utf8 ().unwrap (); let path = Path::new (&*path_s); let mut full_path = PathBuf::from (root); full_path.push (path); if let Ok (dir) = read_dir (&full_path).await { serve_dir ( handlebars, full_path.to_string_lossy (), dir ).await } else if let Ok (file) = File::open (&full_path).await { serve_file ( file, should_send_body, range_start, range_end ).await } else { serve_error (http_serde::StatusCode::NotFound, "404 Not Found".into ()).await } } pub fn load_templates () -> Result , Box > { let mut handlebars = Handlebars::new (); handlebars.set_strict_mode (true); for (k, v) in vec! [ ("file_server_dir", "file_server_dir.html"), ].into_iter () { handlebars.register_template_file (k, format! ("ptth_handlebars/{}", v))?; } Ok (handlebars) } #[cfg (test)] mod tests { #[test] fn i_hate_paths () { use std::{ ffi::OsStr, path::{Component, Path} }; let mut components = Path::new ("/home/user").components (); assert_eq! (components.next (), Some (Component::RootDir)); assert_eq! (components.next (), Some (Component::Normal (OsStr::new ("home")))); assert_eq! (components.next (), Some (Component::Normal (OsStr::new ("user")))); assert_eq! (components.next (), None); let mut components = Path::new ("./home/user").components (); assert_eq! (components.next (), Some (Component::CurDir)); assert_eq! (components.next (), Some (Component::Normal (OsStr::new ("home")))); assert_eq! (components.next (), Some (Component::Normal (OsStr::new ("user")))); assert_eq! (components.next (), None); let mut components = Path::new (".").components (); assert_eq! (components.next (), Some (Component::CurDir)); assert_eq! (components.next (), None); } }