ptth/crates/insecure_chat/src/main.rs

376 lines
8.3 KiB
Rust

use std::{
collections::*,
net::{
Ipv4Addr,
SocketAddrV4,
},
sync::Arc,
time::Duration,
};
use hyper::{
Body,
Method,
Request,
Response,
Server,
StatusCode,
service::{
make_service_fn,
service_fn,
},
};
use tokio::{
net::UdpSocket,
sync::RwLock,
};
mod ip;
mod tlv;
fn main () -> Result <(), Error>
{
let mut args = std::env::args ();
let mut bail_unknown = true;
let mut last_unknown = None;
let mut name = ptth_diceware::passphrase ("_", 3);
let mut subcommand_count = 0;
let mut subcommand = None;
args.next ();
while let Some (arg) = args.next () {
if arg == "--ignore-unknown" {
bail_unknown = false;
}
if arg == "--name" {
name = args.next ().unwrap ().to_string ();
}
else if arg == "peer" {
subcommand = Some (Subcommand::Peer);
subcommand_count += 1;
}
else if arg == "receiver" {
subcommand = Some (Subcommand::Receiver);
subcommand_count += 1;
}
else if arg == "sender" {
subcommand = Some (Subcommand::Sender);
subcommand_count += 1;
}
else if arg == "spy" {
subcommand = Some (Subcommand::Spy);
subcommand_count += 1;
}
else {
last_unknown = Some (arg);
}
}
if bail_unknown {
if let Some (last_unknown) = last_unknown {
eprintln! ("Unknown argument `{}`", last_unknown);
return Err (Error::Args);
}
}
if subcommand_count >= 2 {
eprintln! ("Detected {} subcommands in arguments", subcommand_count);
return Err (Error::Args)
}
let rt = tokio::runtime::Runtime::new ()?;
let params = Params::default ();
rt.block_on (async {
if let Some (cmd) = subcommand {
return match cmd {
Subcommand::Peer => peer (params).await,
Subcommand::Receiver => receiver (params).await,
Subcommand::Sender => sender (params).await,
Subcommand::Spy => spy (params),
};
}
println! ("Name is `{}`", name);
Ok::<_, Error> (())
})?;
Ok (())
}
enum Subcommand {
Peer,
Receiver,
Sender,
Spy,
}
struct Params {
multicast_group: (Ipv4Addr, u16),
}
impl Default for Params {
fn default () -> Self {
let multicast_group = (Ipv4Addr::new (225, 100, 99, 98), 9041);
Self {
multicast_group,
}
}
}
async fn peer (params: Params) -> Result <(), Error>
{
use rand::Rng;
let mut id = [0];
rand::thread_rng ().try_fill (&mut id).or (Err (Error::Rand))?;
let (multicast_addr, multicast_port) = params.multicast_group;
let socket = tokio::net::UdpSocket::bind (SocketAddrV4::new (Ipv4Addr::UNSPECIFIED, multicast_port)).await?;
socket.join_multicast_v4 (multicast_addr, Ipv4Addr::UNSPECIFIED)?;
eprintln! ("Multicast group is {:?}", params.multicast_group);
eprintln! ("Local addr is {}", socket.local_addr ()?);
let peer = Peer {
id: id [0],
outbox: Outbox {
index: 1000,
messages: Default::default (),
}.into (),
params,
socket,
};
eprintln! ("Random peer ID is {}", peer.id);
let state = Arc::new (peer);
{
let state = Arc::clone (&state);
tokio::spawn (async move {
let mut interval = tokio::time::interval (Duration::from_secs (25));
interval.set_missed_tick_behavior (tokio::time::MissedTickBehavior::Skip);
loop {
interval.tick ().await;
state.send_multicast (&tlv::Message::IAmOnline { peer_id: state.id }).await.ok ();
}
});
}
{
let state = Arc::clone (&state);
tokio::spawn (async move {
loop {
let mut buf = vec! [0u8; 2048];
let (bytes_recved, addr) = match state.socket.recv_from (&mut buf).await
{
Err (_) => {
tokio::time::sleep (Duration::from_secs (10)).await;
continue;
},
Ok (x) => x,
};
let buf = &buf [0..bytes_recved];
let msg = match tlv::decode (buf) {
Err (_) => {
eprintln! ("ZAT4ERXR Couldn't decode message");
continue;
},
Ok (x) => x,
};
println! ("Received {:?}", msg);
}
});
}
let make_svc = make_service_fn (|_conn| {
let state = Arc::clone (&state);
async {
Ok::<_, String> (service_fn (move |req| {
let state = state.clone ();
peer_handle_all (req, state)
}))
}
});
let addr = std::net::SocketAddr::from (([127, 0, 0, 1], multicast_port));
let server = Server::bind (&addr)
.serve (make_svc);
eprintln! ("Local UI on {}", addr);
server.await?;
Ok (())
}
struct Peer {
id: u32,
outbox: RwLock <Outbox>,
params: Params,
socket: UdpSocket,
}
impl Peer {
async fn send_multicast (&self, msg: &tlv::Message) -> Result <(), Error>
{
let msg = tlv::encode (&msg)?;
self.socket.send_to (&msg, self.params.multicast_group).await?;
Ok (())
}
}
struct Outbox {
index: u32,
messages: VecDeque <SentMessage>,
}
struct SentMessage {
index: u32,
body: Vec <u8>,
}
async fn peer_handle_all (req: Request <Body>, state: Arc <Peer>)
-> Result <Response <Body>, Error>
{
if req.method () == Method::POST {
if req.uri () == "/paste" {
let body = hyper::body::to_bytes (req.into_body ()).await?;
if body.len () > 1024 {
let resp = Response::builder ()
.status (StatusCode::BAD_REQUEST)
.body (Body::from ("Message body must be <= 1024 bytes"))?;
return Ok (resp);
}
let body = body.to_vec ();
let msg_index;
{
let mut outbox = state.outbox.write ().await;
let msg = SentMessage {
index: outbox.index,
body: body.clone (),
};
msg_index = msg.index;
outbox.messages.push_back (msg);
if outbox.messages.len () > 10 {
outbox.messages.pop_front ();
}
outbox.index += 1;
}
match state.send_multicast (&tlv::Message::IHaveMessage {
peer_id: state.id,
msg_index,
body,
}).await {
Ok (_) => (),
Err (_) => return Ok (
Response::builder ()
.status (StatusCode::BAD_REQUEST)
.body (Body::from ("Can't encode message"))?
),
}
return Ok (Response::new (format! ("Pasted message {}\n", msg_index).into ()));
}
}
Ok (Response::new (":V\n".into ()))
}
async fn receiver (params: Params) -> Result <(), Error>
{
let (multicast_addr, multicast_port) = params.multicast_group;
let socket = tokio::net::UdpSocket::bind (SocketAddrV4::new (Ipv4Addr::UNSPECIFIED, multicast_port)).await?;
socket.join_multicast_v4 (multicast_addr, Ipv4Addr::UNSPECIFIED)?;
eprintln! ("Multicast group is {:?}", params.multicast_group);
eprintln! ("Local addr is {}", socket.local_addr ()?);
loop {
let mut buf = vec! [0u8; 2048];
let (bytes_recved, remote_addr) = socket.recv_from (&mut buf).await?;
buf.truncate (bytes_recved);
println! ("Received {} bytes from {}", bytes_recved, remote_addr);
}
}
async fn sender (params: Params) -> Result <(), Error>
{
let (multicast_addr, multicast_port) = params.multicast_group;
let socket = tokio::net::UdpSocket::bind (SocketAddrV4::new (Ipv4Addr::UNSPECIFIED, 0)).await?;
socket.join_multicast_v4 (multicast_addr, Ipv4Addr::UNSPECIFIED)?;
eprintln! ("Multicast group is {:?}", params.multicast_group);
eprintln! ("Local addr is {}", socket.local_addr ()?);
socket.send_to (&[], params.multicast_group).await?;
Ok (())
}
fn spy (params: Params) -> Result <(), Error>
{
let (multicast_addr, multicast_port) = params.multicast_group;
let socket = match std::net::UdpSocket::bind (SocketAddrV4::new (Ipv4Addr::UNSPECIFIED, multicast_port)) {
Ok (x) => x,
Err (e) => if e.kind () == std::io::ErrorKind::AddrInUse {
eprintln! ("Address in use. You can only run 1 instance of Insecure Chat at a time, even in spy mode.");
return Err (Error::AddrInUse);
}
else {
return Err (e.into ());
}
};
for bind_addr in ip::get_ips ()? {
socket.join_multicast_v4 (&multicast_addr, &bind_addr)?;
// eprintln! ("Joined multicast with {}", bind_addr);
}
eprintln! ("Multicast addr is {}", multicast_addr);
eprintln! ("Local addr is {}", socket.local_addr ()?);
loop {
let mut buf = vec! [0u8; 2048];
eprintln! ("Listening for UDP packets...");
let (bytes_recved, remote_addr) = socket.recv_from (&mut buf)?;
buf.truncate (bytes_recved);
println! ("Received {} bytes from {}", bytes_recved, remote_addr);
}
}
#[derive (Debug, thiserror::Error)]
enum Error {
#[error ("Address in use")]
AddrInUse,
#[error ("CLI args")]
Args,
#[error (transparent)]
Hyper (#[from] hyper::Error),
#[error (transparent)]
HyperHttp (#[from] hyper::http::Error),
#[error (transparent)]
Io (#[from] std::io::Error),
#[error (transparent)]
Ip (#[from] ip::Error),
#[error ("Randomness")]
Rand,
#[error (transparent)]
Tlv (#[from] tlv::Error),
}