diff --git a/Cargo.toml b/Cargo.toml index c28bcfae3d822844be409bf9c946258b4ff18698..28e7ab867343cca968fe958faddcabc70318fd59 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,6 +18,7 @@ serde = { version = "1.0.112", features = ["derive"] } getrandom = "0.1.14" # tiny crate. used to guess controller-id take_mut = "0.2.2" indexmap = "1.3.0" # hashsets/hashmaps with efficient arbitrary element removal +replace_with = "0.1.5" # network integer-encoding = "1.0.7" diff --git a/src/runtime/endpoints.rs b/src/runtime/endpoints.rs index 57b49ba006be144f813c2056ffb1f95a07213a5f..8a88cb1e48f1ec247572a878ef91d45703e2b1f2 100644 --- a/src/runtime/endpoints.rs +++ b/src/runtime/endpoints.rs @@ -12,10 +12,6 @@ enum TryRecyAnyError { } ///////////////////// - -fn would_block(err: &std::io::Error) -> bool { - err.kind() == std::io::ErrorKind::WouldBlock -} impl Endpoint { pub(super) fn try_recv( &mut self, diff --git a/src/runtime/mod.rs b/src/runtime/mod.rs index 3a098dcca503e27a53d98cbb55a213603094be10..2a3984726b3efe32bb18e86d6164fa22d6d616d4 100644 --- a/src/runtime/mod.rs +++ b/src/runtime/mod.rs @@ -202,6 +202,9 @@ pub struct SyncProtoContext<'a> { inbox: &'a HashMap, } //////////////// +pub fn would_block(err: &std::io::Error) -> bool { + err.kind() == std::io::ErrorKind::WouldBlock +} impl VecSet { fn iter(&self) -> std::slice::Iter { self.vec.iter() diff --git a/src/runtime/setup.rs b/src/runtime/setup.rs index d28d7576b823ad095337b0449bd89d14789a00ba..80dd36be7bb03fb4c0302988a859284ae5330d53 100644 --- a/src/runtime/setup.rs +++ b/src/runtime/setup.rs @@ -97,7 +97,6 @@ impl Connector { } } } - fn new_endpoint_manager( logger: &mut dyn Logger, endpoint_setups: &[(PortId, EndpointSetup)], @@ -105,16 +104,18 @@ fn new_endpoint_manager( deadline: Option, ) -> Result { //////////////////////////////////////////// + use std::sync::atomic::AtomicBool; use ConnectError::*; const BOTH: Interest = Interest::READABLE.add(Interest::WRITABLE); struct Todo { todo_endpoint: TodoEndpoint, + endpoint_setup: EndpointSetup, local_port: PortId, sent_local_port: bool, // true <-> I've sent my local port recv_peer_port: Option, // Some(..) <-> I've received my peer's port } enum TodoEndpoint { - Listener(TcpListener), + Accepting(TcpListener), Endpoint(Endpoint), } fn init_todo( @@ -132,13 +133,27 @@ fn new_endpoint_manager( let mut listener = TcpListener::bind(endpoint_setup.sock_addr) .map_err(|_| BindFailed(endpoint_setup.sock_addr))?; poll.registry().register(&mut listener, token, BOTH).unwrap(); - TodoEndpoint::Listener(listener) + TodoEndpoint::Accepting(listener) }; - Ok(Todo { todo_endpoint, local_port, sent_local_port: false, recv_peer_port: None }) + Ok(Todo { + todo_endpoint, + local_port, + sent_local_port: false, + recv_peer_port: None, + endpoint_setup: endpoint_setup.clone(), + }) }; + struct WakerState { + continue_signal: Arc, + failed_indices: HashSet, + } //////////////////////////////////////////// // 1. Start to construct EndpointManager + const WAKER_TOKEN: Token = Token(usize::MAX); + const WAKER_PERIOD: Duration = Duration::from_millis(90); + assert!(endpoint_setups.len() < WAKER_TOKEN.0); // using MAX usize as waker token + let mut waker_continue_signal: Option> = None; let mut poll = Poll::new().map_err(|_| PollInitFailed)?; let mut events = Events::with_capacity(endpoint_setups.len() * 2 + 4); let mut polled_undrained = IndexSet::default(); @@ -157,6 +172,7 @@ fn new_endpoint_manager( // - accept an incoming connection for each TcpListener (turning them into endpoints too) // - for each endpoint, send the local PortId // - for each endpoint, recv the peer's PortId, and + let mut connect_failed: HashSet = Default::default(); let mut setup_incomplete: HashSet = (0..todos.len()).collect(); while !setup_incomplete.is_empty() { let remaining = if let Some(deadline) = deadline { @@ -169,40 +185,88 @@ fn new_endpoint_manager( let token = event.token(); let Token(index) = token; let todo: &mut Todo = &mut todos[index]; - if let TodoEndpoint::Listener(listener) = &mut todo.todo_endpoint { - match listener.accept() { - Ok((mut stream, peer_addr)) => { - poll.registry().deregister(listener).unwrap(); - poll.registry().register(&mut stream, token, BOTH).unwrap(); - log!( - logger, - "Endpoint[{}] accepted a connection from {:?}", - index, - peer_addr - ); - let endpoint = Endpoint { stream, inbox: vec![] }; - todo.todo_endpoint = TodoEndpoint::Endpoint(endpoint); + if token == WAKER_TOKEN { + log!(logger, "Notification from waker"); + assert!(waker_continue_signal.is_some()); + for index in connect_failed.drain() { + log!( + logger, + "Restarting connection with endpoint {:?} {:?}", + index, + todo.endpoint_setup.sock_addr + ); + match &mut todo.todo_endpoint { + TodoEndpoint::Endpoint(endpoint) => { + let mut new_stream = TcpStream::connect(todo.endpoint_setup.sock_addr) + .expect("mio::TcpStream connect should not fail!"); + poll.registry().deregister(&mut endpoint.stream).unwrap(); + std::mem::swap(&mut endpoint.stream, &mut new_stream); + poll.registry().register(&mut endpoint.stream, token, BOTH).unwrap(); + } + _ => unreachable!(), } - Err(e) if e.kind() == WouldBlock => {} - Err(_) => return Err(AcceptFailed(listener.local_addr().unwrap())), } - } - match todo { - Todo { - todo_endpoint: TodoEndpoint::Endpoint(endpoint), - local_port, - sent_local_port, - recv_peer_port, - .. - } => { + } else { + // FIRST try convert this into an endpoint + if let TodoEndpoint::Accepting(listener) = &mut todo.todo_endpoint { + match listener.accept() { + Ok((mut stream, peer_addr)) => { + poll.registry().deregister(listener).unwrap(); + poll.registry().register(&mut stream, token, BOTH).unwrap(); + log!( + logger, + "Endpoint[{}] accepted a connection from {:?}", + index, + peer_addr + ); + let endpoint = Endpoint { stream, inbox: vec![] }; + todo.todo_endpoint = TodoEndpoint::Endpoint(endpoint); + } + Err(e) if would_block(&e) => { + log!(logger, "Spurious wakeup on listener {:?}", index) + } + Err(_) => { + log!(logger, "accept() failure on index {}", index); + return Err(AcceptFailed(listener.local_addr().unwrap())); + } + } + } + if let TodoEndpoint::Endpoint(endpoint) = &mut todo.todo_endpoint { + if event.is_error() { + if todo.endpoint_setup.endpoint_polarity == EndpointPolarity::Passive { + // right now you cannot retry an acceptor. + return Err(AcceptFailed(endpoint.stream.local_addr().unwrap())); + } + connect_failed.insert(index); + if waker_continue_signal.is_none() { + log!(logger, "First connect failure. Starting waker thread"); + let waker = + Arc::new(mio::Waker::new(poll.registry(), WAKER_TOKEN).unwrap()); + let wcs = Arc::new(AtomicBool::from(true)); + let wcs2 = wcs.clone(); + std::thread::spawn(move || { + while wcs2.load(std::sync::atomic::Ordering::SeqCst) { + std::thread::sleep(WAKER_PERIOD); + waker.wake().expect("unable to wake"); + } + }); + waker_continue_signal = Some(wcs); + } + continue; + } + if connect_failed.contains(&index) { + // spurious wakeup + continue; + } if !setup_incomplete.contains(&index) { + // spurious wakeup continue; } - let local_polarity = *port_info.polarities.get(local_port).unwrap(); - if event.is_writable() && !*sent_local_port { + let local_polarity = *port_info.polarities.get(&todo.local_port).unwrap(); + if event.is_writable() && !todo.sent_local_port { let msg = Msg::SetupMsg(SetupMsg::MyPortInfo(MyPortInfo { polarity: local_polarity, - port: *local_port, + port: todo.local_port, })); endpoint .send(&msg) @@ -211,9 +275,9 @@ fn new_endpoint_manager( }) .unwrap(); log!(logger, "endpoint[{}] sent msg {:?}", index, &msg); - *sent_local_port = true; + todo.sent_local_port = true; } - if event.is_readable() && recv_peer_port.is_none() { + if event.is_readable() && todo.recv_peer_port.is_none() { let maybe_msg = endpoint.try_recv(logger).map_err(|e| { EndpointSetupError(endpoint.stream.local_addr().unwrap(), e) })?; @@ -226,15 +290,15 @@ fn new_endpoint_manager( log!(logger, "endpoint[{}] got peer info {:?}", index, peer_info); if peer_info.polarity == local_polarity { return Err(ConnectError::PortPeerPolarityMismatch( - *local_port, + todo.local_port, )); } - *recv_peer_port = Some(peer_info.port); + todo.recv_peer_port = Some(peer_info.port); // 1. finally learned the peer of this port! - port_info.peers.insert(*local_port, peer_info.port); + port_info.peers.insert(todo.local_port, peer_info.port); // 2. learned the info of this peer port port_info.polarities.insert(peer_info.port, peer_info.polarity); - port_info.peers.insert(peer_info.port, *local_port); + port_info.peers.insert(peer_info.port, todo.local_port); port_info.routes.insert(peer_info.port, Route::Endpoint { index }); } Some(inappropriate_msg) => { @@ -247,12 +311,11 @@ fn new_endpoint_manager( } } } - if *sent_local_port && recv_peer_port.is_some() { + if todo.sent_local_port && todo.recv_peer_port.is_some() { setup_incomplete.remove(&index); log!(logger, "endpoint[{}] is finished!", index); } } - Todo { todo_endpoint: TodoEndpoint::Listener(_), .. } => unreachable!(), } } events.clear(); @@ -268,11 +331,15 @@ fn new_endpoint_manager( .unwrap(); endpoint } - TodoEndpoint::Listener(..) => unreachable!(), + _ => unreachable!(), }, getter_for_incoming: local_port, }) .collect(); + if let Some(wcs) = waker_continue_signal { + log!(logger, "Sending waker the stop signal"); + wcs.store(false, std::sync::atomic::Ordering::SeqCst); + } Ok(EndpointManager { poll, events,