From ebd352ab717ebd13b700b61e66110d6910babe4b 2020-07-15 14:14:09 From: Christopher Esterhuyse Date: 2020-07-15 14:14:09 Subject: [PATCH] smarter (safe) caching and reuse of temporary map structures during the sync round --- diff --git a/src/runtime/communication.rs b/src/runtime/communication.rs index 45bc7fce1c6e4f49c64ef2c07544b32831b2fd6c..c9170e011072e702733d98616ef64feae61546f6 100644 --- a/src/runtime/communication.rs +++ b/src/runtime/communication.rs @@ -1,7 +1,14 @@ use super::*; use crate::common::*; +use core::ops::Deref; +use core::ops::DerefMut; //////////////// +// Guard protecting an incrementally unfoldable slice of MapTempGuard elements +struct MapTempsGuard<'a, K, V>(&'a mut [HashMap]); +// Type protecting a temporary map; At the start and end of the Guard's lifetime, self.0.is_empty() +struct MapTempGuard<'a, K, V>(&'a mut HashMap); + #[derive(Default)] struct GetterBuffer { getters_and_sends: Vec<(PortId, SendPayloadMsg)>, @@ -62,6 +69,37 @@ impl ReplaceBoolTrue for bool { } //////////////// +impl<'a, K, V> MapTempsGuard<'a, K, V> { + fn reborrow(&mut self) -> MapTempsGuard<'_, K, V> { + MapTempsGuard(self.0) + } + fn split_first_mut(&mut self) -> (MapTempGuard<'_, K, V>, MapTempsGuard<'_, K, V>) { + let (head, tail) = self.0.split_first_mut().expect("Cache exhausted"); + (MapTempGuard::new(head), MapTempsGuard(tail)) + } +} +impl<'a, K, V> MapTempGuard<'a, K, V> { + fn new(map: &'a mut HashMap) -> Self { + map.clear(); + Self(map) + } +} +impl<'a, K, V> Drop for MapTempGuard<'a, K, V> { + fn drop(&mut self) { + self.0.clear() + } +} +impl<'a, K, V> Deref for MapTempGuard<'a, K, V> { + type Target = HashMap; + fn deref(&self) -> &::Target { + self.0 + } +} +impl<'a, K, V> DerefMut for MapTempGuard<'a, K, V> { + fn deref_mut(&mut self) -> &mut ::Target { + self.0 + } +} impl RoundCtxTrait for RoundCtx { fn get_deadline(&self) -> &Option { &self.deadline @@ -407,6 +445,10 @@ impl Connector { } log!(cu.logger, "Done translating native batches into branches"); + let mut pcb_temps_owner = <[HashMap; 3]>::default(); + let mut pcb_temps = MapTempsGuard(&mut pcb_temps_owner); + let mut bn_temp_owner = >::default(); + // run all proto components to their sync blocker log!( cu.logger, @@ -415,11 +457,11 @@ impl Connector { ); for (&proto_component_id, proto_component) in branching_proto_components.iter_mut() { let BranchingProtoComponent { ports, branches } = proto_component; - let mut swap = HashMap::default(); + let (swap, mut pcb_temps) = pcb_temps.split_first_mut(); + let (blocked, _pcb_temps) = pcb_temps.split_first_mut(); // initially, no components have .ended==true - let mut blocked = HashMap::default(); // drain from branches --> blocked - let cd = CyclicDrainer::new(branches, &mut swap, &mut blocked); + let cd = CyclicDrainer::new(branches, swap.0, blocked.0); BranchingProtoComponent::drain_branches_to_blocked( cd, cu, @@ -428,7 +470,7 @@ impl Connector { ports, )?; // swap the blocked branches back - std::mem::swap(&mut blocked, branches); + std::mem::swap(blocked.0, branches); if branches.is_empty() { log!(cu.logger, "{:?} has become inconsistent!", proto_component_id); if let Some(parent) = comm.neighborhood.parent { @@ -478,9 +520,10 @@ impl Connector { } Some(Route::LocalComponent(ComponentId::Native)) => branching_native.feed_msg( cu, - &mut rctx.solution_storage, + rctx, getter, &send_payload_msg, + MapTempGuard::new(&mut bn_temp_owner), ), Some(Route::LocalComponent(ComponentId::Proto(proto_component_id))) => { if let Some(branching_component) = @@ -493,6 +536,7 @@ impl Connector { proto_component_id, getter, &send_payload_msg, + pcb_temps.reborrow(), )?; if branching_component.branches.is_empty() { log!( @@ -664,15 +708,16 @@ impl BranchingNative { fn feed_msg( &mut self, cu: &mut ConnectorUnphased, - solution_storage: &mut SolutionStorage, + round_ctx: &mut RoundCtx, getter: PortId, send_payload_msg: &SendPayloadMsg, + bn_temp: MapTempGuard<'_, Predicate, NativeBranch>, ) { log!(cu.logger, "feeding native getter {:?} {:?}", getter, &send_payload_msg); assert!(cu.port_info.polarities.get(&getter).copied() == Some(Getter)); - let mut draining = HashMap::default(); + let mut draining = bn_temp; let finished = &mut self.branches; - std::mem::swap(&mut draining, finished); + std::mem::swap(draining.0, finished); for (predicate, mut branch) in draining.drain() { log!(cu.logger, "visiting native branch {:?} with {:?}", &branch, &predicate); // check if this branch expects to receive it @@ -689,7 +734,7 @@ impl BranchingNative { &branch.gotten ); let subtree_id = SubtreeId::LocalComponent(ComponentId::Native); - solution_storage.submit_and_digest_subtree_solution( + round_ctx.solution_storage.submit_and_digest_subtree_solution( &mut *cu.logger, subtree_id, predicate.clone(), @@ -910,16 +955,16 @@ impl BranchingProtoComponent { Ok(()) }) } - fn branch_merge_func( - mut a: ProtoComponentBranch, - b: &mut ProtoComponentBranch, - ) -> ProtoComponentBranch { - if b.ended && !a.ended { - a.ended = true; - std::mem::swap(&mut a, b); - } - a - } + // fn branch_merge_func( + // mut a: ProtoComponentBranch, + // b: &mut ProtoComponentBranch, + // ) -> ProtoComponentBranch { + // if b.ended && !a.ended { + // a.ended = true; + // std::mem::swap(&mut a, b); + // } + // a + // } fn feed_msg( &mut self, cu: &mut ConnectorUnphased, @@ -927,6 +972,7 @@ impl BranchingProtoComponent { proto_component_id: ProtoComponentId, getter: PortId, send_payload_msg: &SendPayloadMsg, + mut pcb_temps: MapTempsGuard<'_, Predicate, ProtoComponentBranch>, ) -> Result<(), UnrecoverableSyncError> { let logger = &mut *cu.logger; log!( @@ -937,8 +983,8 @@ impl BranchingProtoComponent { &send_payload_msg ); let BranchingProtoComponent { branches, ports } = self; - let mut unblocked = HashMap::default(); - let mut blocked = HashMap::default(); + let (mut unblocked, mut pcb_temps) = pcb_temps.split_first_mut(); + let (mut blocked, mut pcb_temps) = pcb_temps.split_first_mut(); // partition drain from branches -> {unblocked, blocked} log!(logger, "visiting {} blocked branches...", branches.len()); for (predicate, mut branch) in branches.drain() { @@ -982,8 +1028,8 @@ impl BranchingProtoComponent { } log!(logger, "blocked {:?} unblocked {:?}", blocked.len(), unblocked.len()); // drain from unblocked --> blocked - let mut swap = HashMap::default(); - let cd = CyclicDrainer::new(&mut unblocked, &mut swap, &mut blocked); + let (swap, _pcb_temps) = pcb_temps.split_first_mut(); + let cd = CyclicDrainer::new(unblocked.0, swap.0, blocked.0); BranchingProtoComponent::drain_branches_to_blocked( cd, cu, @@ -992,7 +1038,7 @@ impl BranchingProtoComponent { ports, )?; // swap the blocked branches back - std::mem::swap(&mut blocked, branches); + std::mem::swap(blocked.0, branches); log!(cu.logger, "component settles down with branches: {:?}", branches.keys()); Ok(()) } @@ -1014,7 +1060,7 @@ impl BranchingProtoComponent { ended: false, untaken_choice: None, }; - Self { ports, branches: hashmap! { Predicate::default() => branch } } + Self { ports, branches: hashmap! { Predicate::default() => branch } } } } impl SolutionStorage {