From b20c3a55156d11dd1d510b50b4711261079c552a 2020-06-24 15:44:18 From: Christopher Esterhuyse Date: 2020-06-24 15:44:18 Subject: [PATCH] dummy scatter/gather wave in --- diff --git a/src/lib.rs b/src/lib.rs index 95563b12e61e50e7648faf0da4662131252fb41a..6845fca2aa2b71b7e658c5de82a94c1c3ec7a0c6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,7 +10,7 @@ mod runtime; pub use common::{ConnectorId, EndpointPolarity, Polarity, PortId}; pub use protocol::ProtocolDescription; -pub use runtime::{error, Connector, FileLogger, VecLogger}; +pub use runtime::{error, Connector, DummyLogger, FileLogger, VecLogger}; // #[cfg(feature = "ffi")] // pub use runtime::ffi; diff --git a/src/runtime/communication.rs b/src/runtime/communication.rs index ae4c9985bebf731f91caeaf5bec0a4158c378339..e7db616c1ac626a4471c7c0ba5a446b86ad02755 100644 --- a/src/runtime/communication.rs +++ b/src/runtime/communication.rs @@ -37,6 +37,9 @@ struct CyclicDrainInner<'a, K: Eq + Hash, V> { swap: &'a mut HashMap, output: &'a mut HashMap, } +trait PayloadMsgSender { + fn send(&mut self, port_info: &PortInfo, putter: &PortId, msg: SendPayloadMsg); +} //////////////// impl Connector { @@ -122,8 +125,8 @@ impl Connector { } } } - - // TODO make cu immutable + // private function. mutates state but returns with round + // result ASAP (allows for convenient error return with ?) fn connected_sync( cu: &mut ConnectorUnphased, comm: &mut ConnectorCommunication, @@ -131,13 +134,14 @@ impl Connector { ) -> Result, SyncError> { use SyncError as Se; let mut deadline = timeout.map(|to| Instant::now() + to); - // 1. run all proto components to Nonsync blockers log!( cu.logger, "~~~ SYNC called with timeout {:?}; starting round {}", &timeout, comm.round_index ); + + // 1. run all proto components to Nonsync blockers let mut branching_proto_components = HashMap::::default(); let mut unrun_components: Vec<(ProtoComponentId, ProtoComponent)> = @@ -262,7 +266,6 @@ impl Connector { branching_proto_components.len() ); for (&proto_component_id, proto_component) in branching_proto_components.iter_mut() { - let ConnectorUnphased { port_info, proto_description, .. } = cu; let BranchingProtoComponent { ports, branches } = proto_component; let mut swap = HashMap::default(); let mut blocked = HashMap::default(); @@ -515,7 +518,6 @@ impl BranchingNative { &mut self, cu: &mut ConnectorUnphased, solution_storage: &mut SolutionStorage, - // payloads_to_get: &mut Vec<(PortId, CommMsgContents)>, getter: PortId, send_payload_msg: SendPayloadMsg, ) { @@ -617,7 +619,7 @@ impl BranchingProtoComponent { cd: CyclicDrainer, cu: &mut ConnectorUnphased, solution_storage: &mut SolutionStorage, - payloads_to_get: &mut Vec<(PortId, SendPayloadMsg)>, + payload_msg_sender: &mut impl PayloadMsgSender, proto_component_id: ProtoComponentId, ports: &HashSet, ) { @@ -683,9 +685,8 @@ impl BranchingProtoComponent { } else { // keep in "unblocked" log!(cu.logger, "Proto component {:?} putting payload {:?} on port {:?} (using var {:?})", proto_component_id, &payload, putter, var); - let getter = *cu.port_info.peers.get(&putter).unwrap(); let msg = SendPayloadMsg { predicate: predicate.clone(), payload }; - payloads_to_get.push((getter, msg)); + payload_msg_sender.send(&cu.port_info, &putter, msg); drainer.add_input(predicate, branch); } } @@ -697,7 +698,7 @@ impl BranchingProtoComponent { cu: &mut ConnectorUnphased, solution_storage: &mut SolutionStorage, proto_component_id: ProtoComponentId, - payloads_to_get: &mut Vec<(PortId, SendPayloadMsg)>, + payload_msg_sender: &mut impl PayloadMsgSender, getter: PortId, send_payload_msg: SendPayloadMsg, ) { @@ -756,7 +757,7 @@ impl BranchingProtoComponent { cd, cu, solution_storage, - payloads_to_get, + payload_msg_sender, proto_component_id, ports, ); @@ -806,7 +807,7 @@ impl SolutionStorage { self.old_local.clear(); self.new_local.clear(); } - pub(crate) fn reset(&mut self, subtree_ids: impl Iterator) { + fn reset(&mut self, subtree_ids: impl Iterator) { self.subtree_id_to_index.clear(); self.subtree_solutions.clear(); self.old_local.clear(); @@ -880,12 +881,18 @@ impl SolutionStorage { } } } +impl PayloadMsgSender for Vec<(PortId, SendPayloadMsg)> { + fn send(&mut self, port_info: &PortInfo, putter: &PortId, msg: SendPayloadMsg) { + let getter = *port_info.peers.get(putter).unwrap(); + self.push((getter, msg)); + } +} impl SyncProtoContext<'_> { - pub fn is_firing(&mut self, port: PortId) -> Option { + pub(crate) fn is_firing(&mut self, port: PortId) -> Option { let var = self.port_info.firing_var_for(port); self.predicate.query(var) } - pub fn read_msg(&mut self, port: PortId) -> Option<&Payload> { + pub(crate) fn read_msg(&mut self, port: PortId) -> Option<&Payload> { self.inbox.get(&port) } } diff --git a/src/runtime/endpoints.rs b/src/runtime/endpoints.rs index 57e0a05e44819dcbd232f59958c116c54704a2c9..57b49ba006be144f813c2056ffb1f95a07213a5f 100644 --- a/src/runtime/endpoints.rs +++ b/src/runtime/endpoints.rs @@ -17,7 +17,7 @@ fn would_block(err: &std::io::Error) -> bool { err.kind() == std::io::ErrorKind::WouldBlock } impl Endpoint { - pub fn try_recv( + pub(super) fn try_recv( &mut self, logger: &mut dyn Logger, ) -> Result, EndpointError> { @@ -50,28 +50,28 @@ impl Endpoint { }, } } - pub fn send(&mut self, msg: &T) -> Result<(), EndpointError> { + pub(super) fn send(&mut self, msg: &T) -> Result<(), EndpointError> { bincode::serialize_into(&mut self.stream, msg).map_err(|_| EndpointError::BrokenEndpoint) } } impl EndpointManager { - pub fn index_iter(&self) -> Range { + pub(super) fn index_iter(&self) -> Range { 0..self.num_endpoints() } - pub fn num_endpoints(&self) -> usize { + pub(super) fn num_endpoints(&self) -> usize { self.endpoint_exts.len() } - pub fn send_to_setup(&mut self, index: usize, msg: &Msg) -> Result<(), ConnectError> { + pub(super) fn send_to_setup(&mut self, index: usize, msg: &Msg) -> Result<(), ConnectError> { let endpoint = &mut self.endpoint_exts[index].endpoint; endpoint.send(msg).map_err(|err| { ConnectError::EndpointSetupError(endpoint.stream.local_addr().unwrap(), err) }) } - pub fn send_to(&mut self, index: usize, msg: &Msg) -> Result<(), EndpointError> { + pub(super) fn send_to(&mut self, index: usize, msg: &Msg) -> Result<(), EndpointError> { self.endpoint_exts[index].endpoint.send(msg) } - pub fn try_recv_any_comms( + pub(super) fn try_recv_any_comms( &mut self, logger: &mut dyn Logger, deadline: Option, @@ -84,7 +84,7 @@ impl EndpointManager { Err(Trae::EndpointError { error: _, index }) => Err(Se::BrokenEndpoint(index)), } } - pub fn try_recv_any_setup( + pub(super) fn try_recv_any_setup( &mut self, logger: &mut dyn Logger, deadline: Option, @@ -119,10 +119,10 @@ impl EndpointManager { .map_err(|error| Trea::EndpointError { error, index })? { endptlog!(logger, "RECV polled_undrained {:?}", &msg); - // if !endpoint.inbox.is_empty() { - // there may be another message waiting! - self.polled_undrained.insert(index); - // } + if !endpoint.inbox.is_empty() { + // there may be another message waiting! + self.polled_undrained.insert(index); + } return Ok((index, msg)); } } @@ -143,17 +143,11 @@ impl EndpointManager { index, self.polled_undrained.iter() ); - if event.is_error() { - return Err(Trea::EndpointError { - error: EndpointError::BrokenEndpoint, - index, - }); - } } self.events.clear(); } } - pub fn undelay_all(&mut self) { + pub(super) fn undelay_all(&mut self) { if self.undelayed_messages.is_empty() { // fast path std::mem::swap(&mut self.delayed_messages, &mut self.undelayed_messages); @@ -174,7 +168,7 @@ impl From for MonitoredReader { } } impl MonitoredReader { - pub fn bytes_read(&self) -> usize { + pub(super) fn bytes_read(&self) -> usize { self.bytes } } diff --git a/src/runtime/mod.rs b/src/runtime/mod.rs index b5f63c8958d8bc0e79d923734538dbd57c1ba0f8..c1acf7cc7637f20dce218a1a06d9bf0ed1e60447 100644 --- a/src/runtime/mod.rs +++ b/src/runtime/mod.rs @@ -51,7 +51,13 @@ pub enum SetupMsg { LeaderWave { wave_leader: ConnectorId }, LeaderAnnounce { tree_leader: ConnectorId }, YouAreMyParent, + SessionGather { unoptimized_map: HashMap }, + SessionScatter { optimized_map: HashMap }, } + +#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] +pub struct SessionInfo {} + #[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] pub struct CommMsg { pub round_index: usize, @@ -192,7 +198,7 @@ pub struct SyncProtoContext<'a> { } //////////////// impl VecSet { - fn iter(&self) -> impl Iterator { + fn iter(&self) -> std::slice::Iter { self.vec.iter() } fn contains(&self, element: &T) -> bool { diff --git a/src/runtime/setup.rs b/src/runtime/setup.rs index abbbbabd1cd3d31924605d421103113d3f6449e0..97f9673898fe7bb6cc82e6499ec79bf3b55817e4 100644 --- a/src/runtime/setup.rs +++ b/src/runtime/setup.rs @@ -3,15 +3,6 @@ use crate::runtime::*; use std::io::ErrorKind::WouldBlock; impl Connector { - pub fn new_simple( - proto_description: Arc, - connector_id: ConnectorId, - ) -> Self { - let logger = Box::new(DummyLogger); - // let logger = Box::new(DummyLogger); - let surplus_sockets = 2; - Self::new(logger, proto_description, connector_id, surplus_sockets) - } pub fn new( mut logger: Box, proto_description: Arc, @@ -61,45 +52,46 @@ impl Connector { } pub fn connect(&mut self, timeout: Option) -> Result<(), ConnectError> { use ConnectError::*; - let Self { unphased: up, phased } = self; + let Self { unphased: cu, phased } = self; match phased { ConnectorPhased::Communication { .. } => { - log!(up.logger, "Call to connecting in connected state"); + log!(cu.logger, "Call to connecting in connected state"); Err(AlreadyConnected) } ConnectorPhased::Setup { endpoint_setups, .. } => { - log!(up.logger, "~~~ CONNECT called timeout {:?}", timeout); + log!(cu.logger, "~~~ CONNECT called timeout {:?}", timeout); let deadline = timeout.map(|to| Instant::now() + to); // connect all endpoints in parallel; send and receive peer ids through ports let mut endpoint_manager = new_endpoint_manager( - &mut *up.logger, + &mut *cu.logger, endpoint_setups, - &mut up.port_info, + &mut cu.port_info, deadline, )?; log!( - up.logger, + cu.logger, "Successfully connected {} endpoints", endpoint_manager.endpoint_exts.len() ); // leader election and tree construction let neighborhood = init_neighborhood( - up.id_manager.connector_id, - &mut *up.logger, + cu.id_manager.connector_id, + &mut *cu.logger, &mut endpoint_manager, deadline, )?; - log!(up.logger, "Successfully created neighborhood {:?}", &neighborhood); - log!(up.logger, "connect() finished. setup phase complete"); - // TODO session optimization goes here - self.phased = ConnectorPhased::Communication(ConnectorCommunication { + log!(cu.logger, "Successfully created neighborhood {:?}", &neighborhood); + let mut comm = ConnectorCommunication { round_index: 0, endpoint_manager, neighborhood, mem_inbox: Default::default(), native_batches: vec![Default::default()], round_result: Ok(None), - }); + }; + session_optimize(cu, &mut comm, deadline)?; + log!(cu.logger, "connect() finished. setup phase complete"); + self.phased = ConnectorPhased::Communication(comm); Ok(()) } } @@ -297,8 +289,8 @@ fn init_neighborhood( em: &mut EndpointManager, deadline: Option, ) -> Result { - use {ConnectError::*, Msg::SetupMsg as S, SetupMsg::*}; //////////////////////////////// + use {ConnectError::*, Msg::SetupMsg as S, SetupMsg::*}; #[derive(Debug)] struct WaveState { parent: Option, @@ -416,10 +408,15 @@ fn init_neighborhood( } } } - S(YouAreMyParent) | S(MyPortInfo(_)) => unreachable!(), - comm_msg @ Msg::CommMsg { .. } => { - log!(logger, "delaying msg {:?} during election algorithm", comm_msg); - em.delayed_messages.push((recv_index, comm_msg)); + msg @ S(YouAreMyParent) | msg @ S(MyPortInfo(_)) => { + log!(logger, "Endpont {:?} sent unexpected msg! {:?}", recv_index, &msg); + return Err(SetupAlgMisbehavior); + } + msg @ S(SessionScatter { .. }) + | msg @ S(SessionGather { .. }) + | msg @ Msg::CommMsg { .. } => { + log!(logger, "delaying msg {:?} during election algorithm", msg); + em.delayed_messages.push((recv_index, msg)); } } } @@ -443,7 +440,6 @@ fn init_neighborhood( let (recv_index, msg) = em.try_recv_any_setup(logger, deadline)?; log!(logger, "Received from index {:?} msg {:?}", &recv_index, &msg); match msg { - S(LeaderWave { .. }) => { /* old message */ } S(LeaderAnnounce { .. }) => { // not a child log!( @@ -468,10 +464,14 @@ fn init_neighborhood( } children.push(recv_index); } - S(MyPortInfo(_)) => unreachable!(), - comm_msg @ Msg::CommMsg { .. } => { - log!(logger, "delaying msg {:?} during election algorithm", comm_msg); - em.delayed_messages.push((recv_index, comm_msg)); + msg @ S(MyPortInfo(_)) | msg @ S(LeaderWave { .. }) => { + log!(logger, "discarding old message {:?} during election", msg); + } + msg @ S(SessionScatter { .. }) + | msg @ S(SessionGather { .. }) + | msg @ Msg::CommMsg { .. } => { + log!(logger, "delaying msg {:?} during election", msg); + em.delayed_messages.push((recv_index, msg)); } } } @@ -481,3 +481,138 @@ fn init_neighborhood( log!(logger, "Neighborhood constructed {:?}", &neighborhood); Ok(neighborhood) } + +fn session_optimize( + cu: &mut ConnectorUnphased, + comm: &mut ConnectorCommunication, + deadline: Option, +) -> Result<(), ConnectError> { + //////////////////////////////////////// + use {ConnectError::*, Msg::SetupMsg as S, SetupMsg::*}; + //////////////////////////////////////// + log!(cu.logger, "Beginning session optimization"); + // populate session_info_map from a message per child + let mut unoptimized_map: HashMap = Default::default(); + let mut awaiting: HashSet = comm.neighborhood.children.iter().copied().collect(); + comm.endpoint_manager.undelay_all(); + while !awaiting.is_empty() { + log!( + cu.logger, + "Session gather loop. awaiting info from children {:?}...", + awaiting.iter() + ); + let (recv_index, msg) = + comm.endpoint_manager.try_recv_any_setup(&mut *cu.logger, deadline)?; + log!(cu.logger, "Received from index {:?} msg {:?}", &recv_index, &msg); + match msg { + S(SessionGather { unoptimized_map: child_unoptimized_map }) => { + if !awaiting.remove(&recv_index) { + log!( + cu.logger, + "Wasn't expecting session info from {:?}. Got {:?}", + recv_index, + &child_unoptimized_map + ); + return Err(SetupAlgMisbehavior); + } + unoptimized_map.extend(child_unoptimized_map.into_iter()); + } + msg @ S(YouAreMyParent) + | msg @ S(MyPortInfo(..)) + | msg @ S(LeaderAnnounce { .. }) + | msg @ S(LeaderWave { .. }) => { + log!(cu.logger, "discarding old message {:?} during election", msg); + } + msg @ S(SessionScatter { .. }) => { + log!( + cu.logger, + "Endpoint {:?} sent unexpected scatter! {:?} I've not contributed yet!", + recv_index, + &msg + ); + return Err(SetupAlgMisbehavior); + } + msg @ Msg::CommMsg(..) => { + log!(cu.logger, "delaying msg {:?} during session optimization", msg); + comm.endpoint_manager.delayed_messages.push((recv_index, msg)); + } + } + } + log!( + cu.logger, + "Gathered all children's maps. ConnectorId set is... {:?}", + unoptimized_map.keys() + ); + let my_session_info = SessionInfo {}; + unoptimized_map.insert(cu.id_manager.connector_id, my_session_info); + log!(cu.logger, "Inserting my own info. Unoptimized subtree map is {:?}", &unoptimized_map); + + // acquire the optimized info... + let optimized_map = if let Some(parent) = comm.neighborhood.parent { + // ... as a message from my parent + log!(cu.logger, "Forwarding gathered info to parent {:?}", parent); + let msg = S(SessionGather { unoptimized_map }); + comm.endpoint_manager.send_to_setup(parent, &msg)?; + 'scatter_loop: loop { + log!( + cu.logger, + "Session scatter recv loop. awaiting info from children {:?}...", + awaiting.iter() + ); + let (recv_index, msg) = + comm.endpoint_manager.try_recv_any_setup(&mut *cu.logger, deadline)?; + log!(cu.logger, "Received from index {:?} msg {:?}", &recv_index, &msg); + match msg { + S(SessionScatter { optimized_map }) => { + if recv_index != parent { + log!(cu.logger, "I expected the scatter from my parent only!"); + return Err(SetupAlgMisbehavior); + } + break 'scatter_loop optimized_map; + } + msg @ Msg::CommMsg { .. } => { + log!(cu.logger, "delaying msg {:?} during scatter recv", msg); + comm.endpoint_manager.delayed_messages.push((recv_index, msg)); + } + msg @ S(SessionGather { .. }) + | msg @ S(YouAreMyParent) + | msg @ S(MyPortInfo(..)) + | msg @ S(LeaderAnnounce { .. }) + | msg @ S(LeaderWave { .. }) => { + log!(cu.logger, "discarding old message {:?} during election", msg); + } + } + } + } else { + // by computing it myself + log!(cu.logger, "I am the leader! I will optimize this session"); + leader_session_map_optimize(unoptimized_map)? + }; + log!( + cu.logger, + "Optimized info map is {:?}. Sending to children {:?}", + &optimized_map, + comm.neighborhood.children.iter() + ); + let optimized_info = + optimized_map.get(&cu.id_manager.connector_id).expect("HEY NO INFO FOR ME?").clone(); + let msg = S(SessionScatter { optimized_map }); + for &child in comm.neighborhood.children.iter() { + comm.endpoint_manager.send_to_setup(child, &msg)?; + } + apply_optimizations(cu, comm, optimized_info)?; + log!(cu.logger, "Session optimization complete"); + Ok(()) +} +fn leader_session_map_optimize( + unoptimized_map: HashMap, +) -> Result, ConnectError> { + Ok(unoptimized_map) +} +fn apply_optimizations( + _cu: &mut ConnectorUnphased, + _comm: &mut ConnectorCommunication, + _session_info: SessionInfo, +) -> Result<(), ConnectError> { + Ok(()) +} diff --git a/src/runtime/tests.rs b/src/runtime/tests.rs index 25e12dd835bc4ffb4b474fffc06a0dc880fae9cb..a2a15b4fc3a328901f84ca4e7140de92f7f5c874 100644 --- a/src/runtime/tests.rs +++ b/src/runtime/tests.rs @@ -34,7 +34,7 @@ fn file_logged_connector(connector_id: ConnectorId, dir_path: &Path) -> Connecto #[test] fn basic_connector() { - Connector::new_simple(MINIMAL_PROTO.clone(), 0); + Connector::new(Box::new(DummyLogger), MINIMAL_PROTO.clone(), 0, 0); } #[test] @@ -86,9 +86,9 @@ fn single_node_connect() { } #[test] -fn multithreaded_connect() { +fn minimal_net_connect() { let sock_addr = next_test_addr(); - let test_log_path = Path::new("./logs/multithreaded_connect"); + let test_log_path = Path::new("./logs/minimal_net_connect"); scope(|s| { s.spawn(|_| { let mut c = file_logged_connector(0, test_log_path); @@ -278,9 +278,10 @@ fn connector_pair_nondet() { #[test] fn cannot_use_moved_ports() { - let test_log_path = Path::new("./logs/cannot_use_moved_ports"); /* - native p|-->|g sync - */ + /* + native p|-->|g sync + */ + let test_log_path = Path::new("./logs/cannot_use_moved_ports"); let mut c = file_logged_connector(0, test_log_path); let [p, g] = c.new_port_pair(); c.add_component(b"sync", &[g, p]).unwrap(); @@ -291,10 +292,11 @@ fn cannot_use_moved_ports() { #[test] fn sync_sync() { - let test_log_path = Path::new("./logs/sync_sync"); /* - native p0|-->|g0 sync - g1|<--|p1 - */ + /* + native p0|-->|g0 sync + g1|<--|p1 + */ + let test_log_path = Path::new("./logs/sync_sync"); let mut c = file_logged_connector(0, test_log_path); let [p0, g0] = c.new_port_pair(); let [p1, g1] = c.new_port_pair(); @@ -333,11 +335,11 @@ fn double_net_connect() { #[test] fn distributed_msg_bounce() { - let test_log_path = Path::new("./logs/distributed_msg_bounce"); /* native[0] | sync 0.p|-->|1.p native[1] 0.g|<--|1.g */ + let test_log_path = Path::new("./logs/distributed_msg_bounce"); let sock_addrs = [next_test_addr(), next_test_addr()]; scope(|s| { s.spawn(|_| {