use std::path::Component; use crate::collections::VecSet; use crate::protocol::eval::ValueGroup; use crate::runtime2::branch::{BranchId, ExecTree, QueueKind}; use crate::runtime2::ConnectorId; use crate::runtime2::inbox2::{DataHeader, MessageFancy, SyncContent, SyncHeader, SyncMessageFancy}; use crate::runtime2::inbox::SyncMessage; use crate::runtime2::port::{Port, PortIdLocal}; use crate::runtime2::scheduler::ComponentCtxFancy; use super::inbox2::PortAnnotation; struct BranchAnnotation { port_mapping: Vec, } pub(crate) struct LocalSolution { component: ConnectorId, final_branch_id: BranchId, port_mapping: Vec<(PortIdLocal, BranchId)>, } // ----------------------------------------------------------------------------- // Consensus // ----------------------------------------------------------------------------- /// The consensus algorithm. Currently only implemented to find the component /// with the highest ID within the sync region and letting it handle all the /// local solutions. /// /// The type itself serves as an experiment to see how code should be organized. // TODO: Flatten all datastructures // TODO: Have a "branch+port position hint" in case multiple operations are // performed on the same port to prevent repeated lookups // TODO: A lot of stuff should be batched. Like checking all the sync headers // and sending "I have a higher ID" messages. pub(crate) struct Consensus { // Local component's state highest_connector_id: ConnectorId, branch_annotations: Vec, last_finished_handled: Option, // Gathered state (in case we are currently the leader of the distributed // consensus protocol) encountered_peers: VecSet, local_solutions: Vec, // Workspaces workspace_ports: Vec, } #[derive(Clone, Copy, PartialEq, Eq)] pub(crate) enum Consistency { Valid, Inconsistent, } impl Consensus { pub fn new() -> Self { return Self { highest_connector_id: ConnectorId::new_invalid(), branch_annotations: Vec::new(), last_finished_handled: None, encountered_peers: VecSet::new(), local_solutions: Vec::new(), workspace_ports: Vec::new(), } } // --- Controlling sync round and branches /// Returns whether the consensus algorithm is running in sync mode pub fn is_in_sync(&self) -> bool { return !self.branch_annotations.is_empty(); } /// TODO: Remove this once multi-fire is in place pub fn get_annotation(&self, branch_id: BranchId, port_id: PortIdLocal) -> &PortAnnotation { let branch = &self.branch_annotations[branch_id.index as usize]; let port = branch.port_mapping.iter().find(|v| v.port_id == port_id).unwrap(); return port; } /// Sets up the consensus algorithm for a new synchronous round. The /// provided ports should be the ports the component owns at the start of /// the sync round. pub fn start_sync(&mut self, ports: &[Port]) { debug_assert!(!self.highest_connector_id.is_valid()); debug_assert!(self.branch_annotations.is_empty()); debug_assert!(self.encountered_peers.is_empty()); // We'll use the first "branch" (the non-sync one) to store our ports, // this allows cloning if we created a new branch. self.branch_annotations.push(BranchAnnotation{ port_mapping: ports.iter() .map(|v| PortAnnotation{ port_id: v.self_id, registered_id: None, expected_firing: None, }) .collect(), }); } /// Notifies the consensus algorithm that a new branch has appeared. Must be /// called for each forked branch in the execution tree. pub fn notify_of_new_branch(&mut self, parent_branch_id: BranchId, new_branch_id: BranchId) { // If called correctly. Then each time we are notified the new branch's // index is the length in `branch_annotations`. debug_assert!(self.branch_annotations.len() == new_branch_id.index as usize); let parent_branch_annotations = &self.branch_annotations[parent_branch_id.index as usize]; let new_branch_annotations = BranchAnnotation{ port_mapping: parent_branch_annotations.port_mapping.clone(), }; self.branch_annotations.push(new_branch_annotations); } /// Notifies the consensus algorithm that a branch has reached the end of /// the sync block. A final check for consistency will be performed that the /// caller has to handle. Note that pub fn notify_of_finished_branch(&self, branch_id: BranchId) -> Consistency { debug_assert!(self.is_in_sync()); let branch = &self.branch_annotations[branch_id.index as usize]; for mapping in &branch.port_mapping { match mapping.expected_firing { Some(expected) => { if expected != mapping.registered_id.is_some() { // Inconsistent speculative state and actual state debug_assert!(mapping.registered_id.is_none()); // because if we did fire on a silent port, we should've caught that earlier return Consistency::Inconsistent; } }, None => {}, } } return Consistency::Valid; } /// Notifies the consensus algorithm that a particular branch has assumed /// a speculative value for its port mapping. pub fn notify_of_speculative_mapping(&mut self, branch_id: BranchId, port_id: PortIdLocal, does_fire: bool) -> Consistency { debug_assert!(self.is_in_sync()); let branch = &mut self.branch_annotations[branch_id.index as usize]; for mapping in &mut branch.port_mapping { if mapping.port_id == port_id { match mapping.expected_firing { None => { // Not yet mapped, perform speculative mapping mapping.expected_firing = Some(does_fire); return Consistency::Valid; }, Some(current) => { // Already mapped if current == does_fire { return Consistency::Valid; } else { return Consistency::Inconsistent; } } } } } unreachable!("notify_of_speculative_mapping called with unowned port"); } /// Generates sync messages for any branches that are at the end of the /// sync block. To find these branches, they should've been put in the /// "finished" queue in the execution tree. pub fn handle_new_finished_sync_branches(&mut self, tree: &ExecTree, ctx: &mut ComponentCtxFancy) { debug_assert!(self.is_in_sync()); let mut last_branch_id = self.last_finished_handled; for branch in tree.iter_queue(QueueKind::FinishedSync, last_branch_id) { // Turn the port mapping into a local solution let source_mapping = &self.branch_annotations[branch.id.index as usize].port_mapping; let mut target_mapping = Vec::with_capacity(source_mapping.len()); for port in source_mapping { target_mapping.push(( port.port_id, port.registered_id.unwrap_or(BranchId::new_invalid()) )); } let local_solution = LocalSolution{ component: ctx.id, final_branch_id: branch.id, port_mapping: target_mapping, }; last_branch_id = Some(branch.id); } self.last_finished_handled = last_branch_id; } pub fn end_sync(&mut self, branch_id: BranchId, final_ports: &mut Vec) { debug_assert!(self.is_in_sync()); // TODO: Handle sending and receiving ports final_ports.clear(); let branch = &self.branch_annotations[branch_id.index as usize]; for port in &branch.port_mapping { final_ports.push(port.port_id); } } // --- Handling messages /// Prepares a message for sending. Caller should have made sure that /// sending the message is consistent with the speculative state. pub fn handle_message_to_send(&mut self, branch_id: BranchId, source_port_id: PortIdLocal, content: &ValueGroup, ctx: &mut ComponentCtxFancy) -> (SyncHeader, DataHeader) { debug_assert!(self.is_in_sync()); let branch = &mut self.branch_annotations[branch_id.index as usize]; if cfg!(debug_assertions) { let port = branch.port_mapping.iter() .find(|v| v.port_id == source_port_id) .unwrap(); debug_assert!(port.expected_firing == None || port.expected_firing == Some(true)); } // Check for ports that are begin sent debug_assert!(self.workspace_ports.is_empty()); find_ports_in_value_group(content, &mut self.workspace_ports); if !self.workspace_ports.is_empty() { todo!("handle sending ports"); self.workspace_ports.clear(); } // TODO: Handle multiple firings. Right now we just assign the current // branch to the `None` value because we know we can only send once. debug_assert!(branch.port_mapping.iter().find(|v| v.port_id == source_port_id).unwrap().registered_id.is_none()); let sync_header = self.create_sync_header(ctx); let port_info = ctx.get_port_by_id(source_port_id).unwrap(); let data_header = DataHeader{ expected_mapping: branch.port_mapping.clone(), sending_port: port_info.peer_id, target_port: port_info.peer_id, new_mapping: branch_id }; for mapping in &mut branch.port_mapping { if mapping.port_id == source_port_id { mapping.expected_firing = Some(true); mapping.registered_id = Some(branch_id); } } return (sync_header, data_header); } pub fn handle_received_sync_header(&mut self, sync_header: &SyncHeader, ctx: &mut ComponentCtxFancy) { debug_assert!(sync_header.sending_component_id != ctx.id); // not sending to ourselves self.encountered_peers.push(sync_header.sending_component_id); if sync_header.highest_component_id > self.highest_connector_id { // Sender has higher component ID. So should be the target of our // messages. We should also let all of our peers know self.highest_connector_id = sync_header.highest_component_id; for encountered_id in self.encountered_peers.iter() { if encountered_id == sync_header.sending_component_id { // Don't need to send it to this one continue } let message = SyncMessageFancy{ sync_header: self.create_sync_header(ctx), target_component_id: encountered_id, content: SyncContent::Notification, }; ctx.submit_message(MessageFancy::Sync(message)); } // But also send our locally combined solution self.forward_local_solutions(ctx); } else if sync_header.highest_component_id < self.highest_connector_id { // Sender has lower leader ID, so it should know about our higher // one. let message = SyncMessageFancy{ sync_header: self.create_sync_header(ctx), target_component_id: sync_header.sending_component_id, content: SyncContent::Notification }; ctx.submit_message(MessageFancy::Sync(message)); } // else: exactly equal, so do nothing } /// Checks data header and consults the stored port mapping and the /// execution tree to see which branches may receive the data message's /// contents. /// /// This function is generally called for freshly received messages that /// should be matched against previously halted branches. /// TODO: Rename, name confused me after a day pub fn handle_received_data_header(&mut self, exec_tree: &ExecTree, data_header: &DataHeader, target_ids: &mut Vec) { for branch in exec_tree.iter_queue(QueueKind::AwaitingMessage, None) { if branch.awaiting_port == data_header.target_port { // Found a branch awaiting the message, but we need to make sure // the mapping is correct if self.branch_can_receive(branch.id, data_header) { target_ids.push(branch.id); } } } } pub fn notify_of_received_message(&mut self, branch_id: BranchId, data_header: &DataHeader, content: &ValueGroup) { debug_assert!(self.branch_can_receive(branch_id, data_header)); let branch = &mut self.branch_annotations[branch_id.index as usize]; for mapping in &mut branch.port_mapping { if mapping.port_id == data_header.target_port { // Found the port in which the message should be inserted mapping.registered_id = Some(data_header.new_mapping); // Check for sent ports debug_assert!(self.workspace_ports.is_empty()); find_ports_in_value_group(content, &mut self.workspace_ports); if !self.workspace_ports.is_empty() { todo!("handle received ports"); self.workspace_ports.clear(); } return; } } // If here, then the branch didn't actually own the port? Means the // caller made a mistake unreachable!("incorrect notify_of_received_message"); } /// Matches the mapping between the branch and the data message. If they /// match then the branch can receive the message. pub fn branch_can_receive(&self, branch_id: BranchId, data_header: &DataHeader) -> bool { let annotation = &self.branch_annotations[branch_id.index as usize]; for expected in &data_header.expected_mapping { // If we own the port, then we have an entry in the // annotation, check if the current mapping matches for current in &annotation.port_mapping { if expected.port_id == current.port_id { if expected.registered_id != current.registered_id { // IDs do not match, we cannot receive the // message in this branch return false; } } } } return true; } // --- Internal helpers fn send_or_store_local_solution(&mut self, solution: LocalSolution, ctx: &mut ComponentCtxFancy) { if self.highest_connector_id == ctx.id { // We are the leader self.store_local_solution(solution, ctx); } else { // Someone else is the leader let message = SyncMessageFancy{ sync_header: self.create_sync_header(ctx), target_component_id: self.highest_connector_id, content: SyncContent::LocalSolution(solution), }; ctx.submit_message(MessageFancy::Sync(message)); } } /// Stores the local solution internally. This assumes that we are the /// leader. fn store_local_solution(&mut self, solution: LocalSolution, _ctx: &ComponentCtxFancy) { debug_assert_eq!(self.highest_connector_id, _ctx.id); self.local_solutions.push(solution); } #[inline] fn create_sync_header(&self, ctx: &ComponentCtxFancy) -> SyncHeader { return SyncHeader{ sending_component_id: ctx.id, highest_component_id: self.highest_connector_id, } } fn forward_local_solutions(&mut self, ctx: &mut ComponentCtxFancy) { debug_assert_ne!(self.highest_connector_id, ctx.id); if !self.local_solutions.is_empty() { for local_solution in self.local_solutions.drain() { let message = SyncMessageFancy{ sync_header: self.create_sync_header(ctx), target_component_id: self.highest_connector_id, content: SyncContent::LocalSolution(local_solution), }; ctx.submit_message(MessageFancy::Sync(message)); } } } } // ----------------------------------------------------------------------------- // Solution storage and algorithms // ----------------------------------------------------------------------------- struct MatchedLocalSolution { final_branch_id: BranchId, port_mapping: Vec<(PortIdLocal, BranchId)>, matches: Vec, } struct ComponentMatches { target_id: ConnectorId, target_index: usize, match_indices: Vec, involved_ports: Vec, } struct ComponentLocalSolutions { component: ConnectorId, solutions: Vec, } // TODO: Flatten? Flatten. Flatten everything. pub(crate) struct GlobalSolution { local: Vec } impl GlobalSolution { fn new() -> Self { return Self{ local: Vec::new(), }; } fn add_solution(&mut self, solution: LocalSolution) { let component_id = solution.component; let solution = MatchedLocalSolution{ final_branch_id: solution.final_branch_id, port_mapping: solution.port_mapping, matches: Vec::new(), }; // Create an entry for the solution for the particular component let component_exists = self.local.iter_mut() .enumerate() .find(|(_, v)| v.component == component_id); let (component_index, solution_index) = match component_exists { Some((component_index, storage)) => { // Entry for component exists, so add to solutions let solution_index = storage.solutions.len(); storage.solutions.push(solution); (component_index, solution_index) } None => { // Entry for component does not exist yet let component_index = self.local.len(); self.local.push(ComponentLocalSolutions{ component: component_id, solutions: vec![solution], }); (component_index, 0) } }; // Compare this new solution to other solutions of different components // to see if we get a closed global solution. } } // ----------------------------------------------------------------------------- // Generic Helpers // ----------------------------------------------------------------------------- /// Recursively goes through the value group, attempting to find ports. /// Duplicates will only be added once. pub(crate) fn find_ports_in_value_group(value_group: &ValueGroup, ports: &mut Vec) { // Helper to check a value for a port and recurse if needed. use crate::protocol::eval::Value; fn find_port_in_value(group: &ValueGroup, value: &Value, ports: &mut Vec) { match value { Value::Input(port_id) | Value::Output(port_id) => { // This is an actual port let cur_port = PortIdLocal::new(port_id.0.u32_suffix); for prev_port in ports.iter() { if *prev_port == cur_port { // Already added return; } } ports.push(cur_port); }, Value::Array(heap_pos) | Value::Message(heap_pos) | Value::String(heap_pos) | Value::Struct(heap_pos) | Value::Union(_, heap_pos) => { // Reference to some dynamic thing which might contain ports, // so recurse let heap_region = &group.regions[*heap_pos as usize]; for embedded_value in heap_region { find_port_in_value(group, embedded_value, ports); } }, _ => {}, // values we don't care about } } // Clear the ports, then scan all the available values ports.clear(); for value in &value_group.values { find_port_in_value(value_group, value, ports); } }