ptth/crates/ptth_relay/src/key_validity.rs

251 lines
5.7 KiB
Rust

use std::{
convert::TryInto,
fmt::{self, Debug, Formatter},
ops::Deref,
};
use chrono::{DateTime, Duration, Utc};
use serde::{
de::{
self,
Visitor,
},
Deserialize,
Deserializer,
};
#[derive (Copy, Clone, PartialEq, Eq)]
pub struct BlakeHashWrapper (blake3::Hash);
impl Debug for BlakeHashWrapper {
fn fmt (&self, f: &mut Formatter <'_>) -> Result <(), fmt::Error> {
write! (f, "{}", self.encode_base64 ())
}
}
impl BlakeHashWrapper {
pub fn from_key (bytes: &[u8]) -> Self {
Self (blake3::hash (bytes))
}
pub fn encode_base64 (&self) -> String {
base64::encode (self.as_bytes ())
}
}
impl Deref for BlakeHashWrapper {
type Target = blake3::Hash;
fn deref (&self) -> &<Self as Deref>::Target {
&self.0
}
}
struct BlakeHashVisitor;
impl <'de> Visitor <'de> for BlakeHashVisitor {
type Value = blake3::Hash;
fn expecting (&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str ("a 32-byte blake3 hash, encoded as base64")
}
fn visit_str <E: de::Error> (self, value: &str)
-> Result <Self::Value, E>
{
let bytes: Vec <u8> = base64::decode (value).map_err (|_| E::custom (format! ("str is not base64: {}", value)))?;
let bytes: [u8; 32] = (&bytes [..]).try_into ().map_err (|_| E::custom (format! ("decode base64 is not 32 bytes long: {}", value)))?;
let tripcode = blake3::Hash::from (bytes);
Ok (tripcode)
}
}
impl <'de> Deserialize <'de> for BlakeHashWrapper {
fn deserialize <D: Deserializer <'de>> (deserializer: D) -> Result <Self, D::Error> {
Ok (BlakeHashWrapper (deserializer.deserialize_str (BlakeHashVisitor)?))
}
}
pub struct Valid7Days;
//pub struct Valid30Days;
//pub struct Valid90Days;
pub trait MaxValidDuration {
fn dur () -> Duration;
}
impl MaxValidDuration for Valid7Days {
fn dur () -> Duration {
Duration::days (7)
}
}
#[derive (Deserialize)]
pub struct ScraperKey <V: MaxValidDuration> {
not_before: DateTime <Utc>,
not_after: DateTime <Utc>,
hash: BlakeHashWrapper,
#[serde (default)]
_phantom: std::marker::PhantomData <V>,
}
#[derive (Copy, Clone, Debug, PartialEq)]
pub enum KeyValidity {
Valid,
WrongKey (BlakeHashWrapper),
ClockIsBehind,
Expired,
DurationTooLong (Duration),
DurationNegative,
}
impl ScraperKey <Valid7Days> {
pub fn new (input: &[u8]) -> Self {
let now = Utc::now ();
Self {
not_before: now,
not_after: now + Duration::days (7),
hash: BlakeHashWrapper::from_key (input),
_phantom: Default::default (),
}
}
}
impl <V: MaxValidDuration> ScraperKey <V> {
pub fn is_valid (&self, now: DateTime <Utc>, input: &[u8]) -> KeyValidity {
use KeyValidity::*;
// I put this first because I think the constant-time check should run
// before anything else. But I'm not a crypto expert, so it's just
// guesswork.
let input_hash = BlakeHashWrapper::from_key (input);
if input_hash != self.hash {
return WrongKey (input_hash);
}
if self.not_after < self.not_before {
return DurationNegative;
}
let max_dur = V::dur ();
let actual_dur = self.not_after - self.not_before;
if actual_dur > max_dur {
return DurationTooLong (max_dur);
}
if now >= self.not_after {
return Expired;
}
if now < self.not_before {
return ClockIsBehind;
}
return Valid;
}
}
#[cfg (test)]
mod tests {
use chrono::{Utc};
use super::*;
use KeyValidity::*;
#[test]
fn duration_negative () {
let zero_time = Utc::now ();
let key = ScraperKey::<Valid7Days> {
not_before: zero_time + Duration::days (1 + 2),
not_after: zero_time + Duration::days (1),
hash: BlakeHashWrapper::from_key ("bad_password".as_bytes ()),
_phantom: Default::default (),
};
let err = DurationNegative;
for (input, expected) in &[
(zero_time + Duration::days (0), err),
(zero_time + Duration::days (2), err),
(zero_time + Duration::days (100), err),
] {
assert_eq! (key.is_valid (*input, "bad_password".as_bytes ()), *expected);
}
}
#[test]
fn key_valid_too_long () {
let zero_time = Utc::now ();
let key = ScraperKey::<Valid7Days> {
not_before: zero_time + Duration::days (1),
not_after: zero_time + Duration::days (1 + 8),
hash: BlakeHashWrapper::from_key ("bad_password".as_bytes ()),
_phantom: Default::default (),
};
let err = DurationTooLong (Duration::days (7));
for (input, expected) in &[
(zero_time + Duration::days (0), err),
(zero_time + Duration::days (2), err),
(zero_time + Duration::days (100), err),
] {
assert_eq! (key.is_valid (*input, "bad_password".as_bytes ()), *expected);
}
}
#[test]
fn normal_key () {
let zero_time = Utc::now ();
let key = ScraperKey::<Valid7Days> {
not_before: zero_time + Duration::days (1),
not_after: zero_time + Duration::days (1 + 7),
hash: BlakeHashWrapper::from_key ("bad_password".as_bytes ()),
_phantom: Default::default (),
};
for (input, expected) in &[
(zero_time + Duration::days (0), ClockIsBehind),
(zero_time + Duration::days (2), Valid),
(zero_time + Duration::days (1 + 7), Expired),
(zero_time + Duration::days (100), Expired),
] {
assert_eq! (key.is_valid (*input, "bad_password".as_bytes ()), *expected);
}
}
#[test]
fn wrong_key () {
let zero_time = Utc::now ();
let key = ScraperKey::<Valid7Days> {
not_before: zero_time + Duration::days (1),
not_after: zero_time + Duration::days (1 + 7),
hash: BlakeHashWrapper::from_key ("bad_password".as_bytes ()),
_phantom: Default::default (),
};
for input in &[
zero_time + Duration::days (0),
zero_time + Duration::days (2),
zero_time + Duration::days (1 + 7),
zero_time + Duration::days (100),
] {
let validity = key.is_valid (*input, "badder_password".as_bytes ());
match validity {
WrongKey (_) => (),
_ => panic! ("Expected WrongKey here"),
}
}
}
}