diff --git a/src/bin/server.rs b/src/bin/server.rs index 5be1ed0..6ee5ae4 100644 --- a/src/bin/server.rs +++ b/src/bin/server.rs @@ -1,6 +1,7 @@ use std::{ + cmp::{min, max}, collections::*, - convert::Infallible, + convert::{Infallible, TryInto}, error::Error, io::SeekFrom, path::PathBuf, @@ -31,9 +32,11 @@ use ptth::http_serde; fn parse_range_header (range_str: &str) -> (Option , Option ) { lazy_static! { - static ref RE: Regex = Regex::new (r"^(\d+)-(\d+)$").expect ("Couldn't compile regex for Range header"); + static ref RE: Regex = Regex::new (r"^bytes=(\d*)-(\d*)$").expect ("Couldn't compile regex for Range header"); } + println! ("{}", range_str); + let caps = match RE.captures (range_str) { Some (x) => x, _ => return (None, None), @@ -100,7 +103,7 @@ async fn main () -> Result <(), Box > { for (k, v) in parts.headers.iter () { let v = std::str::from_utf8 (v).unwrap (); - println! ("{}: {}", k, v); + //println! ("{}: {}", k, v); if k == "range" { let (start, end) = parse_range_header (v); @@ -114,7 +117,7 @@ async fn main () -> Result <(), Box > { _ => false, }; - println! ("Step 6"); + //println! ("Step 6"); let client = client.clone (); tokio::spawn (async move { let (tx, rx) = channel (2); @@ -129,11 +132,18 @@ async fn main () -> Result <(), Box > { path.push (&uri [1..]); let mut f = File::open (path).await.unwrap (); - if let Some (start) = range_start { - f.seek (SeekFrom::Start (start)).await.unwrap (); - } - let file_md = f.metadata ().await.unwrap (); + let file_len = file_md.len (); + + let start = range_start.unwrap_or (0); + let end = range_end.unwrap_or (file_len); + + let start = max (0, min (start, file_len)); + let end = max (0, min (end, file_len)); + + f.seek (SeekFrom::Start (start)).await.unwrap (); + + println! ("Serving range {}-{}", start, end); if should_send_body { tokio::spawn (async move { @@ -141,12 +151,15 @@ async fn main () -> Result <(), Box > { let mut tx = tx; //let mut bytes_sent = 0; + let mut bytes_left = end - start; loop { let mut buffer = vec! [0u8; 65_536]; - let bytes_read = f.read (&mut buffer).await.unwrap (); + let bytes_read: u64 = f.read (&mut buffer).await.unwrap ().try_into ().unwrap (); - buffer.truncate (bytes_read); + let bytes_read = min (bytes_left, bytes_read); + + buffer.truncate (bytes_read.try_into ().unwrap ()); if bytes_read == 0 { break; @@ -156,6 +169,11 @@ async fn main () -> Result <(), Box > { break; } + bytes_left -= bytes_read; + if bytes_left == 0 { + break; + } + //bytes_sent += bytes_read; //println! ("Sent {} bytes", bytes_sent); @@ -165,14 +183,21 @@ async fn main () -> Result <(), Box > { } let mut headers: HashMap > = Default::default (); - //headers.insert (String::from ("x-its-a-header"), Vec::from (&b"wow"[..])); + headers.insert (String::from ("accept-ranges"), b"bytes".to_vec ()); + + let status_code; if range_start.is_none () && range_end.is_none () { - headers.insert (String::from ("content-length"), file_md.len ().to_string ().into_bytes ()); + headers.insert (String::from ("content-length"), end.to_string ().into_bytes ()); + status_code = http_serde::StatusCode::Ok; + } + else { + headers.insert (String::from ("content-range"), format! ("bytes {}-{}/{}", start, end - 1, end).into_bytes ()); + status_code = http_serde::StatusCode::PartialContent; } let resp_parts = http_serde::ResponseParts { - status_code: http_serde::StatusCode::Ok, + status_code, headers, }; @@ -184,7 +209,7 @@ async fn main () -> Result <(), Box > { resp_req = resp_req.body (body); } - println! ("Step 6"); + //println! ("Step 6"); if let Err (e) = resp_req.send ().await { println! ("Err: {:?}", e); } diff --git a/todo.md b/todo.md index 4f963e2..6e00e07 100644 --- a/todo.md +++ b/todo.md @@ -1 +1,2 @@ -- Byte range request header +- Set up tokens or something we clients can't trivially +impersonate servers