Files @ a2b6b8e94778
Branch filter:

Location: CSY/reowolf/src/runtime2/consensus.rs

a2b6b8e94778 11.8 KiB application/rls-services+xml Show Annotation Show as Raw Download as Raw
MH
initial rewrite of component using new ExecTree and Consensus

use crate::protocol::eval::ValueGroup;
use crate::runtime2::branch::{BranchId, ExecTree, QueueKind};
use crate::runtime2::ConnectorId;
use crate::runtime2::inbox2::{DataHeader, SyncHeader};
use crate::runtime2::port::{Port, PortIdLocal};
use crate::runtime2::scheduler::ComponentCtxFancy;
use super::inbox2::PortAnnotation;

struct BranchAnnotation {
    port_mapping: Vec<PortAnnotation>,
}

/// The consensus algorithm. Currently only implemented to find the component
/// with the highest ID within the sync region and letting it handle all the
/// local solutions.
///
/// The type itself serves as an experiment to see how code should be organized.
// TODO: Flatten all datastructures
// TODO: Have a "branch+port position hint" in case multiple operations are
//  performed on the same port to prevent repeated lookups
pub(crate) struct Consensus {
    highest_connector_id: ConnectorId,
    branch_annotations: Vec<BranchAnnotation>,
    workspace_ports: Vec<PortIdLocal>,
}

#[derive(Clone, Copy, PartialEq, Eq)]
pub(crate) enum Consistency {
    Valid,
    Inconsistent,
}

impl Consensus {
    pub fn new() -> Self {
        return Self {
            highest_connector_id: ConnectorId::new_invalid(),
            branch_annotations: Vec::new(),
        }
    }

    // --- Controlling sync round and branches

    /// Returns whether the consensus algorithm is running in sync mode
    pub fn is_in_sync(&self) -> bool {
        return !self.branch_annotations.is_empty();
    }

    /// Sets up the consensus algorithm for a new synchronous round. The
    /// provided ports should be the ports the component owns at the start of
    /// the sync round.
    pub fn start_sync(&mut self, ports: &[Port]) {
        debug_assert!(self.branch_annotations.is_empty());
        debug_assert!(!self.highest_connector_id.is_valid());

        // We'll use the first "branch" (the non-sync one) to store our ports,
        // this allows cloning if we created a new branch.
        self.branch_annotations.push(BranchAnnotation{
            port_mapping: ports.iter()
                .map(|v| PortAnnotation{
                    port_id: v.self_id,
                    registered_id: None,
                    expected_firing: None,
                })
                .collect(),
        });
    }

    /// Notifies the consensus algorithm that a new branch has appeared. Must be
    /// called for each forked branch in the execution tree.
    pub fn notify_of_new_branch(&mut self, parent_branch_id: BranchId, new_branch_id: BranchId) {
        // If called correctly. Then each time we are notified the new branch's
        // index is the length in `branch_annotations`.
        debug_assert!(self.branch_annotations.len() == new_branch_id.index as usize);
        let parent_branch_annotations = &self.branch_annotations[parent_branch_id.index as usize];
        let new_branch_annotations = BranchAnnotation{
            port_mapping: parent_branch_annotations.port_mapping.clone(),
        };
        self.branch_annotations.push(new_branch_annotations);
    }

    /// Notifies the consensus algorithm that a branch has reached the end of
    /// the sync block. A final check for consistency will be performed that the
    /// caller has to handle
    pub fn notify_of_finished_branch(&self, branch_id: BranchId) -> Consistency {
        let branch = &self.branch_annotations[branch_id.index as usize];
        for mapping in &branch.port_mapping {
            match mapping.expected_firing {
                Some(expected) => {
                    if expected != mapping.registered_id.is_some() {
                        // Inconsistent speculative state and actual state
                        debug_assert!(mapping.registered_id.is_none()); // because if we did fire on a silent port, we should've caught that earlier
                        return Consistency::Inconsistent;
                    }
                },
                None => {},
            }
        }

        return Consistency::Valid;
    }

    /// Notifies the consensus algorithm that a particular branch has assumed
    /// a speculative value for its port mapping.
    pub fn notify_of_speculative_mapping(&mut self, branch_id: BranchId, port_id: PortIdLocal, does_fire: bool) -> Consistency {
        let branch = &mut self.branch_annotations[branch_id.index as usize];
        for mapping in &mut branch.port_mapping {
            if mapping.port_id == port_id {
                match mapping.expected_firing {
                    None => {
                        // Not yet mapped, perform speculative mapping
                        mapping.expected_firing = Some(does_fire);
                        return Consistency::Valid;
                    },
                    Some(current) => {
                        // Already mapped
                        if current == does_fire {
                            return Consistency::Valid;
                        } else {
                            return Consistency::Inconsistent;
                        }
                    }
                }
            }
        }

        unreachable!("notify_of_speculative_mapping called with unowned port");
    }

    pub fn end_sync(&mut self, branch_id: BranchId, final_ports: &mut Vec<PortIdLocal>) {
        todo!("write");
    }

    // --- Handling messages

    /// Prepares a message for sending. Caller should have made sure that
    /// sending the message is consistent with the speculative state.
    pub fn handle_message_to_send(&mut self, branch_id: BranchId, source_port_id: PortIdLocal, content: &ValueGroup, ctx: &mut ComponentCtxFancy) -> (SyncHeader, DataHeader) {
        debug_assert!(self.is_in_sync());
        let branch = &mut self.branch_annotations[branch_id.index as usize];

        if cfg!(debug_assertions) {
            let port = branch.port_mapping.iter()
                .find(|v| v.port_id == source_port_id)
                .unwrap();
            debug_assert!(port.expected_firing == None || port.expected_firing == Some(true));
        }

        // Check for ports that are begin sent
        debug_assert!(self.workspace_ports.is_empty());
        find_ports_in_value_group(content, &mut self.workspace_ports);
        if !self.workspace_ports.is_empty() {
            todo!("handle sending ports");
            self.workspace_ports.clear();
        }

        let sync_header = SyncHeader{
            sending_component_id: ctx.id,
            highest_component_id: self.highest_connector_id,
        };

        // TODO: Handle multiple firings. Right now we just assign the current
        //  branch to the `None` value because we know we can only send once.
        debug_assert!(branch.port_mapping.iter().find(|v| v.port_id == source_port_id).unwrap().registered_id.is_none());
        let port_info = ctx.get_port_by_id(source_port_id).unwrap();
        let data_header = DataHeader{
            expected_mapping: branch.port_mapping.clone(),
            target_port: port_info.peer_id,
            new_mapping: branch_id
        };

        for mapping in &mut branch.port_mapping {
            if mapping.port_id == source_port_id {
                mapping.expected_firing = Some(true);
                mapping.registered_id = Some(branch_id);
            }
        }

        return (sync_header, data_header);
    }

    pub fn handle_received_sync_header(&mut self, sync_header: &SyncHeader, ctx: &mut ComponentCtxFancy) {
        todo!("should check IDs and maybe send sync messages");
    }

    /// Checks data header and consults the stored port mapping and the
    /// execution tree to see which branches may receive the data message's
    /// contents.
    ///
    /// This function is generally called for freshly received messages that
    /// should be matched against previously halted branches.
    pub fn handle_received_data_header(&mut self, exec_tree: &ExecTree, data_header: &DataHeader, target_ids: &mut Vec<BranchId>) {
        for branch in exec_tree.iter_queue(QueueKind::AwaitingMessage) {
            if branch.awaiting_port == data_header.target_port {
                // Found a branch awaiting the message, but we need to make sure
                // the mapping is correct
                if self.branch_can_receive(branch.id, data_header) {
                    target_ids.push(branch.id);
                }
            }
        }
    }

    pub fn notify_of_received_message(&mut self, branch_id: BranchId, data_header: &DataHeader, content: &ValueGroup) {
        debug_assert!(self.branch_can_receive(branch_id, data_header));
        let branch = &mut self.branch_annotations[branch_id.index as usize];
        for mapping in &mut branch.port_mapping {
            if mapping.port_id == data_header.target_port {
                // Found the port in which the message should be inserted
                mapping.registered_id = Some(data_header.new_mapping);

                // Check for sent ports
                debug_assert!(self.workspace_ports.is_empty());
                find_ports_in_value_group(content, &mut self.workspace_ports);
                if !self.workspace_ports.is_empty() {
                    todo!("handle received ports");
                    self.workspace_ports.clear();
                }

                return;
            }
        }

        // If here, then the branch didn't actually own the port? Means the
        // caller made a mistake
        unreachable!("incorrect notify_of_received_message");
    }

    /// Matches the mapping between the branch and the data message. If they
    /// match then the branch can receive the message.
    pub(crate) fn branch_can_receive(&self, branch_id: BranchId, data_header: &DataHeader) -> bool {
        let annotation = &self.branch_annotations[branch_id.index as usize];
        for expected in &data_header.expected_mapping {
            // If we own the port, then we have an entry in the
            // annotation, check if the current mapping matches
            for current in &annotation.port_mapping {
                if expected.port_id == current.port_id {
                    if expected.registered_id != current.registered_id {
                        // IDs do not match, we cannot receive the
                        // message in this branch
                        return false;
                    }
                }
            }
        }

        return true;
    }
}

/// Recursively goes through the value group, attempting to find ports.
/// Duplicates will only be added once.
pub(crate) fn find_ports_in_value_group(value_group: &ValueGroup, ports: &mut Vec<PortIdLocal>) {
    // Helper to check a value for a port and recurse if needed.
    use crate::protocol::eval::Value;

    fn find_port_in_value(group: &ValueGroup, value: &Value, ports: &mut Vec<PortIdLocal>) {
        match value {
            Value::Input(port_id) | Value::Output(port_id) => {
                // This is an actual port
                let cur_port = PortIdLocal::new(port_id.0.u32_suffix);
                for prev_port in ports.iter() {
                    if *prev_port == cur_port {
                        // Already added
                        return;
                    }
                }

                ports.push(cur_port);
            },
            Value::Array(heap_pos) |
            Value::Message(heap_pos) |
            Value::String(heap_pos) |
            Value::Struct(heap_pos) |
            Value::Union(_, heap_pos) => {
                // Reference to some dynamic thing which might contain ports,
                // so recurse
                let heap_region = &group.regions[*heap_pos as usize];
                for embedded_value in heap_region {
                    find_port_in_value(group, embedded_value, ports);
                }
            },
            _ => {}, // values we don't care about
        }
    }

    // Clear the ports, then scan all the available values
    ports.clear();
    for value in &value_group.values {
        find_port_in_value(value_group, value, ports);
    }
}