diff --git a/src/macros.rs b/src/macros.rs index 92d83ade6ecb93a88f8995bdbc60df3af0a9415d..afa88b8b0a305d34db5acc42a6af8f846a5d0a6c 100644 --- a/src/macros.rs +++ b/src/macros.rs @@ -1,13 +1,16 @@ macro_rules! endptlog { ($logger:expr, $($arg:tt)*) => {{ if cfg!(feature = "endpoint_logging") { - let w = $logger.line_writer(); - let _ = writeln!(w, $($arg)*); + if let Some(w) = $logger.line_writer() { + let _ = writeln!(w, $($arg)*); + } } }}; } macro_rules! log { ($logger:expr, $($arg:tt)*) => {{ - let _ = writeln!($logger.line_writer(), $($arg)*); + if let Some(w) = $logger.line_writer() { + let _ = writeln!(w, $($arg)*); + } }}; } diff --git a/src/runtime/logging.rs b/src/runtime/logging.rs index ee2ee375a06d685ac28a6f30c6d416238192b1ff..40fee3e2aa9fc159f3b4e27ce10d9afc8643a3b9 100644 --- a/src/runtime/logging.rs +++ b/src/runtime/logging.rs @@ -12,20 +12,20 @@ impl VecLogger { } ///////////////// impl Logger for DummyLogger { - fn line_writer(&mut self) -> &mut dyn std::io::Write { - self + fn line_writer(&mut self) -> Option<&mut dyn std::io::Write> { + None } } impl Logger for VecLogger { - fn line_writer(&mut self) -> &mut dyn std::io::Write { + fn line_writer(&mut self) -> Option<&mut dyn std::io::Write> { let _ = write!(&mut self.1, "CID({}) at {:?} ", self.0, Instant::now()); - self + Some(self) } } impl Logger for FileLogger { - fn line_writer(&mut self) -> &mut dyn std::io::Write { + fn line_writer(&mut self) -> Option<&mut dyn std::io::Write> { let _ = write!(&mut self.1, "CID({}) at {:?} ", self.0, Instant::now()); - &mut self.1 + Some(&mut self.1) } } /////////////////// @@ -46,11 +46,3 @@ impl std::io::Write for VecLogger { Ok(data.len()) } } -impl std::io::Write for DummyLogger { - fn flush(&mut self) -> Result<(), std::io::Error> { - Ok(()) - } - fn write(&mut self, bytes: &[u8]) -> Result { - Ok(bytes.len()) - } -} diff --git a/src/runtime/mod.rs b/src/runtime/mod.rs index f856116712da79ab851b3033074a25c5568e7c12..65bae87c604690c262705ded9e82a98df4279d49 100644 --- a/src/runtime/mod.rs +++ b/src/runtime/mod.rs @@ -24,7 +24,7 @@ pub struct Connector { phased: ConnectorPhased, } pub trait Logger: Debug { - fn line_writer(&mut self) -> &mut dyn std::io::Write; + fn line_writer(&mut self) -> Option<&mut dyn std::io::Write>; } #[derive(Debug)] pub struct VecLogger(ConnectorId, Vec); @@ -154,14 +154,14 @@ struct ProtoComponent { } #[derive(Debug, Clone)] struct NetEndpointSetup { - local_port: PortId, + getter_for_incoming: PortId, sock_addr: SocketAddr, endpoint_polarity: EndpointPolarity, } #[derive(Debug, Clone)] struct UdpEndpointSetup { - local_port: PortId, + getter_for_incoming: PortId, local_addr: SocketAddr, peer_addr: SocketAddr, } @@ -193,7 +193,7 @@ struct SpecVarStream { #[derive(Debug)] struct EndpointManager { // invariants: - // 1. endpoint N is registered READ | WRITE with poller + // 1. net and udp endpoints are registered with poll. Poll token computed with TargetToken::into // 2. Events is empty poll: Poll, events: Events, diff --git a/src/runtime/setup.rs b/src/runtime/setup.rs index 76c48fc1c28f5a64bb63beb42edfdb9489df94d0..682cc3f09741ed4fa5fc4cd9d7a48dbc87b3be20 100644 --- a/src/runtime/setup.rs +++ b/src/runtime/setup.rs @@ -27,30 +27,36 @@ impl Connector { } pub fn new_udp_port( &mut self, - polarity: Polarity, local_addr: SocketAddr, peer_addr: SocketAddr, - ) -> Result { + ) -> Result<[PortId; 2], WrongStateError> { let Self { unphased: cu, phased } = self; match phased { ConnectorPhased::Communication(..) => Err(WrongStateError), ConnectorPhased::Setup(setup) => { let udp_index = setup.udp_endpoint_setups.len(); - let [port_nat, port_udp] = - [cu.id_manager.new_port_id(), cu.id_manager.new_port_id()]; - cu.native_ports.insert(port_nat); - cu.port_info.peers.insert(port_nat, port_udp); - cu.port_info.peers.insert(port_udp, port_nat); - cu.port_info.routes.insert(port_nat, Route::LocalComponent(ComponentId::Native)); - cu.port_info.routes.insert(port_udp, Route::UdpEndpoint { index: udp_index }); - cu.port_info.polarities.insert(port_nat, polarity); - cu.port_info.polarities.insert(port_udp, !polarity); + let mut npid = || cu.id_manager.new_port_id(); + let [nin, nout, uin, uout] = [npid(), npid(), npid(), npid()]; + cu.native_ports.insert(nin); + cu.native_ports.insert(nout); + cu.port_info.polarities.insert(nin, Getter); + cu.port_info.polarities.insert(nout, Putter); + cu.port_info.polarities.insert(uin, Getter); + cu.port_info.polarities.insert(uout, Putter); + cu.port_info.peers.insert(nin, uout); + cu.port_info.peers.insert(nout, uin); + cu.port_info.peers.insert(uin, nout); + cu.port_info.peers.insert(uout, nin); + cu.port_info.routes.insert(nin, Route::LocalComponent(ComponentId::Native)); + cu.port_info.routes.insert(nout, Route::LocalComponent(ComponentId::Native)); + cu.port_info.routes.insert(uin, Route::UdpEndpoint { index: udp_index }); + cu.port_info.routes.insert(uout, Route::UdpEndpoint { index: udp_index }); setup.udp_endpoint_setups.push(UdpEndpointSetup { local_addr, peer_addr, - local_port: port_nat, + getter_for_incoming: nin, }); - Ok(port_nat) + Ok([nout, nin]) } } } @@ -80,7 +86,7 @@ impl Connector { setup.net_endpoint_setups.push(NetEndpointSetup { sock_addr, endpoint_polarity, - local_port, + getter_for_incoming: local_port, }); Ok(local_port) } @@ -143,17 +149,38 @@ fn new_endpoint_manager( deadline: &Option, ) -> Result { //////////////////////////////////////////// - use std::sync::atomic::AtomicBool; + use std::sync::atomic::{AtomicBool, Ordering::SeqCst}; use ConnectError as Ce; const BOTH: Interest = Interest::READABLE.add(Interest::WRITABLE); + const WAKER_PERIOD: Duration = Duration::from_millis(300); + struct WakerState { + continue_signal: AtomicBool, + waker: mio::Waker, + } + impl WakerState { + fn waker_loop(&self) { + while self.continue_signal.load(SeqCst) { + std::thread::sleep(WAKER_PERIOD); + let _ = self.waker.wake(); + } + } + fn waker_stop(&self) { + self.continue_signal.store(false, SeqCst); + // TODO keep waker registered? + } + } struct Todo { + // becomes completed once sent_local_port && recv_peer_port.is_some() + // we send local port if we haven't already and we receive a writable event + // we recv peer port if we haven't already and we receive a readbale event todo_endpoint: TodoEndpoint, endpoint_setup: NetEndpointSetup, sent_local_port: bool, // true <-> I've sent my local port recv_peer_port: Option, // Some(..) <-> I've received my peer's port } struct UdpTodo { - local_port: PortId, + // becomes completed once we receive our first writable event + getter_for_incoming: PortId, sock: UdpSocket, } enum TodoEndpoint { @@ -163,19 +190,14 @@ fn new_endpoint_manager( //////////////////////////////////////////// // 1. Start to construct EndpointManager - const WAKER_PERIOD: Duration = Duration::from_millis(300); - struct WakerState { - continue_signal: AtomicBool, - waker: mio::Waker, - } - let mut waker_state: Option> = None; let mut poll = Poll::new().map_err(|_| Ce::PollInitFailed)?; - let mut events = Events::with_capacity(net_endpoint_setups.len() * 2 + 4); + let mut events = + Events::with_capacity((net_endpoint_setups.len() + udp_endpoint_setups.len()) * 2 + 4); let [mut net_polled_undrained, udp_polled_undrained] = [VecSet::default(), VecSet::default()]; let mut delayed_messages = vec![]; - // 2. create a registered (TcpListener/Endpoint) for passive / active respectively + // 2. Create net/udp TODOs, each already registered with poll let mut net_todos = net_endpoint_setups .iter() .enumerate() @@ -211,19 +233,12 @@ fn new_endpoint_manager( poll.registry() .register(&mut sock, TokenTarget::UdpEndpoint { index }.into(), Interest::WRITABLE) .unwrap(); - Ok(UdpTodo { sock, local_port: endpoint_setup.local_port }) + Ok(UdpTodo { sock, getter_for_incoming: endpoint_setup.getter_for_incoming }) }) .collect::, ConnectError>>()?; - // 3. Using poll to drive progress: - // - accept an incoming connection for each TcpListener (turning them into endpoints too) - // - for each endpoint, send the local PortId - // - for each endpoint, recv the peer's PortId, and - - // all in connect_failed are NOT registered with Poll - let mut connect_failed: HashSet = Default::default(); - // TODO register udps, and all them to incomplete list - + // Initially, (1) no net connections have failed, and (2) all udp and net endpoint setups are incomplete + let mut net_connect_retry_later: HashSet = Default::default(); let mut setup_incomplete: HashSet = { let net_todo_targets_iter = (0..net_todos.len()).map(|index| TokenTarget::NetEndpoint { index }); @@ -231,6 +246,7 @@ fn new_endpoint_manager( (0..udp_todos.len()).map(|index| TokenTarget::UdpEndpoint { index }); net_todo_targets_iter.chain(udp_todo_targets_iter).collect() }; + // progress by reacting to poll events. continue until every endpoint is set up while !setup_incomplete.is_empty() { let remaining = if let Some(deadline) = deadline { Some(deadline.checked_duration_since(Instant::now()).ok_or(Ce::Timeout)?) @@ -241,19 +257,15 @@ fn new_endpoint_manager( for event in events.iter() { let token = event.token(); let token_target = TokenTarget::from(token); - if !setup_incomplete.contains(&token_target) { - // spurious wakeup - continue; - } match token_target { TokenTarget::Waker => { log!( logger, "Notification from waker. connect_failed is {:?}", - connect_failed.iter() + net_connect_retry_later.iter() ); assert!(waker_state.is_some()); - for net_index in connect_failed.drain() { + for net_index in net_connect_retry_later.drain() { let net_todo = &mut net_todos[net_index]; log!( logger, @@ -277,6 +289,10 @@ fn new_endpoint_manager( } } TokenTarget::UdpEndpoint { index } => { + if !setup_incomplete.contains(&token_target) { + // spurious wakeup. this endpoint has already been set up! + continue; + } let udp_todo: &UdpTodo = &udp_todos[index]; if event.is_error() { return Err(Ce::BindFailed(udp_todo.sock.local_addr().unwrap())); @@ -288,17 +304,15 @@ fn new_endpoint_manager( if let TodoEndpoint::Accepting(listener) = &mut net_todo.todo_endpoint { // FIRST try complete this connection match listener.accept() { - Err(e) if would_block(&e) => { - log!(logger, "Spurious wakeup on listener {:?}", index) - } + Err(e) if would_block(&e) => continue, // spurious wakeup Err(_) => { log!(logger, "accept() failure on index {}", index); return Err(Ce::AcceptFailed(listener.local_addr().unwrap())); } Ok((mut stream, peer_addr)) => { - // success! + // successfully accepted the active peer + // reusing the token, but now for the stream and not the listener poll.registry().deregister(listener).unwrap(); - // reusing original token as-is poll.registry().register(&mut stream, token, BOTH).unwrap(); log!( logger, @@ -322,16 +336,16 @@ fn new_endpoint_manager( )); } // this actively-connecting endpoint failed to connect! - if connect_failed.insert(index) { + if net_connect_retry_later.insert(index) { log!( logger, "Connection failed for {:?}. List is {:?}", index, - connect_failed.iter() + net_connect_retry_later.iter() ); poll.registry().deregister(&mut net_endpoint.stream).unwrap(); } else { - // spurious wakeup. + // spurious wakeup. already scheduled to retry connect later continue; } if waker_state.is_none() { @@ -346,30 +360,31 @@ fn new_endpoint_manager( }); let moved_arc = arc.clone(); waker_state = Some(arc); - std::thread::spawn(move || { - while moved_arc - .continue_signal - .load(std::sync::atomic::Ordering::SeqCst) - { - std::thread::sleep(WAKER_PERIOD); - let _ = moved_arc.waker.wake(); - } - }); + std::thread::spawn(move || moved_arc.waker_loop()); } continue; } // event wasn't ERROR - if connect_failed.contains(&index) { - // spurious wakeup + if net_connect_retry_later.contains(&index) { + // spurious wakeup. already scheduled to retry connect later continue; } - let local_polarity = - *port_info.polarities.get(&net_todo.endpoint_setup.local_port).unwrap(); + if !setup_incomplete.contains(&token_target) { + // spurious wakeup. this endpoint has already been completed! + if event.is_readable() { + net_polled_undrained.insert(index); + } + continue; + } + let local_polarity = *port_info + .polarities + .get(&net_todo.endpoint_setup.getter_for_incoming) + .unwrap(); if event.is_writable() && !net_todo.sent_local_port { // can write and didn't send setup msg yet? Do so! let msg = Msg::SetupMsg(SetupMsg::MyPortInfo(MyPortInfo { polarity: local_polarity, - port: net_todo.endpoint_setup.local_port, + port: net_todo.endpoint_setup.getter_for_incoming, })); net_endpoint .send(&msg) @@ -405,19 +420,21 @@ fn new_endpoint_manager( ); if peer_info.polarity == local_polarity { return Err(ConnectError::PortPeerPolarityMismatch( - net_todo.endpoint_setup.local_port, + net_todo.endpoint_setup.getter_for_incoming, )); } net_todo.recv_peer_port = Some(peer_info.port); // 1. finally learned the peer of this port! - port_info - .peers - .insert(net_todo.endpoint_setup.local_port, peer_info.port); + port_info.peers.insert( + net_todo.endpoint_setup.getter_for_incoming, + peer_info.port, + ); // 2. learned the info of this peer port port_info.polarities.insert(peer_info.port, peer_info.polarity); - port_info - .peers - .insert(peer_info.port, net_todo.endpoint_setup.local_port); + port_info.peers.insert( + peer_info.port, + net_todo.endpoint_setup.getter_for_incoming, + ); if let Some(route) = port_info.routes.get(&peer_info.port) { // check just for logging purposes log!( @@ -455,12 +472,9 @@ fn new_endpoint_manager( events.clear(); } log!(logger, "Endpoint setup complete! Cleaning up and building structures"); - if let Some(arc) = waker_state { - log!(logger, "Sending waker the stop signal"); - arc.continue_signal.store(false, std::sync::atomic::Ordering::SeqCst); - // TODO leave the waker registered? + if let Some(ws) = waker_state.take() { + ws.waker_stop(); } - let net_endpoint_exts = net_todos .into_iter() .enumerate() @@ -475,21 +489,21 @@ fn new_endpoint_manager( } _ => unreachable!(), }, - getter_for_incoming: endpoint_setup.local_port, + getter_for_incoming: endpoint_setup.getter_for_incoming, }) .collect(); let udp_endpoint_exts = udp_todos .into_iter() .enumerate() .map(|(index, udp_todo)| { - let UdpTodo { mut sock, local_port } = udp_todo; + let UdpTodo { mut sock, getter_for_incoming } = udp_todo; let token = TokenTarget::UdpEndpoint { index }.into(); poll.registry().reregister(&mut sock, token, Interest::READABLE).unwrap(); UdpEndpointExt { sock, outgoing_payloads: Default::default(), incoming_round_spec_var: None, - getter_for_incoming: local_port, + getter_for_incoming, incoming_payloads: Default::default(), } }) diff --git a/src/runtime/tests.rs b/src/runtime/tests.rs index 8211d6b180ad89472f0924eaa670a80b9910867f..384273edf24b9575018c7f2a0d61cf6cd367db16 100644 --- a/src/runtime/tests.rs +++ b/src/runtime/tests.rs @@ -8,10 +8,11 @@ use reowolf::{ }; use std::{fs::File, net::SocketAddr, path::Path, sync::Arc, time::Duration}; ////////////////////////////////////////// +const MS100: Option = Some(Duration::from_millis(100)); +const MS300: Option = Some(Duration::from_millis(300)); const SEC1: Option = Some(Duration::from_secs(1)); const SEC5: Option = Some(Duration::from_secs(5)); const SEC15: Option = Some(Duration::from_secs(15)); -const MS300: Option = Some(Duration::from_millis(300)); fn next_test_addr() -> SocketAddr { use std::{ net::{Ipv4Addr, SocketAddrV4}, @@ -664,8 +665,8 @@ fn udp_self_connect() { let test_log_path = Path::new("./logs/udp_self_connect"); let sock_addrs = [next_test_addr(), next_test_addr()]; let mut c = file_logged_connector(0, test_log_path); - c.new_udp_port(Putter, sock_addrs[0], sock_addrs[1]).unwrap(); - c.new_udp_port(Getter, sock_addrs[1], sock_addrs[0]).unwrap(); + c.new_udp_port(sock_addrs[0], sock_addrs[1]).unwrap(); + c.new_udp_port(sock_addrs[1], sock_addrs[0]).unwrap(); c.connect(SEC1).unwrap(); } @@ -674,7 +675,7 @@ fn solo_udp_put_success() { let test_log_path = Path::new("./logs/solo_udp_put_success"); let sock_addrs = [next_test_addr(), next_test_addr()]; let mut c = file_logged_connector(0, test_log_path); - let p0 = c.new_udp_port(Putter, sock_addrs[0], sock_addrs[1]).unwrap(); + let [p0, _] = c.new_udp_port(sock_addrs[0], sock_addrs[1]).unwrap(); c.connect(SEC1).unwrap(); c.put(p0, TEST_MSG.clone()).unwrap(); c.sync(MS300).unwrap(); @@ -685,7 +686,7 @@ fn solo_udp_get_fail() { let test_log_path = Path::new("./logs/solo_udp_get_fail"); let sock_addrs = [next_test_addr(), next_test_addr()]; let mut c = file_logged_connector(0, test_log_path); - let p0 = c.new_udp_port(Getter, sock_addrs[0], sock_addrs[1]).unwrap(); + let [_, p0] = c.new_udp_port(sock_addrs[0], sock_addrs[1]).unwrap(); c.connect(SEC1).unwrap(); c.get(p0).unwrap(); c.sync(MS300).unwrap_err(); @@ -701,7 +702,7 @@ fn reowolf_to_udp() { barrier.wait(); // reowolf thread let mut c = file_logged_connector(0, test_log_path); - let p0 = c.new_udp_port(Putter, sock_addrs[0], sock_addrs[1]).unwrap(); + let [p0, _] = c.new_udp_port(sock_addrs[0], sock_addrs[1]).unwrap(); c.connect(SEC1).unwrap(); c.put(p0, TEST_MSG.clone()).unwrap(); c.sync(MS300).unwrap(); @@ -736,10 +737,10 @@ fn udp_to_reowolf() { barrier.wait(); // reowolf thread let mut c = file_logged_connector(0, test_log_path); - let p0 = c.new_udp_port(Getter, sock_addrs[0], sock_addrs[1]).unwrap(); + let [_, p0] = c.new_udp_port(sock_addrs[0], sock_addrs[1]).unwrap(); c.connect(SEC1).unwrap(); c.get(p0).unwrap(); - c.sync(SEC1).unwrap(); + c.sync(SEC5).unwrap(); assert_eq!(c.gotten(p0).unwrap().as_slice(), TEST_MSG_BYTES); barrier.wait(); }); @@ -748,11 +749,50 @@ fn udp_to_reowolf() { // udp thread let udp = std::net::UdpSocket::bind(sock_addrs[1]).unwrap(); udp.connect(sock_addrs[0]).unwrap(); - for _ in 0..5 { + for _ in 0..15 { udp.send(TEST_MSG_BYTES).unwrap(); + std::thread::sleep(MS100.unwrap()); } barrier.wait(); }); }) .unwrap(); } + +#[test] +fn udp_reowolf_swap() { + let test_log_path = Path::new("./logs/udp_reowolf_swap"); + let sock_addrs = [next_test_addr(), next_test_addr()]; + let barrier = std::sync::Barrier::new(2); + scope(|s| { + s.spawn(|_| { + barrier.wait(); + // reowolf thread + let mut c = file_logged_connector(0, test_log_path); + let [p0, p1] = c.new_udp_port(sock_addrs[0], sock_addrs[1]).unwrap(); + c.connect(SEC1).unwrap(); + c.put(p0, TEST_MSG.clone()).unwrap(); + c.get(p1).unwrap(); + c.sync(SEC5).unwrap(); + assert_eq!(c.gotten(p1).unwrap().as_slice(), TEST_MSG_BYTES); + barrier.wait(); + }); + s.spawn(|_| { + barrier.wait(); + // udp thread + let udp = std::net::UdpSocket::bind(sock_addrs[1]).unwrap(); + udp.connect(sock_addrs[0]).unwrap(); + let mut buf = unsafe { + // canonical way to create uninitalized byte buffer + let mut v = Vec::with_capacity(256); + v.set_len(256); + v + }; + udp.send(TEST_MSG_BYTES).unwrap(); + let len = udp.recv(&mut buf).unwrap(); + assert_eq!(TEST_MSG_BYTES, &buf[0..len]); + barrier.wait(); + }); + }) + .unwrap(); +}