diff --git a/src/runtime/setup.rs b/src/runtime/setup.rs index 32ea1ac6d99fd1954189264e22daf4605ce0545c..3a45844f1cdcf75dcb342fd0192e990713a05e3f 100644 --- a/src/runtime/setup.rs +++ b/src/runtime/setup.rs @@ -94,6 +94,7 @@ impl Connector { let mut endpoint_manager = new_endpoint_manager( &mut *cu.logger, &setup.net_endpoint_setups, + &setup.udp_endpoint_setups, &mut cu.port_info, &deadline, )?; @@ -129,7 +130,8 @@ impl Connector { } fn new_endpoint_manager( logger: &mut dyn Logger, - endpoint_setups: &[(PortId, NetEndpointSetup)], + net_endpoint_setups: &[(PortId, NetEndpointSetup)], + udp_endpoint_setups: &[(PortId, UdpEndpointSetup)], port_info: &mut PortInfo, deadline: &Option, ) -> Result { @@ -144,35 +146,14 @@ fn new_endpoint_manager( sent_local_port: bool, // true <-> I've sent my local port recv_peer_port: Option, // Some(..) <-> I've received my peer's port } + struct UdpTodo { + local_port: PortId, + sock: UdpSocket, + } enum TodoEndpoint { Accepting(TcpListener), NetEndpoint(NetEndpoint), } - fn init_todo( - token: Token, - local_port: PortId, - endpoint_setup: &NetEndpointSetup, - poll: &mut Poll, - ) -> Result { - let todo_endpoint = if let EndpointPolarity::Active = endpoint_setup.endpoint_polarity { - let mut stream = TcpStream::connect(endpoint_setup.sock_addr) - .expect("mio::TcpStream connect should not fail!"); - poll.registry().register(&mut stream, token, BOTH).unwrap(); - TodoEndpoint::NetEndpoint(NetEndpoint { stream, inbox: vec![] }) - } else { - let mut listener = TcpListener::bind(endpoint_setup.sock_addr) - .map_err(|_| Ce::BindFailed(endpoint_setup.sock_addr))?; - poll.registry().register(&mut listener, token, BOTH).unwrap(); - TodoEndpoint::Accepting(listener) - }; - Ok(Todo { - todo_endpoint, - local_port, - sent_local_port: false, - recv_peer_port: None, - endpoint_setup: endpoint_setup.clone(), - }) - } //////////////////////////////////////////// // 1. Start to construct EndpointManager @@ -184,24 +165,49 @@ fn new_endpoint_manager( let mut waker_state: Option> = None; let mut poll = Poll::new().map_err(|_| Ce::PollInitFailed)?; - let mut events = Events::with_capacity(endpoint_setups.len() * 2 + 4); + let mut events = Events::with_capacity(net_endpoint_setups.len() * 2 + 4); let mut net_polled_undrained = VecSet::default(); let udp_polled_undrained = VecSet::default(); let mut delayed_messages = vec![]; // 2. create a registered (TcpListener/Endpoint) for passive / active respectively - let mut todos = endpoint_setups + let mut todos = net_endpoint_setups .iter() .enumerate() .map(|(index, (local_port, endpoint_setup))| { - init_todo( - TokenTarget::NetEndpoint { index }.into(), - *local_port, - endpoint_setup, - &mut poll, - ) + let token = TokenTarget::NetEndpoint { index }.into(); + let todo_endpoint = if let EndpointPolarity::Active = endpoint_setup.endpoint_polarity { + let mut stream = TcpStream::connect(endpoint_setup.sock_addr) + .expect("mio::TcpStream connect should not fail!"); + poll.registry().register(&mut stream, token, BOTH).unwrap(); + TodoEndpoint::NetEndpoint(NetEndpoint { stream, inbox: vec![] }) + } else { + let mut listener = TcpListener::bind(endpoint_setup.sock_addr) + .map_err(|_| Ce::BindFailed(endpoint_setup.sock_addr))?; + poll.registry().register(&mut listener, token, BOTH).unwrap(); + TodoEndpoint::Accepting(listener) + }; + Ok(Todo { + todo_endpoint, + local_port: *local_port, + sent_local_port: false, + recv_peer_port: None, + endpoint_setup: endpoint_setup.clone(), + }) }) .collect::, ConnectError>>()?; + let udp_todos = udp_endpoint_setups + .iter() + .enumerate() + .map(|(index, (local_port, endpoint_setup))| { + let mut sock = UdpSocket::bind(endpoint_setup.local_addr) + .map_err(|_| Ce::BindFailed(endpoint_setup.local_addr))?; + poll.registry() + .register(&mut sock, TokenTarget::UdpEndpoint { index }.into(), Interest::WRITABLE) + .unwrap(); + Ok(UdpTodo { sock, local_port: *local_port }) + }) + .collect::, ConnectError>>()?; // 3. Using poll to drive progress: // - accept an incoming connection for each TcpListener (turning them into endpoints too) @@ -210,8 +216,12 @@ fn new_endpoint_manager( // all in connect_failed are NOT registered with Poll let mut connect_failed: HashSet = Default::default(); + // TODO register udps, and all them to incomplete list - let mut setup_incomplete: HashSet = (0..todos.len()).collect(); + let mut setup_incomplete: HashSet = (0..todos.len()) + .map(|index| TokenTarget::NetEndpoint { index }) + .chain((0..udp_todos.len()).map(|index| TokenTarget::UdpEndpoint { index })) + .collect(); while !setup_incomplete.is_empty() { let remaining = if let Some(deadline) = deadline { Some(deadline.checked_duration_since(Instant::now()).ok_or(Ce::Timeout)?) @@ -222,6 +232,10 @@ fn new_endpoint_manager( for event in events.iter() { let token = event.token(); let token_target = TokenTarget::from(token); + if !setup_incomplete.contains(&token_target) { + // spurious wakeup + continue; + } match token_target { TokenTarget::Waker => { log!( @@ -230,12 +244,12 @@ fn new_endpoint_manager( connect_failed.iter() ); assert!(waker_state.is_some()); - for net_endpoint_index in connect_failed.drain() { - let todo: &mut Todo = &mut todos[net_endpoint_index]; + for net_index in connect_failed.drain() { + let todo: &mut Todo = &mut todos[net_index]; log!( logger, "Restarting connection with endpoint {:?} {:?}", - net_endpoint_index, + net_index, todo.endpoint_setup.sock_addr ); match &mut todo.todo_endpoint { @@ -244,8 +258,7 @@ fn new_endpoint_manager( TcpStream::connect(todo.endpoint_setup.sock_addr) .expect("mio::TcpStream connect should not fail!"); std::mem::swap(&mut endpoint.stream, &mut new_stream); - let token = - TokenTarget::NetEndpoint { index: net_endpoint_index }.into(); + let token = TokenTarget::NetEndpoint { index: net_index }.into(); poll.registry() .register(&mut endpoint.stream, token, BOTH) .unwrap(); @@ -254,7 +267,13 @@ fn new_endpoint_manager( } } } - TokenTarget::UdpEndpoint { index: _ } => unreachable!(), + TokenTarget::UdpEndpoint { index } => { + let udp_todo: &UdpTodo = &udp_todos[index]; + if event.is_error() { + return Err(Ce::BindFailed(udp_todo.sock.local_addr().unwrap())); + } + setup_incomplete.remove(&token_target); + } TokenTarget::NetEndpoint { index } => { let todo: &mut Todo = &mut todos[index]; if let TodoEndpoint::Accepting(listener) = &mut todo.todo_endpoint { @@ -333,10 +352,6 @@ fn new_endpoint_manager( // spurious wakeup continue; } - if !setup_incomplete.contains(&index) { - // spurious wakeup - continue; - } let local_polarity = *port_info.polarities.get(&todo.local_port).unwrap(); if event.is_writable() && !todo.sent_local_port { // can write and didn't send setup msg yet? Do so! @@ -414,7 +429,7 @@ fn new_endpoint_manager( // is the setup for this net_endpoint now complete? if todo.sent_local_port && todo.recv_peer_port.is_some() { // yes! connected, sent my info and received peer's info - setup_incomplete.remove(&index); + setup_incomplete.remove(&token_target); log!(logger, "endpoint[{}] is finished!", index); } } @@ -429,7 +444,6 @@ fn new_endpoint_manager( arc.continue_signal.store(false, std::sync::atomic::Ordering::SeqCst); // TODO leave the waker registered? } - let udp_endpoint_exts = vec![]; let net_endpoint_exts = todos .into_iter() @@ -448,6 +462,22 @@ fn new_endpoint_manager( getter_for_incoming: local_port, }) .collect(); + let udp_endpoint_exts = udp_todos + .into_iter() + .enumerate() + .map(|(index, udp_todo)| { + let UdpTodo { mut sock, local_port } = udp_todo; + let token = TokenTarget::UdpEndpoint { index }.into(); + poll.registry().reregister(&mut sock, token, Interest::READABLE).unwrap(); + UdpEndpointExt { + sock, + outgoing_payloads: Default::default(), + incoming_round_spec_var: None, + getter_for_incoming: local_port, + incoming_payloads: Default::default(), + } + }) + .collect(); Ok(EndpointManager { poll, events,