diff --git a/src/runtime2/component/consensus.rs b/src/runtime2/component/consensus.rs index d965aef2993d1dbe8fca8f539491ff7f8f72ceba..d989a63febc246bd6ed3d643c8cd6fa08b15d9d3 100644 --- a/src/runtime2/component/consensus.rs +++ b/src/runtime2/component/consensus.rs @@ -1,5 +1,8 @@ use crate::protocol::eval::ValueGroup; +use crate::runtime2::scheduler::*; +use crate::runtime2::runtime::*; use crate::runtime2::communication::*; +use crate::runtime2::component::wake_up_if_sleeping; use super::component_pdl::*; @@ -14,42 +17,383 @@ impl PortAnnotation { } } +#[derive(Eq, PartialEq)] +enum Mode { + NonSync, + SyncBusy, + SyncAwaitingSolution, +} + +struct SolutionCombiner { + solution: SyncPartialSolution, + all_present: bool, // set if the `submissions_by` only contains (_, true) entries. +} + +impl SolutionCombiner { + fn new() -> Self { + return Self { + solution: SyncPartialSolution{ + submissions_by: Vec::new(), + channel_mapping: Vec::new(), + decision: RoundDecision::None, + }, + all_present: false, + } + } + + /// Returns a decision for the current round. If there is no decision (yet) + /// then `RoundDecision::None` is returned. + fn get_decision(&self) -> RoundDecision { + if self.all_present { + debug_assert_ne!(self.solution.decision, RoundDecision::None); + return self.solution.decision; + } + + return RoundDecision::None; // even if failure: wait for everyone. + } + + fn combine_with_partial_solution(&mut self, partial: SyncPartialSolution) { + // Combine the submission tracking + for (comp_id, present) in partial.submissions_by { + self.mark_single_component_submission(comp_id, present); + } + + debug_assert_ne!(self.solution.decision, RoundDecision::Solution); + debug_assert_ne!(partial.decision, RoundDecision::Solution); + + // Combine our partial solution with the provided partial solution. + // This algorithm *could* allow overlap in the partial solutions, but + // in practice this means something is going wrong (a component stored + // a local solution *and* transmitted it to the leader, then later + // submitted its partial solution), hence we will do some debug asserts + // for now. + for new_entry in partial.channel_mapping { + let channel_index = if new_entry.putter.is_some() && new_entry.getter.is_some() { + // Channel is completely specified + debug_assert!( + self.find_channel_index_for_partial_entry(new_entry.putter.as_ref().unwrap()).is_none() && + self.find_channel_index_for_partial_entry(new_entry.getter.as_ref().unwrap()).is_none() + ); + let channel_index = self.solution.channel_mapping.len(); + self.solution.channel_mapping.push(new_entry); + + channel_index + } else if let Some(new_port) = new_entry.putter { + // Only putter is present in new entry + match self.find_channel_index_for_partial_entry(&new_port) { + Some(channel_index) => { + let entry = &mut self.solution.channel_mapping[channel_index]; + debug_assert!(entry.putter.is_none()); + entry.putter = Some(new_port); + + channel_index + }, + None => { + let channel_index = self.solution.channel_mapping.len(); + self.solution.channel_mapping.push(SyncSolutionChannel{ + putter: Some(new_port), + getter: None, + }); + + channel_index + } + } + } else if let Some(new_port) = new_entry.getter { + // Only getter is present in new entry + match self.find_channel_index_for_partial_entry(&new_port) { + Some(channel_index) => { + let entry = &mut self.solution.channel_mapping[channel_index]; + debug_assert!(entry.getter.is_none()); + entry.getter = Some(new_port); + + channel_index + }, + None => { + let channel_index = self.solution.channel_mapping.len(); + self.solution.channel_mapping.push(SyncSolutionChannel{ + putter: None, + getter: Some(new_port) + }); + + channel_index + } + } + } else { + unreachable!() + }; + + // Make sure the new entry is consistent + let channel = &self.solution.channel_mapping[channel_index]; + if !Self::channel_is_consistent(channel) { + self.solution.decision = RoundDecision::Failure; + } + } + + // Check to see if we have a global solution already + self.update_all_present(); + if self.all_present && self.solution.decision != RoundDecision::Failure { + debug_assert_eq!(self.solution.decision, RoundDecision::None); + dbg_code!(for entry in &self.solution.channel_mapping { + debug_assert!(entry.putter.is_some() && entry.getter.is_some()); + }); + self.solution.decision = RoundDecision::Solution; + } + } + + /// Combines the currently stored global solution (if any) with the newly + /// provided local solution. Make sure to check the `has_decision` return + /// value afterwards. + fn combine_with_local_solution(&mut self, comp_id: CompId, solution: SyncLocalSolution) { + // Mark the contributions of the component and detect components whose + // submissions we do not yet have + self.mark_single_component_submission(comp_id, true); + for entry in solution.iter() { + self.mark_single_component_submission(entry.peer_comp_id, false); + } + + debug_assert_ne!(self.solution.decision, RoundDecision::Solution); + + // Go through all entries and check if the submitted local solution is + // consistent with our partial solution + let mut had_new_entry = false; + for entry in solution.iter() { + let preexisting_index = self.find_channel_index_for_local_entry(comp_id, entry); + let new_port = SolutionPort{ + self_comp_id: comp_id, + self_port_id: entry.self_port_id, + peer_comp_id: entry.peer_comp_id, + peer_port_id: entry.peer_port_id, + mapping: entry.mapping, + }; + + match preexisting_index { + Some(entry_index) => { + // Add the local solution's entry to the existing entry in + // the global solution. We'll handle any mismatches along + // the way. + let channel = &mut self.solution.channel_mapping[entry_index]; + if entry.is_putter { + // Getter should be present in existing entry + debug_assert!(channel.getter.is_some() && channel.putter.is_none()); + channel.putter = Some(new_port); + } else { + // Putter should be present in existing entry + debug_assert!(channel.putter.is_some() && channel.getter.is_none()); + channel.getter = Some(new_port); + }; + + if !Self::channel_is_consistent(channel) { + self.solution.decision = RoundDecision::Failure; + } + }, + None => { + // No entry yet. So add it + let new_solution = if entry.is_putter { + SolutionChannel{ putter: Some(new_port), getter: None } + } else { + SolutionChannel{ putter: None, getter: Some(new_port) } + }; + self.solution.channel_mapping.push(new_solution); + had_new_entry = true; + } + } + } + + if !had_new_entry { + self.update_all_present(); + if self.all_present && self.solution.decision != RoundDecision::Failure { + // No new entries and every component is present. This implies that + // every component successfully added their local solutions to the + // global solution. Hence: we have a global solution + debug_assert_eq!(self.solution.decision, RoundDecision::None); + dbg_code!(for entry in &self.solution.channel_mapping { + debug_assert!(entry.putter.is_some() && entry.getter.is_some()); + }); + self.solution.decision = RoundDecision::Solution; + } + } + } + + fn mark_single_component_submission(&mut self, comp_id: CompId, will_contribute: bool) { + debug_assert!(!will_contribute || !self.solution.submissions_by.iter().any(|(id, val)| *id == comp_id && *val)); // if submitting a solution, then we do not expect an existing entry + for (entry, has_contributed) in self.solution.submissions_by.iter_mut() { + if *entry == comp_id { + *has_contributed = *has_contributed || will_contribute; + return; + } + } + + self.solution.submissions_by.push((comp_id, will_contribute)); + } + + fn update_all_present(&mut self) { + debug_assert!(!self.all_present); // upheld by caller + for (_, present) in self.solution.submissions_by.iter() { + if !*present { + return; + } + } + + self.all_present = true; + } + + /// Given the partial solution entry of a channel's port, check if there is + /// an entry for the other port. If there is we return its index, and we + /// return `None` otherwise. + fn find_channel_index_for_partial_entry(&self, new_entry: &SyncSolutionPort) -> Option { + fn might_belong_to_same_channel(cur_entry: &SyncSolutionPort, new_entry: &SyncSolutionPort) -> bool { + ( + cur_entry.peer_comp_id == new_entry.self_comp_id && + cur_entry.peer_port_id == new_entry.self_port_id + ) || ( + cur_entry.self_comp_id == new_entry.peer_comp_id && + cur_entry.self_port_id == new_entry.peer_port_id + ) + } + + for (entry_index, cur_entry) in self.solution.channel_mapping.iter().enumerate() { + if new_entry.port_kind == PortKind::Putter { + if let Some(cur_entry) = &cur_entry.getter { + if might_belong_to_same_channel(cur_entry, new_entry) { + return Some(entry_index); + } + } + } else { + if let Some(cur_entry) = &cur_entry.putter { + if might_belong_to_same_channel(cur_entry, new_entry) { + return Some(entry_index); + } + } + } + } + + return None; + } + + /// Given the local solution entry for one end of a channel, check if there + /// is an entry for the other end of the channel such that they can be + /// paired up. + fn find_channel_index_for_local_entry(&self, comp_id: CompId, new_entry: &SyncLocalSolutionEntry) -> Option { + fn might_belong_to_same_channel(cur_entry: &SyncSolutionPort, new_comp_id: CompId, new_entry: &SyncLocalSolutionEntry) -> bool { + ( + new_entry.peer_comp_id == cur_entry.self_comp_id && + new_entry.peer_port_id == cur_entry.self_port_id + ) || ( + new_comp_id == cur_entry.peer_comp_id && + new_entry.self_port_id == cur_entry.peer_port_id + ) + } + + for (entry_index, cur_entry) in self.solution.channel_mapping.iter().enumerate() { + // Note that the check that determines whether two ports belong to + // the same channel is one-sided. That is: port A may decide that + // port B is part of its channel, but port B may consider port A not + // to be part of its channel. Before merging the entries (outside of + // this function) we'll make sure this is not the case. + if new_entry.is_putter { + // Expect getter to be present + if let Some(cur_entry) = &cur_entry.getter { + if might_belong_to_same_channel(cur_entry, comp_id, new_entry) { + return Some(entry_index); + } + } + } else { + if let Some(cur_entry) = &cur_entry.putter { + if might_belong_to_same_channel(cur_entry, comp_id, new_entry) { + return Some(entry_index); + } + } + } + } + + return None; + } + + // Makes sure that two ports agree that they are each other's peers + fn ports_belong_to_same_channel(a: &SyncSolutionPort, b: &SyncSolutionPort) -> bool { + return + a.self_comp_id == b.peer_comp_id && a.self_port_id == b.peer_port_id && + a.peer_comp_id == b.self_comp_id && a.peer_port_id == b.self_port_id + } + + // Makes sure channel is consistently mapped (or not yet fully specified) + fn channel_is_consistent(channel: &SyncSolutionChannel) -> bool { + debug_assert!(channel.putter.is_some() || channel.getter.is_some()); + if channel.putter.is_none() || channel.getter.is_none() { + // Not yet fully specified + return false; + } + + let putter = channel.putter.as_ref().unwrap(); + let getter = channel.getter.as_ref().unwrap(); + return + Self::ports_belong_to_same_channel(putter, getter) && + putter.mapping == getter.mapping; + } +} + /// Tracking consensus state pub struct Consensus { - round: u32, + // General state of consensus manager mapping_counter: u32, + mode: Mode, + // State associated with sync round + round_index: u32, + highest_id: CompId, ports: Vec, + // State associated with arriving at a solution and being a (temporary) + // leader in the consensus round + solution: SolutionCombiner, } impl Consensus { pub(crate) fn new() -> Self { return Self{ - round: 0, - mapping_counter: 0, + round_index: 0, + highest_id: CompId::new_invalid(), ports: Vec::new(), + mapping_counter: 0, + mode: Mode::NonSync, + solution: SolutionCombiner::new(), } } + // ------------------------------------------------------------------------- + // Managing sync state + // ------------------------------------------------------------------------- + pub(crate) fn notify_sync_start(&mut self, comp_ctx: &CompCtx) { - // Make sure we locally still have all of the same ports - self.transfer_ports(comp_ctx); + debug_assert_eq!(self.mode, Mode::NonSync); + self.highest_id = comp_ctx.id; self.mapping_counter = 0; + self.mode = Mode::SyncBusy; + self.make_ports_consistent_with_ctx(comp_ctx); } - pub(crate) fn annotate_message_data(&mut self, port_info: &Port, content: ValueGroup) -> DataMessage { - debug_assert!(self.ports.iter().any(|v| v.id == port_info.self_id)); - let data_header = self.create_data_header(port_info); - let sync_header = self.create_sync_header(); + pub(crate) fn notify_sync_end(&mut self, sched_ctx: &SchedulerCtx, comp_ctx: &CompCtx) -> RoundDecision { + debug_assert_eq!(self.mode, Mode::SyncBusy); + self.mode = Mode::SyncAwaitingSolution; - return DataMessage{ data_header, sync_header, content }; - } + // Submit our port mapping as a solution + let mut local_solution = Vec::with_capacity(self.ports.len()); + for port in &self.ports { + if let Some(mapping) = port.mapping { + let port_info = comp_ctx.get_port(port.id); + local_solution.push(SyncLocalSolutionEntry { + self_port_id: port.id, + peer_comp_id: port_info.peer_comp_id, + peer_port_id: port_info.peer_id, + mapping, + port_kind: port_info.kind, + }); + } + } - pub(crate) fn notify_sync_end(&mut self) { - self.round = self.round.wrapping_add(1); - todo!("implement sync end") + let decision = self.handle_local_solution(sched_ctx, comp_ctx, comp_ctx.id, local_solution); + return decision; } - pub(crate) fn transfer_ports(&mut self, comp_ctx: &CompCtx) { + fn make_ports_consistent_with_ctx(&mut self, comp_ctx: &CompCtx) { let mut needs_setting_ports = false; if comp_ctx.ports.len() != self.ports.len() { needs_setting_ports = true; @@ -73,32 +417,238 @@ impl Consensus { } } - fn create_data_header(&mut self, port_info: &Port) -> MessageDataHeader { + // ------------------------------------------------------------------------- + // Handling inbound and outbound messages + // ------------------------------------------------------------------------- + + pub(crate) fn annotate_data_message(&mut self, comp_ctx: &CompCtx, port_info: &Port, content: ValueGroup) -> DataMessage { + debug_assert_eq!(self.mode, Mode::SyncBusy); // can only send between sync start and sync end + debug_assert!(self.round.ports.iter().any(|v| v.id == port_info.self_id)); + let data_header = self.create_data_header_and_update_mapping(port_info); + let sync_header = self.create_sync_header(comp_ctx); + + return DataMessage{ data_header, sync_header, content }; + } + + /// Checks if the data message can be received (due to port annotations), if + /// it can then `true` is returned and the caller is responsible for handing + /// the message of to the PDL code. Otherwise the message cannot be + /// received. + pub(crate) fn try_receive_data_message(&mut self, sched_ctx: &SchedulerCtx, comp_ctx: &mut CompCtx, message: &DataMessage) -> bool { + debug_assert_eq!(self.mode, Mode::SyncBusy); + debug_assert!(self.round.ports.iter().any(|v| v.id == message.data_header.target_port)); + + // Make sure the expected mapping matches the currently stored mapping + for (expected_id, expected_annotation) in &message.data_header.expected_mapping { + let got_annotation = self.get_annotation(*expected_id); + if got_annotation != expected_annotation { + return false; + } + } + + // Expected mapping matches current mapping, so we will receive the message + self.set_annotation(message.data_header.target_port, message.data_header.new_mapping); + + // Handle the sync header embedded within the data message + self.handle_sync_header(sched_ctx, comp_ctx, &message.sync_header); + + return true; + } + + /// Receives the sync message and updates the consensus state appropriately. + pub(crate) fn receive_sync_message(&mut self, sched_ctx: &SchedulerCtx, comp_ctx: &mut CompCtx, message: SyncMessage) -> RoundDecision { + // Whatever happens: handle the sync header (possibly changing the + // currently registered leader) + self.handle_sync_header(sched_ctx, comp_ctx, &message.sync_header); + + match message.content { + SyncMessageContent::NotificationOfLeader => { + return RoundDecision::None; + }, + SyncMessageContent::LocalSolution(solution_generator_id, local_solution) => { + return self.handle_local_solution(sched_ctx, comp_ctx, solution_generator_id, local_solution); + }, + SyncMessageContent::PartialSolution(partial_solution) => { + return self.handle_partial_solution(sched_ctx, comp_ctx, partial_solution); + } + SyncMessageContent::GlobalSolution => { + // Global solution has been found + debug_assert_eq!(self.mode, Mode::SyncAwaitingSolution); // leader can only find global- if we submitted local solution + todo!("clear port mapping or something"); + return RoundDecision::Solution; + }, + } + } + + fn handle_sync_header(&mut self, sched_ctx: &SchedulerCtx, comp_ctx: &mut CompCtx, header: &MessageSyncHeader) { + if header.highest_id > self.round.highest_id { + // Sender knows of someone with a higher ID. So store highest ID, + // notify all peers, and forward local solutions + self.round.highest_id = header.highest_id; + for peer in &comp_ctx.peers { + if peer.id == header.sending_id { + continue; + } + + let message = SyncMessage{ + sync_header: self.create_sync_header(comp_ctx), + content: SyncMessageContent::NotificationOfLeader, + }; + peer.handle.inbox.push(Message::Sync(message)); + wake_up_if_sleeping(sched_ctx, peer.id, &peer.handle); + } + + self.forward_local_solutions(sched_ctx, comp_ctx); + } else if header.highest_id < self.round.highest_id { + // Sender has a lower ID, so notify it of our higher one + let message = SyncMessage{ + sync_header: self.create_sync_header(comp_ctx), + content: SyncMessageContent::NotificationOfLeader, + }; + let peer_info = comp_ctx.get_peer(header.sending_id); + peer_info.handle.inbox.push(Message::Sync(message)); + wake_up_if_sleeping(sched_ctx, peer_info.id, &peer_info.handle); + } // else: exactly equal + } + + fn get_annotation(&self, port_id: PortId) -> Option { + for annotation in self.ports.iter() { + if annotation.id == port_id { + return annotation.mapping; + } + } + + debug_assert!(false); + return None; + } + + fn set_annotation(&mut self, port_id: PortId, mapping: u32) { + for annotation in self.ports.iter_mut() { + if annotation.id == port_id { + annotation.mapping = Some(mapping); + } + } + } + + // ------------------------------------------------------------------------- + // Leader-related methods + // ------------------------------------------------------------------------- + + fn forward_local_solutions(&mut self, sched_ctx: &SchedulerCtx, comp_ctx: &mut CompCtx) { + todo!("implement") + } + + fn handle_local_solution(&mut self, sched_ctx: &SchedulerCtx, comp_ctx: &CompCtx, solution_sender_id: CompId, solution: SyncLocalSolution) -> RoundDecision { + if self.highest_id == comp_ctx.id { + // We are the leader + self.solution.combine_with_local_solution(solution_sender_id, solution); + let round_decision = self.solution.get_decision(); + let decision_is_solution = match round_decision { + RoundDecision::None => { + // No solution yet + return RoundDecision::None; + }, + RoundDecision::Solution => true, + RoundDecision::Failure => false, + }; + + // If here then we've reached a decision, broadcast it + for (peer_id, _is_present) in self.solution.solution.submissions_by.iter().copied() { + debug_assert!(_is_present); + if peer_id == comp_ctx.id { + // Do not send the result to ourselves + continue; + } + + let mut handle = sched_ctx.runtime.get_component_public(peer_id); + handle.inbox.push(Message::Sync(SyncMessage{ + sync_header: self.create_sync_header(comp_ctx), + content: if decision_is_solution { + SyncMessageContent::GlobalSolution + } else { + SyncMessageContent::GlobalFailure + }, + })); + wake_up_if_sleeping(sched_ctx, peer_id, &handle); + + let _should_remove = handle.decrement_users(); + debug_assert!(!_should_remove); + } + + return round_decision; + } else { + // Forward the solution + let message = SyncMessage{ + sync_header: self.create_sync_header(comp_ctx), + content: SyncMessageContent::LocalSolution(solution_sender_id, solution), + }; + self.send_to_leader(sched_ctx, comp_ctx, Message::Sync(message)); + return RoundDecision::None; + } + } + + fn handle_partial_solution(&mut self, sched_ctx: &SchedulerCtx, comp_ctx: &mut CompCtx, solution: SyncPartialSolution) -> RoundDecision { + if self.highest_id == comp_ctx.id { + // We are the leader, combine existing and new solution + self.solution.combine_with_partial_solution(solution); + let round_decision = self.solution.get_decision(); + + + return RoundDecision::None; + } else { + // Forward the partial solution + let message = SyncMessage{ + sync_header: self.create_sync_header(comp_ctx), + content: SyncMessageContent::PartialSolution(solution), + }; + self.send_to_leader(sched_ctx, comp_ctx, Message::Sync(message)); + return RoundDecision::None; + } + } + + fn send_to_leader(&mut self, sched_ctx: &SchedulerCtx, comp_ctx: &CompCtx, message: Message) { + debug_assert_ne!(self.highest_id, comp_ctx.id); // we're not the leader + let leader_info = sched_ctx.runtime.get_component_public(self.highest_id); + leader_info.inbox.push(message); + wake_up_if_sleeping(sched_ctx, self.highest_id, &leader_info); + } + + // ------------------------------------------------------------------------- + // Creating message headers + // ------------------------------------------------------------------------- + + fn create_data_header_and_update_mapping(&mut self, port_info: &Port) -> MessageDataHeader { let mut expected_mapping = Vec::with_capacity(self.ports.len()); - for port in &self.ports { - if let Some(mapping) = port.mapping { - expected_mapping.push((port.id, mapping)); + let mut port_index = usize::MAX; + for (index, port) in self.round.ports.iter().enumerate() { + if port.id == port_info.self_id { + port_index = index; } + expected_mapping.push((port.id, Some(mapping))); } + let new_mapping = self.take_mapping(); + self.round.ports[port_index].mapping = Some(new_mapping); debug_assert_eq!(port_info.kind, PortKind::Putter); return MessageDataHeader{ expected_mapping, - new_mapping: self.take_mapping(), + new_mapping, source_port: port_info.self_id, target_port: port_info.peer_id, }; } - fn create_sync_header(&self) -> MessageSyncHeader { + fn create_sync_header(&self, comp_ctx: &CompCtx) -> MessageSyncHeader { return MessageSyncHeader{ - sync_round: self.round, + sync_round: self.round.index, + sending_id: comp_ctx.id, + highest_id: self.highest_id, }; } fn take_mapping(&mut self) -> u32 { let mapping = self.mapping_counter; - self.mapping_counter += 1; + self.mapping_counter = self.mapping_counter.wrapping_add(1); return mapping; } } \ No newline at end of file