use std::{ collections::*, net::{ Ipv4Addr, SocketAddrV4, }, sync::Arc, }; use hyper::{ Body, Method, Request, Response, Server, StatusCode, service::{ make_service_fn, service_fn, }, }; use tokio::{ net::UdpSocket, sync::RwLock, }; mod ip; fn main () -> Result <(), Error> { let mut args = std::env::args (); let mut bail_unknown = true; let mut last_unknown = None; let mut name = ptth_diceware::passphrase ("_", 3); let mut subcommand_count = 0; let mut subcommand = None; args.next (); while let Some (arg) = args.next () { if arg == "--ignore-unknown" { bail_unknown = false; } if arg == "--name" { name = args.next ().unwrap ().to_string (); } else if arg == "peer" { subcommand = Some (Subcommand::Peer); subcommand_count += 1; } else if arg == "receiver" { subcommand = Some (Subcommand::Receiver); subcommand_count += 1; } else if arg == "sender" { subcommand = Some (Subcommand::Sender); subcommand_count += 1; } else if arg == "spy" { subcommand = Some (Subcommand::Spy); subcommand_count += 1; } else { last_unknown = Some (arg); } } if bail_unknown { if let Some (last_unknown) = last_unknown { eprintln! ("Unknown argument `{}`", last_unknown); return Err (Error::Args); } } if subcommand_count >= 2 { eprintln! ("Detected {} subcommands in arguments", subcommand_count); return Err (Error::Args) } let rt = tokio::runtime::Runtime::new ()?; let params = Params::default (); rt.block_on (async { if let Some (cmd) = subcommand { return match cmd { Subcommand::Peer => peer (params).await, Subcommand::Receiver => receiver (params).await, Subcommand::Sender => sender (params).await, Subcommand::Spy => spy (params), }; } println! ("Name is `{}`", name); Ok::<_, Error> (()) })?; Ok (()) } enum Subcommand { Peer, Receiver, Sender, Spy, } struct Params { multicast_group: (Ipv4Addr, u16), } impl Default for Params { fn default () -> Self { let multicast_group = (Ipv4Addr::new (225, 100, 99, 98), 9041); Self { multicast_group, } } } async fn peer (params: Params) -> Result <(), Error> { let (multicast_addr, multicast_port) = params.multicast_group; let socket = tokio::net::UdpSocket::bind (SocketAddrV4::new (Ipv4Addr::UNSPECIFIED, multicast_port)).await?; socket.join_multicast_v4 (multicast_addr, Ipv4Addr::UNSPECIFIED)?; eprintln! ("Multicast group is {:?}", params.multicast_group); eprintln! ("Local addr is {}", socket.local_addr ()?); let peer = Peer { outbox: Outbox { index: 1000, messages: Default::default (), }.into (), params, socket, }; let state = Arc::new (peer); let make_svc = make_service_fn (|_conn| { let state = state.clone (); async { Ok::<_, String> (service_fn (move |req| { let state = state.clone (); peer_handle_all (req, state) })) } }); let addr = std::net::SocketAddr::from (([127, 0, 0, 1], multicast_port)); let server = Server::bind (&addr) .serve (make_svc); eprintln! ("Local UI on {}", addr); server.await?; Ok (()) } struct Peer { outbox: RwLock , params: Params, socket: UdpSocket, } struct Outbox { index: u32, messages: VecDeque , } struct SentMessage { index: u32, body: Vec , } async fn peer_handle_all (req: Request , state: Arc ) -> Result , Error> { if req.method () == Method::POST { if req.uri () == "/paste" { let body = hyper::body::to_bytes (req.into_body ()).await?; if body.len () > 1024 { let resp = Response::builder () .status (StatusCode::BAD_REQUEST) .body (Body::from ("Message body must be <= 1024 bytes"))?; return Ok (resp); } let body = body.to_vec (); let msg_index; { let mut outbox = state.outbox.write ().await; let msg = SentMessage { index: outbox.index, body: body.clone (), }; msg_index = msg.index; outbox.messages.push_back (msg); if outbox.messages.len () > 10 { outbox.messages.pop_front (); } outbox.index += 1; } state.socket.send_to (&body, state.params.multicast_group).await?; return Ok (Response::new (format! ("Pasted message {}\n", msg_index).into ())); } } Ok (Response::new (":V\n".into ())) } async fn receiver (params: Params) -> Result <(), Error> { let (multicast_addr, multicast_port) = params.multicast_group; let socket = tokio::net::UdpSocket::bind (SocketAddrV4::new (Ipv4Addr::UNSPECIFIED, multicast_port)).await?; socket.join_multicast_v4 (multicast_addr, Ipv4Addr::UNSPECIFIED)?; eprintln! ("Multicast group is {:?}", params.multicast_group); eprintln! ("Local addr is {}", socket.local_addr ()?); loop { let mut buf = vec! [0u8; 2048]; let (bytes_recved, remote_addr) = socket.recv_from (&mut buf).await?; buf.truncate (bytes_recved); println! ("Received {} bytes from {}", bytes_recved, remote_addr); } } async fn sender (params: Params) -> Result <(), Error> { let (multicast_addr, multicast_port) = params.multicast_group; let socket = tokio::net::UdpSocket::bind (SocketAddrV4::new (Ipv4Addr::UNSPECIFIED, 0)).await?; socket.join_multicast_v4 (multicast_addr, Ipv4Addr::UNSPECIFIED)?; eprintln! ("Multicast group is {:?}", params.multicast_group); eprintln! ("Local addr is {}", socket.local_addr ()?); socket.send_to (&[], params.multicast_group).await?; Ok (()) } fn spy (params: Params) -> Result <(), Error> { let (multicast_addr, multicast_port) = params.multicast_group; let socket = match std::net::UdpSocket::bind (SocketAddrV4::new (Ipv4Addr::UNSPECIFIED, multicast_port)) { Ok (x) => x, Err (e) => if e.kind () == std::io::ErrorKind::AddrInUse { eprintln! ("Address in use. You can only run 1 instance of Insecure Chat at a time, even in spy mode."); return Err (Error::AddrInUse); } else { return Err (e.into ()); } }; for bind_addr in ip::get_ips ()? { socket.join_multicast_v4 (&multicast_addr, &bind_addr)?; // eprintln! ("Joined multicast with {}", bind_addr); } eprintln! ("Multicast addr is {}", multicast_addr); eprintln! ("Local addr is {}", socket.local_addr ()?); loop { let mut buf = vec! [0u8; 2048]; eprintln! ("Listening for UDP packets..."); let (bytes_recved, remote_addr) = socket.recv_from (&mut buf)?; buf.truncate (bytes_recved); println! ("Received {} bytes from {}", bytes_recved, remote_addr); } } #[derive (Debug, thiserror::Error)] enum Error { #[error ("Address in use")] AddrInUse, #[error ("CLI args")] Args, #[error (transparent)] Hyper (#[from] hyper::Error), #[error (transparent)] HyperHttp (#[from] hyper::http::Error), #[error (transparent)] Io (#[from] std::io::Error), #[error (transparent)] Ip (#[from] ip::Error), }