diff --git a/prototypes/quic_demo/src/bin/quic_demo_relay_server.rs b/prototypes/quic_demo/src/bin/quic_demo_relay_server.rs index a64c31e..0c27c2b 100644 --- a/prototypes/quic_demo/src/bin/quic_demo_relay_server.rs +++ b/prototypes/quic_demo/src/bin/quic_demo_relay_server.rs @@ -16,190 +16,10 @@ async fn main () -> anyhow::Result <()> { // Each new peer QUIC connection gets its own task tokio::spawn (async move { - let quinn::NewConnection { - connection, - mut bi_streams, - .. - } = conn.await?; - - // Everyone who connects must identify themselves with the first - // bi stream - // TODO: Timeout - - let (mut send, mut recv) = bi_streams.next ().await.ok_or_else (|| anyhow::anyhow! ("QUIC client didn't identify itself"))??; - - let mut req_buf = [0u8; 4]; - recv.read_exact (&mut req_buf).await?; - - let peer_type = req_buf [0]; - let peer_id = req_buf [1]; - - match peer_type { - 4 => debug! ("Server-side proxy (P4) connected, ID {}", peer_id), - 2 => debug! ("Client-side proxy (P2) connected, ID {}", peer_id), - _ => bail! ("Unknown QUIC client type"), - } - - let resp_buf = [20u8, 0, 0, 0]; - send.write_all (&resp_buf).await?; - - match peer_type { - 2 => { - let client_id = peer_id; - while let Some (bi_stream) = bi_streams.next ().await { - let (mut client_send, mut client_recv) = bi_stream?; - let relay_state = Arc::clone (&relay_state); - - tokio::spawn (async move { - let mut req_buf = [0u8; 4]; - client_recv.read_exact (&mut req_buf).await?; - - let cmd_type = req_buf [0]; - match cmd_type { - 1 => { - let server_id = req_buf [1]; - - debug! ("P2 {} wants to connect to P4 {}", peer_id, server_id); - - // TODO: Auth checks - - let resp_buf = [0, 0, 0, 0]; - client_send.write_all (&resp_buf).await?; - - { - let relay_state = relay_state.lock ().await; - match relay_state.p4_server_proxies.get (&server_id) { - Some (p4_state) => { - p4_state.req_channel.send (RequestP2ToP4 { - client_send, - client_recv, - client_id, - }).await.map_err (|_| anyhow::anyhow! ("Can't send request to P4 server"))?; - }, - None => warn! ("That server isn't connected"), - } - } - }, - _ => bail! ("Unknown command type from P2"), - } - - debug! ("Request ended for P2"); - - Ok::<_, anyhow::Error> (()) - }); - } - - debug! ("P2 {} disconnected", peer_id); - }, - 4 => { - let (tx, mut rx) = mpsc::channel (2); - - let p4_state = P4State { - req_channel: tx, - }; - - { - let mut relay_state = relay_state.lock ().await; - relay_state.p4_server_proxies.insert (peer_id, p4_state); - } - - while let Some (req) = rx.recv ().await { - let connection = connection.clone (); - - tokio::spawn (async move { - let RequestP2ToP4 { - client_send, - client_recv, - client_id, - } = req; - - debug! ("P4 {} got a request from P2 {}", peer_id, req.client_id); - - let (mut server_send, mut server_recv) = connection.open_bi ().await?; - - let req_buf = [2u8, client_id, 0, 0]; - server_send.write_all (&req_buf).await?; - - let mut resp_buf = [0u8, 0, 0, 0]; - server_recv.read_exact (&mut resp_buf).await?; - - let status_code = resp_buf [0]; - if status_code != 20 { - bail! ("P4 rejected request from {}", client_id); - } - - debug! ("Relaying bytes..."); - - let ptth_conn = PtthNewConnection { - client_send, - client_recv, - server_send, - server_recv, - }.build (); - - ptth_conn.uplink_task.await??; - ptth_conn.downlink_task.await??; - - debug! ("Request ended for P4"); - Ok::<_, anyhow::Error> (()) - }); - } - - debug! ("P4 {} disconnected", peer_id); - }, - _ => bail! ("Unknown QUIC client type"), - } - - debug! ("Peer {} disconnected", peer_id); - Ok::<_, anyhow::Error> (()) + handle_quic_connection (relay_state, conn).await }); } - if false { - debug! ("Waiting for end server to connect"); - - let end_server_conn = incoming.next ().await.ok_or_else (|| anyhow::anyhow! ("No end server connection"))?; - - let end_server_conn = end_server_conn.await?; - - let quinn::NewConnection { - connection: end_server_conn, - .. - } = end_server_conn; - - debug! ("Waiting for client to connect"); - - let client_conn = incoming.next ().await.ok_or_else (|| anyhow::anyhow! ("No client connection"))?; - - let client_conn = client_conn.await?; - - let quinn::NewConnection { - connection: _client_conn, - bi_streams: mut client_incoming_bi_streams, - .. - } = client_conn; - - debug! ("Waiting for client to open bi stream"); - - let (client_send, client_recv) = client_incoming_bi_streams.next ().await.ok_or_else (|| anyhow::anyhow! ("Client didn't open a bi stream"))??; - - debug! ("Opening bi stream to the end server"); - - let (server_send, server_recv) = end_server_conn.open_bi ().await?; - - debug! ("Relaying bytes..."); - - let ptth_conn = PtthNewConnection { - client_send, - client_recv, - server_send, - server_recv, - }.build (); - - ptth_conn.uplink_task.await??; - ptth_conn.downlink_task.await??; - } - Ok (()) } @@ -282,3 +102,145 @@ impl PtthNewConnection { } } } + +async fn handle_quic_connection ( + relay_state: Arc >, + conn: quinn::Connecting, +) -> anyhow::Result <()> { + let quinn::NewConnection { + connection, + mut bi_streams, + .. + } = conn.await?; + + // Everyone who connects must identify themselves with the first + // bi stream + // TODO: Timeout + + let (mut send, mut recv) = bi_streams.next ().await.ok_or_else (|| anyhow::anyhow! ("QUIC client didn't identify itself"))??; + + let mut req_buf = [0u8; 4]; + recv.read_exact (&mut req_buf).await?; + + let peer_type = req_buf [0]; + let peer_id = req_buf [1]; + + match peer_type { + 4 => debug! ("Server-side proxy (P4) connected, ID {}", peer_id), + 2 => debug! ("Client-side proxy (P2) connected, ID {}", peer_id), + _ => bail! ("Unknown QUIC client type"), + } + + let resp_buf = [20u8, 0, 0, 0]; + send.write_all (&resp_buf).await?; + + match peer_type { + 2 => { + let client_id = peer_id; + while let Some (bi_stream) = bi_streams.next ().await { + let (mut client_send, mut client_recv) = bi_stream?; + let relay_state = Arc::clone (&relay_state); + + tokio::spawn (async move { + let mut req_buf = [0u8; 4]; + client_recv.read_exact (&mut req_buf).await?; + + let cmd_type = req_buf [0]; + match cmd_type { + 1 => { + let server_id = req_buf [1]; + + debug! ("P2 {} wants to connect to P4 {}", peer_id, server_id); + + // TODO: Auth checks + + let resp_buf = [0, 0, 0, 0]; + client_send.write_all (&resp_buf).await?; + + { + let relay_state = relay_state.lock ().await; + match relay_state.p4_server_proxies.get (&server_id) { + Some (p4_state) => { + p4_state.req_channel.send (RequestP2ToP4 { + client_send, + client_recv, + client_id, + }).await.map_err (|_| anyhow::anyhow! ("Can't send request to P4 server"))?; + }, + None => warn! ("That server isn't connected"), + } + } + }, + _ => bail! ("Unknown command type from P2"), + } + + debug! ("Request ended for P2"); + + Ok::<_, anyhow::Error> (()) + }); + } + + debug! ("P2 {} disconnected", peer_id); + }, + 4 => { + let (tx, mut rx) = mpsc::channel (2); + + let p4_state = P4State { + req_channel: tx, + }; + + { + let mut relay_state = relay_state.lock ().await; + relay_state.p4_server_proxies.insert (peer_id, p4_state); + } + + while let Some (req) = rx.recv ().await { + let connection = connection.clone (); + + tokio::spawn (async move { + let RequestP2ToP4 { + client_send, + client_recv, + client_id, + } = req; + + debug! ("P4 {} got a request from P2 {}", peer_id, req.client_id); + + let (mut server_send, mut server_recv) = connection.open_bi ().await?; + + let req_buf = [2u8, client_id, 0, 0]; + server_send.write_all (&req_buf).await?; + + let mut resp_buf = [0u8, 0, 0, 0]; + server_recv.read_exact (&mut resp_buf).await?; + + let status_code = resp_buf [0]; + if status_code != 20 { + bail! ("P4 rejected request from {}", client_id); + } + + debug! ("Relaying bytes..."); + + let ptth_conn = PtthNewConnection { + client_send, + client_recv, + server_send, + server_recv, + }.build (); + + ptth_conn.uplink_task.await??; + ptth_conn.downlink_task.await??; + + debug! ("Request ended for P4"); + Ok::<_, anyhow::Error> (()) + }); + } + + debug! ("P4 {} disconnected", peer_id); + }, + _ => bail! ("Unknown QUIC client type"), + } + + debug! ("Peer {} disconnected", peer_id); + Ok::<_, anyhow::Error> (()) +}