diff --git a/src/client.rs b/src/client.rs index d59f075..6bebbcc 100644 --- a/src/client.rs +++ b/src/client.rs @@ -90,9 +90,9 @@ struct Client { impl Client { fn handle_frame(&self, frame: Bytes) -> Result<()> { match rmp_serde::from_slice(&frame)? { - ToClient::ChatLine { id, line } => { - tracing::info!(?id, ?line); - } + ToClient::ChatLine { id, line } => tracing::info!(?id, ?line), + ToClient::ClientConnected { id } => tracing::info!(?id, "Connected"), + ToClient::ClientDisconnected { id } => tracing::info!(?id, "Disconnected"), } Ok(()) } diff --git a/src/messages.rs b/src/messages.rs index 4370de4..15846a3 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -1,8 +1,11 @@ +use crate::prelude::*; use serde::{Deserialize, Serialize}; #[derive(Deserialize, Serialize)] pub(crate) enum ToClient { - ChatLine { id: u64, line: String }, + ChatLine { id: Id, line: String }, + ClientConnected { id: Id }, + ClientDisconnected { id: Id }, } #[derive(Deserialize, Serialize)] diff --git a/src/prelude.rs b/src/prelude.rs index 4be8dc7..183fb0d 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -19,3 +19,5 @@ pub use tokio::{ }; // Don't use BytesCodec, it is _nonsense_ pub use tokio_util::codec::{Framed, LengthDelimitedCodec}; + +pub type Id = u64; diff --git a/src/server.rs b/src/server.rs index 65ecf78..5a7110f 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,7 +1,5 @@ use crate::prelude::*; -type Id = u64; - pub(crate) struct Args { pub(crate) port: u16, } @@ -59,7 +57,7 @@ impl App { } } - // Try to read into the client's inbox + // Try to read data in match stream.as_mut().poll_next(cx) { Poll::Pending => {} Poll::Ready(None) => clients_to_remove.push(*id), @@ -73,16 +71,17 @@ impl App { // Close out disconnected clients for id in clients_to_remove { + cx.waker().wake_by_ref(); tracing::info!(?id, "Closing client"); self.client_streams.remove(&id); - self.server.handle_client_disconnected(id); + self.server.handle_client_disconnected(id)?; } if let Poll::Ready(result) = self.listener.poll_accept(cx) { let (stream, _addr) = result.context("listener.poll_accept")?; cx.waker().wake_by_ref(); let stream = Framed::new(stream, LengthDelimitedCodec::new()); - let id = self.server.handle_new_client(); + let id = self.server.handle_new_client()?; self.client_streams.insert(id, stream); } @@ -97,8 +96,17 @@ struct Server { } impl Server { - fn handle_client_disconnected(&mut self, id: Id) { + fn broadcast(&mut self, msg: &ToClient) -> Result<()> { + for client in &mut self.clients.values_mut() { + client.handle_outgoing(msg)?; + } + Ok(()) + } + + fn handle_client_disconnected(&mut self, id: Id) -> Result<()> { self.clients.remove(&id); + self.broadcast(&ToClient::ClientDisconnected { id })?; + Ok(()) } fn handle_client_frame(&mut self, id: Id, frame: Bytes) -> Result<()> { @@ -112,13 +120,11 @@ impl Server { }; msg }; - for client in &mut self.clients.values_mut() { - client.handle_outgoing(&msg)?; - } + self.broadcast(&msg)?; Ok(()) } - fn handle_new_client(&mut self) -> Id { + fn handle_new_client(&mut self) -> Result { let id = self.next_client_id; self.next_client_id += 1; let client = Client { @@ -132,7 +138,8 @@ impl Server { "Accepted client" ); self.clients.insert(id, client); - id + self.broadcast(&ToClient::ClientConnected { id })?; + Ok(id) } fn poll_send(&mut self, id: Id) -> Result> {